{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Cardano.Crypto.Libsodium.Memory.Internal (
  -- * High-level memory management
  MLockedForeignPtr (..),
  withMLockedForeignPtr,
  allocMLockedForeignPtr,
  finalizeMLockedForeignPtr,
  traceMLockedForeignPtr,
  -- * Low-level memory function
  mlockedAlloca,
  mlockedAllocaSized,
  sodiumMalloc,
  sodiumFree,
) where

import Control.Exception (bracket)
import Control.Monad (when)
import Data.Coerce (coerce)
import Data.Proxy (Proxy (..))
import Foreign.C.Error (errnoToIOError, getErrno)
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (ForeignPtr, newForeignPtr, withForeignPtr, finalizeForeignPtr)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (Storable (alignment, sizeOf, peek))
import GHC.TypeLits (KnownNat, natVal)
import GHC.IO.Exception (ioException)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))

import Cardano.Foreign
import Cardano.Crypto.Libsodium.C

-- | Foreign pointer to securely allocated memory.
newtype MLockedForeignPtr a = SFP { MLockedForeignPtr a -> ForeignPtr a
_unwrapMLockedForeignPtr :: ForeignPtr a }
  deriving Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
Proxy (MLockedForeignPtr a) -> String
(Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo))
-> (Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo))
-> (Proxy (MLockedForeignPtr a) -> String)
-> NoThunks (MLockedForeignPtr a)
forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
forall a. Proxy (MLockedForeignPtr a) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy (MLockedForeignPtr a) -> String
$cshowTypeOf :: forall a. Proxy (MLockedForeignPtr a) -> String
wNoThunks :: Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
$cnoThunks :: forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
NoThunks via OnlyCheckWhnfNamed "MLockedForeignPtr" (MLockedForeignPtr a)

withMLockedForeignPtr :: forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr :: MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr = (ForeignPtr a -> (Ptr a -> IO b) -> IO b)
-> MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
coerce (ForeignPtr a -> (Ptr a -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr @a @b)

finalizeMLockedForeignPtr :: forall a. MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr :: MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr = (ForeignPtr a -> IO ()) -> MLockedForeignPtr a -> IO ()
coerce (ForeignPtr a -> IO ()
forall a. ForeignPtr a -> IO ()
finalizeForeignPtr @a)

traceMLockedForeignPtr :: (Storable a, Show a) => MLockedForeignPtr a -> IO ()
traceMLockedForeignPtr :: MLockedForeignPtr a -> IO ()
traceMLockedForeignPtr MLockedForeignPtr a
fptr = MLockedForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr a
fptr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
    a
a <- Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
ptr
    a -> IO ()
forall a. Show a => a -> IO ()
print a
a

{-# DEPRECATED traceMLockedForeignPtr "Don't leave traceMLockedForeignPtr in production" #-}

-- | Allocate secure memory using 'c_sodium_malloc'.
--
-- <https://libsodium.gitbook.io/doc/memory_management>
--
allocMLockedForeignPtr :: Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr :: IO (MLockedForeignPtr a)
allocMLockedForeignPtr = a -> IO (MLockedForeignPtr a)
forall b. Storable b => b -> IO (MLockedForeignPtr b)
impl a
forall a. HasCallStack => a
undefined where
    impl :: forall b. Storable b => b -> IO (MLockedForeignPtr b)
    impl :: b -> IO (MLockedForeignPtr b)
impl b
b = do
        Ptr b
ptr <- CSize -> IO (Ptr b)
forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size
        (ForeignPtr b -> MLockedForeignPtr b)
-> IO (ForeignPtr b) -> IO (MLockedForeignPtr b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ForeignPtr b -> MLockedForeignPtr b
forall a. ForeignPtr a -> MLockedForeignPtr a
SFP (FinalizerPtr b -> Ptr b -> IO (ForeignPtr b)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr b
forall a. FunPtr (Ptr a -> IO ())
c_sodium_free_funptr Ptr b
ptr)

      where
        size :: CSize
        size :: CSize
size = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size''

        size' :: Int
        size' :: Int
size' = b -> Int
forall a. Storable a => a -> Int
sizeOf b
b

        align :: Int
        align :: Int
align = b -> Int
forall a. Storable a => a -> Int
alignment b
b

        size'' :: Int
        size'' :: Int
size''
            | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = Int
size'
            | Bool
otherwise = (Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
align
          where
            (Int
q,Int
m) = Int
size' Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
align

sodiumMalloc :: CSize -> IO (Ptr a)
sodiumMalloc :: CSize -> IO (Ptr a)
sodiumMalloc CSize
size = do
    Ptr a
ptr <- CSize -> IO (Ptr a)
forall a. CSize -> IO (Ptr a)
c_sodium_malloc CSize
size
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr Ptr a -> Ptr a -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr a
forall a. Ptr a
nullPtr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Errno
errno <- IO Errno
getErrno
        IOException -> IO ()
forall a. IOException -> IO a
ioException (IOException -> IO ()) -> IOException -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"c_sodium_malloc" Errno
errno Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing
    Ptr a -> IO (Ptr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
ptr

sodiumFree :: Ptr a -> IO ()
sodiumFree :: Ptr a -> IO ()
sodiumFree = Ptr a -> IO ()
forall a. Ptr a -> IO ()
c_sodium_free

mlockedAlloca :: forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca :: CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size = IO (Ptr a) -> (Ptr a -> IO ()) -> (Ptr a -> IO b) -> IO b
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (CSize -> IO (Ptr a)
forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size) Ptr a -> IO ()
forall a. Ptr a -> IO ()
sodiumFree

mlockedAllocaSized :: forall n b. KnownNat n => (SizedPtr n -> IO b) -> IO b
mlockedAllocaSized :: (SizedPtr n -> IO b) -> IO b
mlockedAllocaSized SizedPtr n -> IO b
k = CSize -> (Ptr Void -> IO b) -> IO b
forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size (SizedPtr n -> IO b
k (SizedPtr n -> IO b)
-> (Ptr Void -> SizedPtr n) -> Ptr Void -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Void -> SizedPtr n
forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr) where
    size :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))