{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- | Utilities for FFI
module Cardano.Foreign (
    -- * Sized pointer
    SizedPtr (..),
    allocaSized,
    memcpySized,
    memsetSized,
    -- * Low-level C functions
    c_memcpy,
    c_memset,
) where

import Control.Monad (void)
import Data.Void (Void)
import Data.Word (Word8)
import Data.Proxy (Proxy (..))
import Foreign.Ptr (Ptr)
import Foreign.C.Types (CSize (..))
import Foreign.Marshal.Alloc (allocaBytes)
import GHC.TypeLits

-------------------------------------------------------------------------------
-- Sized pointer
-------------------------------------------------------------------------------

-- A pointer which knows the size of underlying memory block
newtype SizedPtr (n :: Nat) = SizedPtr (Ptr Void)

-- | Like 'allocaBytes'.
allocaSized :: forall n b. KnownNat n => (SizedPtr n -> IO b) -> IO b
allocaSized :: (SizedPtr n -> IO b) -> IO b
allocaSized SizedPtr n -> IO b
k = Int -> (Ptr Void -> IO b) -> IO b
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
size (SizedPtr n -> IO b
k (SizedPtr n -> IO b)
-> (Ptr Void -> SizedPtr n) -> Ptr Void -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr)
  where
    size :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

memcpySized :: forall n. KnownNat n => SizedPtr n -> SizedPtr n -> IO ()
memcpySized :: SizedPtr n -> SizedPtr n -> IO ()
memcpySized (SizedPtr Ptr Void
dest) (SizedPtr Ptr Void
src) = IO (Ptr ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Ptr Void -> Ptr Void -> CSize -> IO (Ptr ())
forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy Ptr Void
dest Ptr Void
src CSize
size)
  where
    size :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

memsetSized :: forall n. KnownNat n => SizedPtr n -> Word8 -> IO ()
memsetSized :: SizedPtr n -> Word8 -> IO ()
memsetSized (SizedPtr Ptr Void
s) Word8
c = IO (Ptr ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Ptr Void -> Int -> CSize -> IO (Ptr ())
forall a. Ptr a -> Int -> CSize -> IO (Ptr ())
c_memset Ptr Void
s (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
c) CSize
size)
  where
    size :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

-------------------------------------------------------------------------------
-- Some C functions
-------------------------------------------------------------------------------

-- | @void *memcpy(void *dest, const void *src, size_t n);@
--
-- Note: this is safe foreign import
foreign import ccall "memcpy"
    c_memcpy :: Ptr a -> Ptr a -> CSize -> IO (Ptr ())

-- | @void *memset(void *s, int c, size_t n);@
--
-- Note: for sure zeroing memory use @c_sodium_memzero@.
foreign import ccall "memset"
    c_memset :: Ptr a -> Int -> CSize -> IO (Ptr ())