{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples       #-}
{-# LANGUAGE UnliftedFFITypes    #-}

module Cardano.Prelude.GHC.Heap.Size (
    CountFailure(..)
  , PerformGC(..)
  , computeHeapSize
  , computeHeapSize'
  , computeHeapSizeWorkList
  ) where

import Cardano.Prelude.Base hiding (Any)

import Foreign.C.Types
import Foreign.Marshal.Alloc
import Foreign.StablePtr
import Foreign.Storable
import GHC.Exts.Heap.ClosureTypes (ClosureType)
import GHC.Prim
import GHC.Types
import System.Mem (performMajorGC)

{-------------------------------------------------------------------------------
  Failure
-------------------------------------------------------------------------------}

cNO_FAILURE, cWORK_LIST_FULL, cVISITED_FULL, cOUT_OF_MEMORY, cUNSUPPORTED_CLOSURE :: CUInt

cNO_FAILURE :: CUInt
cNO_FAILURE          = CUInt
0
cWORK_LIST_FULL :: CUInt
cWORK_LIST_FULL      = CUInt
1
cVISITED_FULL :: CUInt
cVISITED_FULL        = CUInt
2
cOUT_OF_MEMORY :: CUInt
cOUT_OF_MEMORY       = CUInt
3
cUNSUPPORTED_CLOSURE :: CUInt
cUNSUPPORTED_CLOSURE = CUInt
4

data CountFailure =
    WorkListFull
  | VisitedFull
  | OutOfMemory
  | UnsupportedClosure ClosureType
  deriving (Int -> CountFailure -> ShowS
[CountFailure] -> ShowS
CountFailure -> String
(Int -> CountFailure -> ShowS)
-> (CountFailure -> String)
-> ([CountFailure] -> ShowS)
-> Show CountFailure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CountFailure] -> ShowS
$cshowList :: [CountFailure] -> ShowS
show :: CountFailure -> String
$cshow :: CountFailure -> String
showsPrec :: Int -> CountFailure -> ShowS
$cshowsPrec :: Int -> CountFailure -> ShowS
Show, CountFailure -> CountFailure -> Bool
(CountFailure -> CountFailure -> Bool)
-> (CountFailure -> CountFailure -> Bool) -> Eq CountFailure
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CountFailure -> CountFailure -> Bool
$c/= :: CountFailure -> CountFailure -> Bool
== :: CountFailure -> CountFailure -> Bool
$c== :: CountFailure -> CountFailure -> Bool
Eq)

toCountFailure :: CUInt -> Maybe CountFailure
toCountFailure :: CUInt -> Maybe CountFailure
toCountFailure CUInt
n
  | CUInt
n CUInt -> CUInt -> Bool
forall a. Eq a => a -> a -> Bool
== CUInt
cNO_FAILURE          = Maybe CountFailure
forall a. Maybe a
Nothing
  | CUInt
n CUInt -> CUInt -> Bool
forall a. Eq a => a -> a -> Bool
== CUInt
cWORK_LIST_FULL      = CountFailure -> Maybe CountFailure
forall a. a -> Maybe a
Just (CountFailure -> Maybe CountFailure)
-> CountFailure -> Maybe CountFailure
forall a b. (a -> b) -> a -> b
$ CountFailure
WorkListFull
  | CUInt
n CUInt -> CUInt -> Bool
forall a. Eq a => a -> a -> Bool
== CUInt
cVISITED_FULL        = CountFailure -> Maybe CountFailure
forall a. a -> Maybe a
Just (CountFailure -> Maybe CountFailure)
-> CountFailure -> Maybe CountFailure
forall a b. (a -> b) -> a -> b
$ CountFailure
VisitedFull
  | CUInt
n CUInt -> CUInt -> Bool
forall a. Eq a => a -> a -> Bool
== CUInt
cOUT_OF_MEMORY       = CountFailure -> Maybe CountFailure
forall a. a -> Maybe a
Just (CountFailure -> Maybe CountFailure)
-> CountFailure -> Maybe CountFailure
forall a b. (a -> b) -> a -> b
$ CountFailure
OutOfMemory
  | CUInt
n CUInt -> CUInt -> Bool
forall a. Ord a => a -> a -> Bool
>= CUInt
cUNSUPPORTED_CLOSURE = CountFailure -> Maybe CountFailure
forall a. a -> Maybe a
Just (CountFailure -> Maybe CountFailure)
-> CountFailure -> Maybe CountFailure
forall a b. (a -> b) -> a -> b
$ ClosureType -> CountFailure
UnsupportedClosure ClosureType
typ
  | Bool
otherwise = Text -> Maybe CountFailure
forall a. HasCallStack => Text -> a
panic Text
"getCountFailure: impossible"
  where
    typ :: ClosureType
    typ :: ClosureType
typ = Int -> ClosureType
forall a. Enum a => Int -> a
toEnum (CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt
n CUInt -> CUInt -> CUInt
forall a. Num a => a -> a -> a
- CUInt
cUNSUPPORTED_CLOSURE))

{-------------------------------------------------------------------------------
  Main API
-------------------------------------------------------------------------------}

-- | Bind to the C function to count the closure size
--
-- It is crucial that this function is marked unsafe, because GHC guarantees
-- that garbage collection will not occur during an unsafe call
-- (see <https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/ffi-chap.html#guaranteed-call-safety>),
-- which is crucial for reliable counting (Haskell's GC moves objects around,
-- and so if a GC occurred during counting, we might end up counting objects
-- more than once or indeed miss them in the count if they happen to be moved to
-- a location we marked as already visited. This is also the reason that we
-- count the size of the object in C rather than in Haskell.
foreign import ccall unsafe "hs_cardanoprelude_closureSize"
  closureSize_ :: CUInt -> CUInt -> CUInt -> Ptr CUInt -> StablePtr a -> IO CULong

-- | Should we perform a GC call before counting the size?
data PerformGC =
    -- | Yes, first perform GC before counting
    --
    -- This should be used for most accurate results. Without calling GC first,
    -- the computed size might be larger than expected due to leftover
    -- indirections (black holes, selector thunks, etc.)
    FirstPerformGC

    -- | No, do not perform GC before counting
    --
    -- If pinpoint accuracy is not requried, then GC can be skipped, making the
    -- call much less expensive.
  | DontPerformGC

-- | Wrapper around 'closureSize_' that takes care of creating the stable ptr
--
-- We can't simply pass the address of the closure to the C function, because
-- we have no guarantee that GC will not happen in between taking that address
-- and the C call. We therefore create and pass a stable pointer instead.
closureSize :: PerformGC -> CUInt -> CUInt -> CUInt -> Ptr CUInt -> a -> IO CULong
closureSize :: PerformGC -> CUInt -> CUInt -> CUInt -> Ptr CUInt -> a -> IO CULong
closureSize PerformGC
performGC
            CUInt
workListCapacity
            CUInt
visitedInitCapacity
            CUInt
visitedMaxCapacity
            Ptr CUInt
err
            a
a
          = do
    case PerformGC
performGC of
      PerformGC
FirstPerformGC -> IO ()
performMajorGC
      PerformGC
DontPerformGC  -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    IO (StablePtr a)
-> (StablePtr a -> IO ())
-> (StablePtr a -> IO CULong)
-> IO CULong
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (a -> IO (StablePtr a)
forall a. a -> IO (StablePtr a)
newStablePtr a
a) StablePtr a -> IO ()
forall a. StablePtr a -> IO ()
freeStablePtr ((StablePtr a -> IO CULong) -> IO CULong)
-> (StablePtr a -> IO CULong) -> IO CULong
forall a b. (a -> b) -> a -> b
$ \StablePtr a
stablePtr ->
      CUInt -> CUInt -> CUInt -> Ptr CUInt -> StablePtr a -> IO CULong
forall a.
CUInt -> CUInt -> CUInt -> Ptr CUInt -> StablePtr a -> IO CULong
closureSize_ CUInt
workListCapacity
                   CUInt
visitedInitCapacity
                   CUInt
visitedMaxCapacity
                   Ptr CUInt
err
                   StablePtr a
stablePtr

-- | Compute the size of the given closure
--
-- The size of the worklist should be set to the maximum expected /depth/ of
-- the closure; the size of the visited set should be set to the maximum /number
-- of nodes/ in the closure.
--
-- 'computeHeapSizeWorkList' can be used to estimate the size of the worklist
-- required.
computeHeapSize' :: PerformGC -- ^ Should we call GC before counting?
                 -> Word      -- ^ Capacity of the worklist
                 -> Word      -- ^ Initial capacity of the visited set
                 -> Word      -- ^ Maximum capacity of the visited set
                 -> a -> IO (Either CountFailure Word64)
computeHeapSize' :: PerformGC
-> Word -> Word -> Word -> a -> IO (Either CountFailure Word64)
computeHeapSize' PerformGC
performGC
                 Word
workListCapacity
                 Word
visitedInitCapacity
                 Word
visitedMaxCapacity
                 a
a
               = do
    (Ptr CUInt -> IO (Either CountFailure Word64))
-> IO (Either CountFailure Word64)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CUInt -> IO (Either CountFailure Word64))
 -> IO (Either CountFailure Word64))
-> (Ptr CUInt -> IO (Either CountFailure Word64))
-> IO (Either CountFailure Word64)
forall a b. (a -> b) -> a -> b
$ \(Ptr CUInt
err :: Ptr CUInt) -> do
      CULong
size     <- PerformGC -> CUInt -> CUInt -> CUInt -> Ptr CUInt -> a -> IO CULong
forall a.
PerformGC -> CUInt -> CUInt -> CUInt -> Ptr CUInt -> a -> IO CULong
closureSize PerformGC
performGC
                              CUInt
workListCapacity'
                              CUInt
visitedInitCapacity'
                              CUInt
visitedMaxCapacity'
                              Ptr CUInt
err
                              a
a
      Maybe CountFailure
mFailure <- CUInt -> Maybe CountFailure
toCountFailure (CUInt -> Maybe CountFailure)
-> IO CUInt -> IO (Maybe CountFailure)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CUInt -> IO CUInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CUInt
err
      Either CountFailure Word64 -> IO (Either CountFailure Word64)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either CountFailure Word64 -> IO (Either CountFailure Word64))
-> Either CountFailure Word64 -> IO (Either CountFailure Word64)
forall a b. (a -> b) -> a -> b
$ case Maybe CountFailure
mFailure of
                 Just CountFailure
failure -> CountFailure -> Either CountFailure Word64
forall a b. a -> Either a b
Left CountFailure
failure
                 Maybe CountFailure
Nothing      -> Word64 -> Either CountFailure Word64
forall a b. b -> Either a b
Right (CULong -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral CULong
size)
  where
    workListCapacity', visitedInitCapacity', visitedMaxCapacity' :: CUInt
    workListCapacity' :: CUInt
workListCapacity'    = Word -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
workListCapacity
    visitedInitCapacity' :: CUInt
visitedInitCapacity' = Word -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
visitedInitCapacity
    visitedMaxCapacity' :: CUInt
visitedMaxCapacity'  = Word -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
visitedMaxCapacity

-- | Compute the size of the given closure
--
-- This is a wrapper around 'computeHeapSize'' which sets some defaults for the
-- capacity of worklist and the visited set: it uses a worklist capacity of 10k
-- (which, assuming balanced data structures, should be more than enough), an
-- initial visited set capacity of 250k, and a maximum visited set capacity of
-- 16M. This means that this will use between 2 MB and 128 MB of heaps space.
--
-- It also does NOT perform GC before counting, for improved performance.
-- Client code can call 'performMajorGC' manually or use 'computeHeapSize''.
 --
-- Should these limits not be sufficient, or conversely, the memory requirements
-- be too large, use 'computeHeapSize'' directly.
computeHeapSize :: a -> IO (Either CountFailure Word64)
computeHeapSize :: a -> IO (Either CountFailure Word64)
computeHeapSize =
   PerformGC
-> Word -> Word -> Word -> a -> IO (Either CountFailure Word64)
forall a.
PerformGC
-> Word -> Word -> Word -> a -> IO (Either CountFailure Word64)
computeHeapSize' PerformGC
DontPerformGC
                    Word
workListCapacity
                    Word
visitedInitCapacity
                    Word
visitedMaxCapacity
  where
    -- Memory usage assuming 64-bit (i.e. 8 byte) pointers
    workListCapacity, visitedInitCapacity, visitedMaxCapacity :: Word
    workListCapacity :: Word
workListCapacity    =        Word
10 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
1000 --  80 kB
    visitedInitCapacity :: Word
visitedInitCapacity =       Word
250 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
1000 --   2 MB
    visitedMaxCapacity :: Word
visitedMaxCapacity  = Word
16 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
1000 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
1000 -- 128 MB

{-------------------------------------------------------------------------------
  Compute the depth of the closure
-------------------------------------------------------------------------------}

-- | Upper bound on the required work list size to compute closure size
--
-- NOTE: This ignores sharing, and so provides an upper bound only.
--
-- The size of a closure with no nested pointers can be computed without any
-- stack space.
--
-- When we have a closure with @(N + 1)@ nested pointers
--
-- > p0 p1 .. pN
--
-- We will
--
-- * Push @pN, .., p1, p0@ onto the stack
-- * Pop off @p0@ and count its children
-- * Pop off @p1@ and count its children
-- * ..
--
-- until we have processed all children. This means that the stack space
-- required will be the maximum of
--
-- > [ N + 1 -- For the initial list
-- > , requiredWorkList p0 + (N + 1) - 1
-- > , requiredWorkList p1 + (N + 1) - 2
-- > , ..
-- > , requiredWorkList pN + (N + 1) - (N + 1)
-- > ]
--
-- For example, for a list, we would get that
--
-- > requiredWorkList []     == 0
-- > requiredWorkList (x:xs) == max [ 2
-- >                                , requiredWorkList x + 1
-- >                                , requiredWorkList xs
-- >                                ]
--
-- which, for a list of @Int@ (which requires only a stack of size 1), equals 2
-- (unless the list is empty).
--
-- Similarly, for binary trees, we get
--
-- > requiredWorkList Leaf           == 0
-- > requiredWorkList (Branch l x r) == max [ 3
-- >                                        , requiredWorkList l + 2
-- >                                        , requiredWorkList x + 1
-- >                                        , requiredWorkList r
-- >                                        ]
--
-- which, for a tree of @Int@, is bound by @(height * 2) + 1@.
computeHeapSizeWorkList :: a -> Word64
computeHeapSizeWorkList :: a -> Word64
computeHeapSizeWorkList a
a =
    [Word64] -> Word64
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Word64] -> Word64) -> [Word64] -> Word64
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# (Array# Any -> Int#
forall a. Array# a -> Int#
sizeofArray# Array# Any
ptrs))
            Word64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
: ((Any, Word64) -> Word64) -> [(Any, Word64)] -> [Word64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (Any, Word64) -> Word64
nested ([(Any, Word64)] -> Int# -> [(Any, Word64)]
collect [] Int#
0#)
  where
    ptrs :: Array# Any
    !(# Addr#
_addr, ByteArray#
_raw, Array# Any
ptrs #) = a -> (# Addr#, ByteArray#, Array# Any #)
forall a b. a -> (# Addr#, ByteArray#, Array# b #)
unpackClosure# a
a

    -- Recursive worklist size of nested pointer @p@, with additional stack @n@
    nested :: (Any, Word64) -> Word64
    nested :: (Any, Word64) -> Word64
nested (Any
p, Word64
n) = Any -> Word64
forall a. a -> Word64
computeHeapSizeWorkList Any
p Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
n

    -- @collect [] 0@ will construct the sequence
    --
    -- > [ (p0, (N + 1) - 1)
    -- > , (p1, (N + 1) - 2)
    -- > , ..
    -- > , (pN, (N + 1) - (N + 1))
    -- > ]
    collect :: [(Any, Word64)] -> Int# -> [(Any, Word64)]
    collect :: [(Any, Word64)] -> Int# -> [(Any, Word64)]
collect [(Any, Word64)]
acc Int#
ix =
        case Int#
ix Int# -> Int# -> Int#
<# Array# Any -> Int#
forall a. Array# a -> Int#
sizeofArray# Array# Any
ptrs of
          Int#
0# -> [(Any, Word64)]
acc
          Int#
_  -> let n :: Word64
                    !n :: Word64
n = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# (Array# Any -> Int#
forall a. Array# a -> Int#
sizeofArray# Array# Any
ptrs Int# -> Int# -> Int#
-# (Int#
ix Int# -> Int# -> Int#
+# Int#
1#)))
                in case Array# Any -> Int# -> (# Any #)
forall a. Array# a -> Int# -> (# a #)
indexArray# Array# Any
ptrs Int#
ix of
                     (# Any
p #) -> [(Any, Word64)] -> Int# -> [(Any, Word64)]
collect ((Any
p, Word64
n) (Any, Word64) -> [(Any, Word64)] -> [(Any, Word64)]
forall a. a -> [a] -> [a]
: [(Any, Word64)]
acc) (Int#
ix Int# -> Int# -> Int#
+# Int#
1#)