{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Cardano.Crypto.Libsodium.MLockedBytes.Internal (
    MLockedSizedBytes (..),
    mlsbZero,
    mlsbFromByteString,
    mlsbFromByteStringCheck,
    mlsbToByteString,
    mlsbUseAsCPtr,
    mlsbUseAsSizedPtr,
    mlsbFinalize,
) where

import Control.DeepSeq (NFData (..))
import Data.Proxy (Proxy (..))
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (castForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import GHC.TypeLits (KnownNat, natVal)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))
import System.IO.Unsafe (unsafeDupablePerformIO)
import Data.Word (Word8)

import Cardano.Foreign
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.PinnedSizedBytes

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI

{- HLINT ignore "Reduce duplication" -}

newtype MLockedSizedBytes n = MLSB (MLockedForeignPtr (PinnedSizedBytes n))
  deriving Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
Proxy (MLockedSizedBytes n) -> String
(Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo))
-> (Proxy (MLockedSizedBytes n) -> String)
-> NoThunks (MLockedSizedBytes n)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
showTypeOf :: Proxy (MLockedSizedBytes n) -> String
$cshowTypeOf :: forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
wNoThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cnoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
NoThunks via OnlyCheckWhnfNamed "MLockedSizedBytes" (MLockedSizedBytes n)

instance KnownNat n => Eq (MLockedSizedBytes n) where
    MLockedSizedBytes n
x == :: MLockedSizedBytes n -> MLockedSizedBytes n -> Bool
== MLockedSizedBytes n
y = MLockedSizedBytes n -> MLockedSizedBytes n -> Ordering
forall a. Ord a => a -> a -> Ordering
compare MLockedSizedBytes n
x MLockedSizedBytes n
y Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance KnownNat n => Ord (MLockedSizedBytes n) where
    compare :: MLockedSizedBytes n -> MLockedSizedBytes n -> Ordering
compare (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) (MLSB MLockedForeignPtr (PinnedSizedBytes n)
y) = IO Ordering -> Ordering
forall a. IO a -> a
unsafeDupablePerformIO (IO Ordering -> Ordering) -> IO Ordering -> Ordering
forall a b. (a -> b) -> a -> b
$
        MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x ((Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering)
-> (Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
x' ->
        MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
y ((Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering)
-> (Ptr (PinnedSizedBytes n) -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
y' -> do
            Int
res <- Ptr (PinnedSizedBytes n)
-> Ptr (PinnedSizedBytes n) -> CSize -> IO Int
forall a. Ptr a -> Ptr a -> CSize -> IO Int
c_sodium_compare Ptr (PinnedSizedBytes n)
x' Ptr (PinnedSizedBytes n)
y' (Word64 -> CSize
CSize (Integer -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
size))
            Ordering -> IO Ordering
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
res Int
0)
      where
        size :: Integer
size = Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n)

instance KnownNat n => Show (MLockedSizedBytes n) where
    showsPrec :: Int -> MLockedSizedBytes n -> ShowS
showsPrec Int
d MLockedSizedBytes n
_ = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10)
        (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"_ :: MLockedSizedBytes "
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

instance NFData (MLockedSizedBytes n) where
    rnf :: MLockedSizedBytes n -> ()
rnf (MLSB MLockedForeignPtr (PinnedSizedBytes n)
p) = MLockedForeignPtr (PinnedSizedBytes n) -> () -> ()
seq MLockedForeignPtr (PinnedSizedBytes n)
p ()

-- | Note: this doesn't need to allocate mlocked memory,
-- but we do that for consistency
mlsbZero :: forall n. KnownNat n => MLockedSizedBytes n
mlsbZero :: MLockedSizedBytes n
mlsbZero = IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a. IO a -> a
unsafeDupablePerformIO (IO (MLockedSizedBytes n) -> MLockedSizedBytes n)
-> IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a b. (a -> b) -> a -> b
$ do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- IO (MLockedForeignPtr (PinnedSizedBytes n))
forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr ((Ptr (PinnedSizedBytes n) -> IO ()) -> IO ())
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- Ptr Any -> Int -> CSize -> IO (Ptr ())
forall a. Ptr a -> Int -> CSize -> IO (Ptr ())
c_memset (Ptr (PinnedSizedBytes n) -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Int
0 CSize
size
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    MLockedSizedBytes n -> IO (MLockedSizedBytes n)
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

mlsbFromByteString :: forall n. KnownNat n => BS.ByteString -> MLockedSizedBytes n
mlsbFromByteString :: ByteString -> MLockedSizedBytes n
mlsbFromByteString ByteString
bs = IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a. IO a -> a
unsafeDupablePerformIO (IO (MLockedSizedBytes n) -> MLockedSizedBytes n)
-> IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a b. (a -> b) -> a -> b
$ ByteString
-> (CStringLen -> IO (MLockedSizedBytes n))
-> IO (MLockedSizedBytes n)
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs ((CStringLen -> IO (MLockedSizedBytes n))
 -> IO (MLockedSizedBytes n))
-> (CStringLen -> IO (MLockedSizedBytes n))
-> IO (MLockedSizedBytes n)
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptrBS, Int
len) -> do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- IO (MLockedForeignPtr (PinnedSizedBytes n))
forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr ((Ptr (PinnedSizedBytes n) -> IO ()) -> IO ())
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- Ptr CChar -> Ptr CChar -> CSize -> IO (Ptr ())
forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (Ptr (PinnedSizedBytes n) -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Ptr CChar
ptrBS (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
len Int
size))
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    MLockedSizedBytes n -> IO (MLockedSizedBytes n)
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

mlsbFromByteStringCheck :: forall n. KnownNat n => BS.ByteString -> Maybe (MLockedSizedBytes n)
mlsbFromByteStringCheck :: ByteString -> Maybe (MLockedSizedBytes n)
mlsbFromByteStringCheck ByteString
bs
    | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
size = Maybe (MLockedSizedBytes n)
forall a. Maybe a
Nothing
    | Bool
otherwise = MLockedSizedBytes n -> Maybe (MLockedSizedBytes n)
forall a. a -> Maybe a
Just (MLockedSizedBytes n -> Maybe (MLockedSizedBytes n))
-> MLockedSizedBytes n -> Maybe (MLockedSizedBytes n)
forall a b. (a -> b) -> a -> b
$ IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a. IO a -> a
unsafeDupablePerformIO (IO (MLockedSizedBytes n) -> MLockedSizedBytes n)
-> IO (MLockedSizedBytes n) -> MLockedSizedBytes n
forall a b. (a -> b) -> a -> b
$ ByteString
-> (CStringLen -> IO (MLockedSizedBytes n))
-> IO (MLockedSizedBytes n)
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs ((CStringLen -> IO (MLockedSizedBytes n))
 -> IO (MLockedSizedBytes n))
-> (CStringLen -> IO (MLockedSizedBytes n))
-> IO (MLockedSizedBytes n)
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptrBS, Int
len) -> do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- IO (MLockedForeignPtr (PinnedSizedBytes n))
forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr ((Ptr (PinnedSizedBytes n) -> IO ()) -> IO ())
-> (Ptr (PinnedSizedBytes n) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- Ptr CChar -> Ptr CChar -> CSize -> IO (Ptr ())
forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (Ptr (PinnedSizedBytes n) -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Ptr CChar
ptrBS (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
len Int
size))
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    MLockedSizedBytes n -> IO (MLockedSizedBytes n)
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

-- | /Note:/ the resulting 'BS.ByteString' will still refer to secure memory,
-- but the types don't prevent it from be exposed.
--
mlsbToByteString :: forall n. KnownNat n => MLockedSizedBytes n -> BS.ByteString
mlsbToByteString :: MLockedSizedBytes n -> ByteString
mlsbToByteString (MLSB (SFP ForeignPtr (PinnedSizedBytes n)
fptr)) = ForeignPtr Word8 -> Int -> Int -> ByteString
BSI.PS (ForeignPtr (PinnedSizedBytes n) -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (PinnedSizedBytes n)
fptr) Int
0 Int
size where
    size  :: Int
    size :: Int
size = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n))

mlsbUseAsCPtr :: MLockedSizedBytes n -> (Ptr Word8 -> IO r) -> IO r
mlsbUseAsCPtr :: MLockedSizedBytes n -> (Ptr Word8 -> IO r) -> IO r
mlsbUseAsCPtr (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) Ptr Word8 -> IO r
k = MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO r) -> IO r
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x (Ptr Word8 -> IO r
k (Ptr Word8 -> IO r)
-> (Ptr (PinnedSizedBytes n) -> Ptr Word8)
-> Ptr (PinnedSizedBytes n)
-> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (PinnedSizedBytes n) -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr)

mlsbUseAsSizedPtr :: MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr :: MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) SizedPtr n -> IO r
k = MLockedForeignPtr (PinnedSizedBytes n)
-> (Ptr (PinnedSizedBytes n) -> IO r) -> IO r
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x (SizedPtr n -> IO r
k (SizedPtr n -> IO r)
-> (Ptr (PinnedSizedBytes n) -> SizedPtr n)
-> Ptr (PinnedSizedBytes n)
-> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (PinnedSizedBytes n) -> SizedPtr n
forall (n :: Nat). Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr)

-- | Calls 'finalizeMLockedForeignPtr' on underlying pointer.
-- This function invalidates argument.
--
mlsbFinalize :: MLockedSizedBytes n -> IO ()
mlsbFinalize :: MLockedSizedBytes n -> IO ()
mlsbFinalize (MLSB MLockedForeignPtr (PinnedSizedBytes n)
ptr) = MLockedForeignPtr (PinnedSizedBytes n) -> IO ()
forall a. MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
ptr