module Foundation.Random.ChaChaDRG
    ( State(..)
    , keySize
    ) where

import           Foundation.Class.Storable (peek)
import           Basement.Imports
import           Basement.Types.OffsetSize
import           Basement.Monad
import           Foundation.Random.Class
import           Foundation.Random.DRG
import qualified Basement.UArray as A
import qualified Basement.UArray.Mutable as A
import           GHC.ST
import qualified Foreign.Marshal.Alloc (alloca)

-- | RNG based on ChaCha core.
--
-- The algorithm is identical to the arc4random found in recent BSDs,
-- namely a ChaCha core provide 64 bytes of random from 32 bytes of
-- key.
newtype State = State (UArray Word8)

instance RandomGen State where
    randomNew :: m State
randomNew = UArray Word8 -> State
State (UArray Word8 -> State) -> m (UArray Word8) -> m State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CountOf Word8 -> m (UArray Word8)
forall (m :: * -> *).
MonadRandom m =>
CountOf Word8 -> m (UArray Word8)
getRandomBytes CountOf Word8
keySize
    randomNewFrom :: UArray Word8 -> Maybe State
randomNewFrom UArray Word8
bs
        | UArray Word8 -> CountOf Word8
forall ty. UArray ty -> CountOf ty
A.length UArray Word8
bs CountOf Word8 -> CountOf Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== CountOf Word8
keySize = State -> Maybe State
forall a. a -> Maybe a
Just (State -> Maybe State) -> State -> Maybe State
forall a b. (a -> b) -> a -> b
$ UArray Word8 -> State
State UArray Word8
bs
        | Bool
otherwise              = Maybe State
forall a. Maybe a
Nothing
    randomGenerate :: CountOf Word8 -> State -> (UArray Word8, State)
randomGenerate = CountOf Word8 -> State -> (UArray Word8, State)
generate
    randomGenerateWord64 :: State -> (Word64, State)
randomGenerateWord64 = State -> (Word64, State)
generateWord64
    randomGenerateF32 :: State -> (Float, State)
randomGenerateF32 = State -> (Float, State)
generateF32
    randomGenerateF64 :: State -> (Double, State)
randomGenerateF64 = State -> (Double, State)
generateF64

keySize :: CountOf Word8
keySize :: CountOf Word8
keySize = CountOf Word8
32

generate :: CountOf Word8 -> State -> (UArray Word8, State)
generate :: CountOf Word8 -> State -> (UArray Word8, State)
generate CountOf Word8
n (State UArray Word8
key) = (forall s. ST s (UArray Word8, State)) -> (UArray Word8, State)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (UArray Word8, State)) -> (UArray Word8, State))
-> (forall s. ST s (UArray Word8, State)) -> (UArray Word8, State)
forall a b. (a -> b) -> a -> b
$ do
    MUArray Word8 s
dst    <- CountOf Word8 -> ST s (MUArray Word8 (PrimState (ST s)))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MUArray ty (PrimState prim))
A.newPinned CountOf Word8
n
    MUArray Word8 s
newKey <- CountOf Word8 -> ST s (MUArray Word8 (PrimState (ST s)))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MUArray ty (PrimState prim))
A.newPinned CountOf Word8
keySize
    MUArray Word8 (PrimState (ST s))
-> (Ptr Word8 -> ST s ()) -> ST s ()
forall (prim :: * -> *) ty a.
(PrimMonad prim, PrimType ty) =>
MUArray ty (PrimState prim) -> (Ptr ty -> prim a) -> prim a
A.withMutablePtr MUArray Word8 s
MUArray Word8 (PrimState (ST s))
dst        ((Ptr Word8 -> ST s ()) -> ST s ())
-> (Ptr Word8 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstP    ->
        MUArray Word8 (PrimState (ST s))
-> (Ptr Word8 -> ST s ()) -> ST s ()
forall (prim :: * -> *) ty a.
(PrimMonad prim, PrimType ty) =>
MUArray ty (PrimState prim) -> (Ptr ty -> prim a) -> prim a
A.withMutablePtr MUArray Word8 s
MUArray Word8 (PrimState (ST s))
newKey ((Ptr Word8 -> ST s ()) -> ST s ())
-> (Ptr Word8 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
newKeyP ->
        UArray Word8 -> (Ptr Word8 -> ST s ()) -> ST s ()
forall ty (prim :: * -> *) a.
(PrimMonad prim, PrimType ty) =>
UArray ty -> (Ptr ty -> prim a) -> prim a
A.withPtr UArray Word8
key           ((Ptr Word8 -> ST s ()) -> ST s ())
-> (Ptr Word8 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyP    -> do
            Word32
_ <- IO Word32 -> ST s Word32
forall (prim :: * -> *) a. PrimMonad prim => IO a -> prim a
unsafePrimFromIO (IO Word32 -> ST s Word32) -> IO Word32 -> ST s Word32
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> CountOf Word8 -> IO Word32
c_rngv1_generate Ptr Word8
newKeyP Ptr Word8
dstP Ptr Word8
keyP CountOf Word8
n
            () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    (,) (UArray Word8 -> State -> (UArray Word8, State))
-> ST s (UArray Word8) -> ST s (State -> (UArray Word8, State))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MUArray Word8 (PrimState (ST s)) -> ST s (UArray Word8)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MUArray ty (PrimState prim) -> prim (UArray ty)
A.unsafeFreeze MUArray Word8 s
MUArray Word8 (PrimState (ST s))
dst
        ST s (State -> (UArray Word8, State))
-> ST s State -> ST s (UArray Word8, State)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (UArray Word8 -> State
State (UArray Word8 -> State) -> ST s (UArray Word8) -> ST s State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MUArray Word8 (PrimState (ST s)) -> ST s (UArray Word8)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MUArray ty (PrimState prim) -> prim (UArray ty)
A.unsafeFreeze MUArray Word8 s
MUArray Word8 (PrimState (ST s))
newKey)

generateWord64 :: State -> (Word64, State)
generateWord64 :: State -> (Word64, State)
generateWord64 (State UArray Word8
key) = (forall s. ST s (Word64, State)) -> (Word64, State)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Word64, State)) -> (Word64, State))
-> (forall s. ST s (Word64, State)) -> (Word64, State)
forall a b. (a -> b) -> a -> b
$ IO (Word64, State) -> ST s (Word64, State)
forall (prim :: * -> *) a. PrimMonad prim => IO a -> prim a
unsafePrimFromIO (IO (Word64, State) -> ST s (Word64, State))
-> IO (Word64, State) -> ST s (Word64, State)
forall a b. (a -> b) -> a -> b
$
    (Ptr Word64 -> IO (Word64, State)) -> IO (Word64, State)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
Foreign.Marshal.Alloc.alloca ((Ptr Word64 -> IO (Word64, State)) -> IO (Word64, State))
-> (Ptr Word64 -> IO (Word64, State)) -> IO (Word64, State)
forall a b. (a -> b) -> a -> b
$ \Ptr Word64
dst -> do
        MUArray Word8 RealWorld
newKey <- CountOf Word8 -> IO (MUArray Word8 (PrimState IO))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MUArray ty (PrimState prim))
A.newPinned CountOf Word8
keySize
        MUArray Word8 (PrimState IO) -> (Ptr Word8 -> IO ()) -> IO ()
forall (prim :: * -> *) ty a.
(PrimMonad prim, PrimType ty) =>
MUArray ty (PrimState prim) -> (Ptr ty -> prim a) -> prim a
A.withMutablePtr MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
newKeyP ->
          UArray Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall ty (prim :: * -> *) a.
(PrimMonad prim, PrimType ty) =>
UArray ty -> (Ptr ty -> prim a) -> prim a
A.withPtr UArray Word8
key           ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyP  ->
            Ptr Word8 -> Ptr Word64 -> Ptr Word8 -> IO Word32
c_rngv1_generate_word64 Ptr Word8
newKeyP Ptr Word64
dst Ptr Word8
keyP IO Word32 -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        (,) (Word64 -> State -> (Word64, State))
-> IO Word64 -> IO (State -> (Word64, State))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word64 -> IO Word64
forall a. Storable a => Ptr a -> IO a
peek Ptr Word64
dst IO (State -> (Word64, State)) -> IO State -> IO (Word64, State)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (UArray Word8 -> State
State (UArray Word8 -> State) -> IO (UArray Word8) -> IO State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MUArray Word8 (PrimState IO) -> IO (UArray Word8)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MUArray ty (PrimState prim) -> prim (UArray ty)
A.unsafeFreeze MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey)

generateF32 :: State -> (Float, State)
generateF32 :: State -> (Float, State)
generateF32 (State UArray Word8
key) = (forall s. ST s (Float, State)) -> (Float, State)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Float, State)) -> (Float, State))
-> (forall s. ST s (Float, State)) -> (Float, State)
forall a b. (a -> b) -> a -> b
$ IO (Float, State) -> ST s (Float, State)
forall (prim :: * -> *) a. PrimMonad prim => IO a -> prim a
unsafePrimFromIO (IO (Float, State) -> ST s (Float, State))
-> IO (Float, State) -> ST s (Float, State)
forall a b. (a -> b) -> a -> b
$
    (Ptr Float -> IO (Float, State)) -> IO (Float, State)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
Foreign.Marshal.Alloc.alloca ((Ptr Float -> IO (Float, State)) -> IO (Float, State))
-> (Ptr Float -> IO (Float, State)) -> IO (Float, State)
forall a b. (a -> b) -> a -> b
$ \Ptr Float
dst -> do
        MUArray Word8 RealWorld
newKey <- CountOf Word8 -> IO (MUArray Word8 (PrimState IO))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MUArray ty (PrimState prim))
A.newPinned CountOf Word8
keySize
        MUArray Word8 (PrimState IO) -> (Ptr Word8 -> IO ()) -> IO ()
forall (prim :: * -> *) ty a.
(PrimMonad prim, PrimType ty) =>
MUArray ty (PrimState prim) -> (Ptr ty -> prim a) -> prim a
A.withMutablePtr MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
newKeyP ->
          UArray Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall ty (prim :: * -> *) a.
(PrimMonad prim, PrimType ty) =>
UArray ty -> (Ptr ty -> prim a) -> prim a
A.withPtr UArray Word8
key           ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyP  ->
            Ptr Word8 -> Ptr Float -> Ptr Word8 -> IO Word32
c_rngv1_generate_f32 Ptr Word8
newKeyP Ptr Float
dst Ptr Word8
keyP IO Word32 -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        (,) (Float -> State -> (Float, State))
-> IO Float -> IO (State -> (Float, State))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Float -> IO Float
forall a. Storable a => Ptr a -> IO a
peek Ptr Float
dst IO (State -> (Float, State)) -> IO State -> IO (Float, State)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (UArray Word8 -> State
State (UArray Word8 -> State) -> IO (UArray Word8) -> IO State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MUArray Word8 (PrimState IO) -> IO (UArray Word8)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MUArray ty (PrimState prim) -> prim (UArray ty)
A.unsafeFreeze MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey)

generateF64 :: State -> (Double, State)
generateF64 :: State -> (Double, State)
generateF64 (State UArray Word8
key) = (forall s. ST s (Double, State)) -> (Double, State)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Double, State)) -> (Double, State))
-> (forall s. ST s (Double, State)) -> (Double, State)
forall a b. (a -> b) -> a -> b
$ IO (Double, State) -> ST s (Double, State)
forall (prim :: * -> *) a. PrimMonad prim => IO a -> prim a
unsafePrimFromIO (IO (Double, State) -> ST s (Double, State))
-> IO (Double, State) -> ST s (Double, State)
forall a b. (a -> b) -> a -> b
$
    (Ptr Double -> IO (Double, State)) -> IO (Double, State)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
Foreign.Marshal.Alloc.alloca ((Ptr Double -> IO (Double, State)) -> IO (Double, State))
-> (Ptr Double -> IO (Double, State)) -> IO (Double, State)
forall a b. (a -> b) -> a -> b
$ \Ptr Double
dst -> do
        MUArray Word8 RealWorld
newKey <- CountOf Word8 -> IO (MUArray Word8 (PrimState IO))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MUArray ty (PrimState prim))
A.newPinned CountOf Word8
keySize
        MUArray Word8 (PrimState IO) -> (Ptr Word8 -> IO ()) -> IO ()
forall (prim :: * -> *) ty a.
(PrimMonad prim, PrimType ty) =>
MUArray ty (PrimState prim) -> (Ptr ty -> prim a) -> prim a
A.withMutablePtr MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
newKeyP ->
          UArray Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall ty (prim :: * -> *) a.
(PrimMonad prim, PrimType ty) =>
UArray ty -> (Ptr ty -> prim a) -> prim a
A.withPtr UArray Word8
key           ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyP  ->
            Ptr Word8 -> Ptr Double -> Ptr Word8 -> IO Word32
c_rngv1_generate_f64 Ptr Word8
newKeyP Ptr Double
dst Ptr Word8
keyP IO Word32 -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        (,) (Double -> State -> (Double, State))
-> IO Double -> IO (State -> (Double, State))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Double -> IO Double
forall a. Storable a => Ptr a -> IO a
peek Ptr Double
dst IO (State -> (Double, State)) -> IO State -> IO (Double, State)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (UArray Word8 -> State
State (UArray Word8 -> State) -> IO (UArray Word8) -> IO State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MUArray Word8 (PrimState IO) -> IO (UArray Word8)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MUArray ty (PrimState prim) -> prim (UArray ty)
A.unsafeFreeze MUArray Word8 RealWorld
MUArray Word8 (PrimState IO)
newKey)

-- return 0 on success, !0 for failure
foreign import ccall unsafe "foundation_rngV1_generate"
   c_rngv1_generate :: Ptr Word8 -- new key
                    -> Ptr Word8 -- destination
                    -> Ptr Word8 -- current key
                    -> CountOf Word8 -- number of bytes to generate
                    -> IO Word32

foreign import ccall unsafe "foundation_rngV1_generate_word64"
   c_rngv1_generate_word64 :: Ptr Word8  -- new key
                           -> Ptr Word64 -- destination
                           -> Ptr Word8  -- current key
                           -> IO Word32

foreign import ccall unsafe "foundation_rngV1_generate_f32"
   c_rngv1_generate_f32 :: Ptr Word8  -- new key
                        -> Ptr Float -- destination
                        -> Ptr Word8  -- current key
                        -> IO Word32

foreign import ccall unsafe "foundation_rngV1_generate_f64"
   c_rngv1_generate_f64 :: Ptr Word8  -- new key
                        -> Ptr Double -- destination
                        -> Ptr Word8  -- current key
                        -> IO Word32