{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Cardano.Crypto.Libsodium.Hash (
    SodiumHashAlgorithm (..),
    digestMLockedStorable,
    digestMLockedBS,
    expandHash,
) where

import Control.Monad (unless)
import Data.Proxy (Proxy (..))
import Foreign.C.Error (errnoToIOError, getErrno)
import Foreign.C.Types (CSize)
import Foreign.Ptr (Ptr, castPtr, nullPtr, plusPtr)
import Foreign.Storable (Storable (sizeOf, poke))
import Data.Word (Word8)
import Data.Type.Equality ((:~:)(..))
import GHC.IO.Exception (ioException)
import GHC.TypeLits
import System.IO.Unsafe (unsafeDupablePerformIO)

import qualified Data.ByteString as BS

import Cardano.Foreign
import Cardano.Crypto.Hash (HashAlgorithm(SizeHash), SHA256, Blake2b_256)
import Cardano.Crypto.PinnedSizedBytes (ptrPsbToSizedPtr)
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.MLockedBytes.Internal

-------------------------------------------------------------------------------
-- Type-Class
-------------------------------------------------------------------------------

class HashAlgorithm h => SodiumHashAlgorithm h where
    -- This function is in IO, it is "morally pure"
    -- and can be 'unsafePerformDupableIO'd.
    naclDigestPtr
        :: proxy h
        -> Ptr a  -- ^ input
        -> Int    -- ^ input length
        -> IO (MLockedSizedBytes (SizeHash h))

    -- TODO: provide interface for multi-part?
    -- That will be useful to hashing ('1' <> oldseed).

digestMLockedStorable
    :: forall h a proxy. (SodiumHashAlgorithm h, Storable a)
    => proxy h -> Ptr a -> MLockedSizedBytes (SizeHash h)
digestMLockedStorable :: proxy h -> Ptr a -> MLockedSizedBytes (SizeHash h)
digestMLockedStorable proxy h
p Ptr a
ptr = IO (MLockedSizedBytes (SizeHash h))
-> MLockedSizedBytes (SizeHash h)
forall a. IO a -> a
unsafeDupablePerformIO (IO (MLockedSizedBytes (SizeHash h))
 -> MLockedSizedBytes (SizeHash h))
-> IO (MLockedSizedBytes (SizeHash h))
-> MLockedSizedBytes (SizeHash h)
forall a b. (a -> b) -> a -> b
$
    proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p Ptr a
ptr (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))

digestMLockedBS
    :: forall h proxy. (SodiumHashAlgorithm h)
    => proxy h -> BS.ByteString -> MLockedSizedBytes (SizeHash h)
digestMLockedBS :: proxy h -> ByteString -> MLockedSizedBytes (SizeHash h)
digestMLockedBS proxy h
p ByteString
bs = IO (MLockedSizedBytes (SizeHash h))
-> MLockedSizedBytes (SizeHash h)
forall a. IO a -> a
unsafeDupablePerformIO (IO (MLockedSizedBytes (SizeHash h))
 -> MLockedSizedBytes (SizeHash h))
-> IO (MLockedSizedBytes (SizeHash h))
-> MLockedSizedBytes (SizeHash h)
forall a b. (a -> b) -> a -> b
$
    ByteString
-> (CStringLen -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs ((CStringLen -> IO (MLockedSizedBytes (SizeHash h)))
 -> IO (MLockedSizedBytes (SizeHash h)))
-> (CStringLen -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) ->
    proxy h -> Ptr Any -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p (Ptr CChar -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) Int
len

-------------------------------------------------------------------------------
-- Hash expansion
-------------------------------------------------------------------------------

expandHash
    :: forall h proxy. SodiumHashAlgorithm h
    => proxy h
    -> MLockedSizedBytes (SizeHash h)
    -> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash :: proxy h
-> MLockedSizedBytes (SizeHash h)
-> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash proxy h
h (MLSB MLockedForeignPtr (PinnedSizedBytes (SizeHash h))
sfptr) = IO (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
-> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
forall a. IO a -> a
unsafeDupablePerformIO (IO
   (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
 -> (MLockedSizedBytes (SizeHash h),
     MLockedSizedBytes (SizeHash h)))
-> IO
     (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
-> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ do
    MLockedForeignPtr (PinnedSizedBytes (SizeHash h))
-> (Ptr (PinnedSizedBytes (SizeHash h))
    -> IO
         (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)))
-> IO
     (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes (SizeHash h))
sfptr ((Ptr (PinnedSizedBytes (SizeHash h))
  -> IO
       (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)))
 -> IO
      (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)))
-> (Ptr (PinnedSizedBytes (SizeHash h))
    -> IO
         (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)))
-> IO
     (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes (SizeHash h))
ptr -> do
        MLockedSizedBytes (SizeHash h)
l <- CSize
-> (Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size1 ((Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
 -> IO (MLockedSizedBytes (SizeHash h)))
-> (Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
1 :: Word8)
            Ptr ()
_ <- Ptr (PinnedSizedBytes (SizeHash h))
-> Ptr (PinnedSizedBytes (SizeHash h)) -> CSize -> IO (Ptr ())
forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (Ptr Any -> Ptr (PinnedSizedBytes (SizeHash h))
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (PinnedSizedBytes (SizeHash h))
ptr CSize
size
            proxy h -> Ptr Word8 -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

        MLockedSizedBytes (SizeHash h)
r <- CSize
-> (Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size1 ((Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
 -> IO (MLockedSizedBytes (SizeHash h)))
-> (Ptr Word8 -> IO (MLockedSizedBytes (SizeHash h)))
-> IO (MLockedSizedBytes (SizeHash h))
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
2 :: Word8)
            Ptr ()
_ <- Ptr (PinnedSizedBytes (SizeHash h))
-> Ptr (PinnedSizedBytes (SizeHash h)) -> CSize -> IO (Ptr ())
forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (Ptr Any -> Ptr (PinnedSizedBytes (SizeHash h))
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (PinnedSizedBytes (SizeHash h))
ptr CSize
size
            proxy h -> Ptr Word8 -> Int -> IO (MLockedSizedBytes (SizeHash h))
forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

        (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
-> IO
     (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedSizedBytes (SizeHash h)
l, MLockedSizedBytes (SizeHash h)
r)
  where
    size1 :: CSize
    size1 :: CSize
size1 = CSize
size CSize -> CSize -> CSize
forall a. Num a => a -> a -> a
+ CSize
1

    size :: CSize
    size :: CSize
size = Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Integer -> CSize) -> Integer -> CSize
forall a b. (a -> b) -> a -> b
$ Proxy (SizeHash h) -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy (SizeHash h)
forall k (t :: k). Proxy t
Proxy @(SizeHash h))

-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

instance SodiumHashAlgorithm SHA256 where
    naclDigestPtr :: forall proxy a. proxy SHA256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
    naclDigestPtr :: proxy SHA256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
naclDigestPtr proxy SHA256
_ Ptr a
input Int
inputlen = do
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output <- IO (MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES))
forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> (Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output ((Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ())
-> (Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' -> do
            Int
res <- SizedPtr CRYPTO_SHA256_BYTES -> Ptr CUChar -> CULLong -> IO Int
c_crypto_hash_sha256 (Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> SizedPtr CRYPTO_SHA256_BYTES
forall (n :: Nat). Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output') (Ptr a -> Ptr CUChar
forall a b. Ptr a -> Ptr b
castPtr Ptr a
input) (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen)
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Errno
errno <- IO Errno
getErrno
                IOException -> IO ()
forall a. IOException -> IO a
ioException (IOException -> IO ()) -> IOException -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @SHA256: c_crypto_hash_sha256" Errno
errno Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

        MLockedSizedBytes CRYPTO_SHA256_BYTES
-> IO (MLockedSizedBytes CRYPTO_SHA256_BYTES)
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> MLockedSizedBytes CRYPTO_SHA256_BYTES
forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output)

-- Test that manually written numbers are the same as in libsodium
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 = SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
forall k (a :: k). a :~: a
Refl

instance SodiumHashAlgorithm Blake2b_256 where
    naclDigestPtr :: forall proxy a. proxy Blake2b_256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
    naclDigestPtr :: proxy Blake2b_256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
naclDigestPtr proxy Blake2b_256
_ Ptr a
input Int
inputlen = do
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output <- IO (MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES))
forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> (Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ()
forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output ((Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ())
-> (Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' -> do
            Int
res <- Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> CSize -> Ptr CUChar -> CULLong -> Ptr Any -> CSize -> IO Int
forall out key.
Ptr out
-> CSize -> Ptr CUChar -> CULLong -> Ptr key -> CSize -> IO Int
c_crypto_generichash_blake2b
                Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' (Integer -> CSize
forall a. Num a => Integer -> a
fromInteger (Integer -> CSize) -> Integer -> CSize
forall a b. (a -> b) -> a -> b
$ Proxy CRYPTO_SHA256_BYTES -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy CRYPTO_SHA256_BYTES
forall k (t :: k). Proxy t
Proxy @CRYPTO_BLAKE2B_256_BYTES))  -- output
                (Ptr a -> Ptr CUChar
forall a b. Ptr a -> Ptr b
castPtr Ptr a
input) (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen)  -- input
                Ptr Any
forall a. Ptr a
nullPtr CSize
0                                -- key, unused
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Errno
errno <- IO Errno
getErrno
                IOException -> IO ()
forall a. IOException -> IO a
ioException (IOException -> IO ()) -> IOException -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @Blake2b_256: c_crypto_hash_sha256" Errno
errno Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

        MLockedSizedBytes CRYPTO_SHA256_BYTES
-> IO (MLockedSizedBytes CRYPTO_SHA256_BYTES)
forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
-> MLockedSizedBytes CRYPTO_SHA256_BYTES
forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output)

_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_BLAKE2B_256_BYTES
_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_SHA256_BYTES
_testBlake2b256 = SizeHash Blake2b_256 :~: CRYPTO_SHA256_BYTES
forall k (a :: k). a :~: a
Refl