-- |
-- Module      : Crypto.MAC.KMAC
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Provide the KMAC (Keccak Message Authentication Code) algorithm, derived from
-- the SHA-3 base algorithm Keccak and defined in NIST SP800-185.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.MAC.KMAC
    ( HashSHAKE
    , kmac
    , KMAC(..)
    -- * Incremental
    , Context
    , initialize
    , update
    , updates
    , finalize
    ) where

import qualified Crypto.Hash as H
import           Crypto.Hash.SHAKE (HashSHAKE(..))
import           Crypto.Hash.Types (HashAlgorithm(..), Digest(..))
import qualified Crypto.Hash.Types as H
import           Foreign.Ptr (Ptr, plusPtr)
import           Foreign.Storable (poke)
import           Data.Bits (shiftR)
import           Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import           Data.Word (Word8)
import           Data.Memory.PtrMethods (memSet)


-- cSHAKE

cshakeInit :: forall a name string prefix . (HashSHAKE a, ByteArrayAccess name, ByteArrayAccess string, ByteArrayAccess prefix)
           => name -> string -> prefix -> H.Context a
cshakeInit :: name -> string -> prefix -> Context a
cshakeInit name
n string
s prefix
p = Bytes -> Context a
forall a. Bytes -> Context a
H.Context (Bytes -> Context a) -> Bytes -> Context a
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr (Context a) -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
c ((Ptr (Context a) -> IO ()) -> Bytes)
-> (Ptr (Context a) -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \(Ptr (Context a)
ptr :: Ptr (H.Context a)) -> do
    Ptr (Context a) -> IO ()
forall a. HashAlgorithm a => Ptr (Context a) -> IO ()
hashInternalInit Ptr (Context a)
ptr
    Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
b ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
d -> Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
forall a.
HashAlgorithm a =>
Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
hashInternalUpdate Ptr (Context a)
ptr Ptr Word8
d (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
b)
    prefix -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray prefix
p ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
d -> Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
forall a.
HashAlgorithm a =>
Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
hashInternalUpdate Ptr (Context a)
ptr Ptr Word8
d (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ prefix -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length prefix
p)
  where
    c :: Int
c = a -> Int
forall a. HashAlgorithm a => a -> Int
hashInternalContextSize (a
forall a. HasCallStack => a
undefined :: a)
    w :: Int
w = a -> Int
forall a. HashAlgorithm a => a -> Int
hashBlockSize (a
forall a. HasCallStack => a
undefined :: a)
    x :: Builder
x = name -> Builder
forall bin. ByteArrayAccess bin => bin -> Builder
encodeString name
n Builder -> Builder -> Builder
<+> string -> Builder
forall bin. ByteArrayAccess bin => bin -> Builder
encodeString string
s
    b :: Bytes
b = Builder -> Bytes
forall ba. ByteArray ba => Builder -> ba
builderAllocAndFreeze (Builder -> Int -> Builder
bytepad Builder
x Int
w) :: B.Bytes

cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba)
             => H.Context a -> ba -> H.Context a
cshakeUpdate :: Context a -> ba -> Context a
cshakeUpdate = Context a -> ba -> Context a
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
H.hashUpdate

cshakeUpdates :: (HashSHAKE a, ByteArrayAccess ba)
              => H.Context a -> [ba] -> H.Context a
cshakeUpdates :: Context a -> [ba] -> Context a
cshakeUpdates = Context a -> [ba] -> Context a
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
H.hashUpdates

cshakeFinalize :: forall a suffix . (HashSHAKE a, ByteArrayAccess suffix)
               => H.Context a -> suffix -> Digest a
cshakeFinalize :: Context a -> suffix -> Digest a
cshakeFinalize !Context a
c suffix
s =
    Block Word8 -> Digest a
forall a. Block Word8 -> Digest a
Digest (Block Word8 -> Digest a) -> Block Word8 -> Digest a
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr (Digest a) -> IO ()) -> Block Word8
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (a -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (a
forall a. HasCallStack => a
undefined :: a)) ((Ptr (Digest a) -> IO ()) -> Block Word8)
-> (Ptr (Digest a) -> IO ()) -> Block Word8
forall a b. (a -> b) -> a -> b
$ \Ptr (Digest a)
dig -> do
        ((!Bytes
_) :: B.Bytes) <- Context a -> (Ptr (Context a) -> IO ()) -> IO Bytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy Context a
c ((Ptr (Context a) -> IO ()) -> IO Bytes)
-> (Ptr (Context a) -> IO ()) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \(Ptr (Context a)
ctx :: Ptr (H.Context a)) -> do
            suffix -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray suffix
s ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
d ->
                Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
forall a.
HashAlgorithm a =>
Ptr (Context a) -> Ptr Word8 -> Word32 -> IO ()
hashInternalUpdate Ptr (Context a)
ctx Ptr Word8
d (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ suffix -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length suffix
s)
            Ptr (Context a) -> Ptr (Digest a) -> IO ()
forall a. HashSHAKE a => Ptr (Context a) -> Ptr (Digest a) -> IO ()
cshakeInternalFinalize Ptr (Context a)
ctx Ptr (Digest a)
dig
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


-- KMAC

-- | Represent a KMAC that is a phantom type with the hash used to produce the
-- mac.
--
-- The Eq instance is constant time.  No Show instance is provided, to avoid
-- printing by mistake.
newtype KMAC a = KMAC { KMAC a -> Digest a
kmacGetDigest :: Digest a }
    deriving KMAC a -> Int
KMAC a -> Ptr p -> IO ()
KMAC a -> (Ptr p -> IO a) -> IO a
(KMAC a -> Int)
-> (forall p a. KMAC a -> (Ptr p -> IO a) -> IO a)
-> (forall p. KMAC a -> Ptr p -> IO ())
-> ByteArrayAccess (KMAC a)
forall a. KMAC a -> Int
forall p. KMAC a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall a p. KMAC a -> Ptr p -> IO ()
forall p a. KMAC a -> (Ptr p -> IO a) -> IO a
forall a p a. KMAC a -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: KMAC a -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall a p. KMAC a -> Ptr p -> IO ()
withByteArray :: KMAC a -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall a p a. KMAC a -> (Ptr p -> IO a) -> IO a
length :: KMAC a -> Int
$clength :: forall a. KMAC a -> Int
ByteArrayAccess

instance Eq (KMAC a) where
    (KMAC Digest a
b1) == :: KMAC a -> KMAC a -> Bool
== (KMAC Digest a
b2) = Digest a -> Digest a -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
B.constEq Digest a
b1 Digest a
b2

-- | Compute a KMAC using the supplied customization string and key.
kmac :: (HashSHAKE a, ByteArrayAccess string, ByteArrayAccess key, ByteArrayAccess ba)
     => string -> key -> ba -> KMAC a
kmac :: string -> key -> ba -> KMAC a
kmac string
str key
key ba
msg = Context a -> KMAC a
forall a. HashSHAKE a => Context a -> KMAC a
finalize (Context a -> KMAC a) -> Context a -> KMAC a
forall a b. (a -> b) -> a -> b
$ Context a -> [ba] -> Context a
forall a ba.
(HashSHAKE a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
updates (string -> key -> Context a
forall a string key.
(HashSHAKE a, ByteArrayAccess string, ByteArrayAccess key) =>
string -> key -> Context a
initialize string
str key
key) [ba
msg]

-- | Represent an ongoing KMAC state, that can be appended with 'update' and
-- finalized to a 'KMAC' with 'finalize'.
newtype Context a = Context (H.Context a)

-- | Initialize a new incremental KMAC context with the supplied customization
-- string and key.
initialize :: forall a string key . (HashSHAKE a, ByteArrayAccess string, ByteArrayAccess key)
           => string -> key -> Context a
initialize :: string -> key -> Context a
initialize string
str key
key = Context a -> Context a
forall a. Context a -> Context a
Context (Context a -> Context a) -> Context a -> Context a
forall a b. (a -> b) -> a -> b
$ Bytes -> string -> ScrubbedBytes -> Context a
forall a name string prefix.
(HashSHAKE a, ByteArrayAccess name, ByteArrayAccess string,
 ByteArrayAccess prefix) =>
name -> string -> prefix -> Context a
cshakeInit Bytes
n string
str ScrubbedBytes
p
  where
    n :: Bytes
n = [Word8] -> Bytes
forall a. ByteArray a => [Word8] -> a
B.pack [Word8
75,Word8
77,Word8
65,Word8
67] :: B.Bytes  -- "KMAC"
    w :: Int
w = a -> Int
forall a. HashAlgorithm a => a -> Int
hashBlockSize (a
forall a. HasCallStack => a
undefined :: a)
    p :: ScrubbedBytes
p = Builder -> ScrubbedBytes
forall ba. ByteArray ba => Builder -> ba
builderAllocAndFreeze (Builder -> Int -> Builder
bytepad (key -> Builder
forall bin. ByteArrayAccess bin => bin -> Builder
encodeString key
key) Int
w) :: B.ScrubbedBytes

-- | Incrementally update a KMAC context.
update :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> ba -> Context a
update :: Context a -> ba -> Context a
update (Context Context a
ctx) = Context a -> Context a
forall a. Context a -> Context a
Context (Context a -> Context a) -> (ba -> Context a) -> ba -> Context a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context a -> ba -> Context a
forall a ba.
(HashSHAKE a, ByteArrayAccess ba) =>
Context a -> ba -> Context a
cshakeUpdate Context a
ctx

-- | Incrementally update a KMAC context with multiple inputs.
updates :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> [ba] -> Context a
updates :: Context a -> [ba] -> Context a
updates (Context Context a
ctx) = Context a -> Context a
forall a. Context a -> Context a
Context (Context a -> Context a)
-> ([ba] -> Context a) -> [ba] -> Context a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context a -> [ba] -> Context a
forall a ba.
(HashSHAKE a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
cshakeUpdates Context a
ctx

-- | Finalize a KMAC context and return the KMAC.
finalize :: forall a . HashSHAKE a => Context a -> KMAC a
finalize :: Context a -> KMAC a
finalize (Context Context a
ctx) = Digest a -> KMAC a
forall a. Digest a -> KMAC a
KMAC (Digest a -> KMAC a) -> Digest a -> KMAC a
forall a b. (a -> b) -> a -> b
$ Context a -> Bytes -> Digest a
forall a suffix.
(HashSHAKE a, ByteArrayAccess suffix) =>
Context a -> suffix -> Digest a
cshakeFinalize Context a
ctx Bytes
suffix
  where
    l :: Int
l = a -> Int
forall a. HashSHAKE a => a -> Int
cshakeOutputLength (a
forall a. HasCallStack => a
undefined :: a)
    suffix :: Bytes
suffix = Builder -> Bytes
forall ba. ByteArray ba => Builder -> ba
builderAllocAndFreeze (Int -> Builder
rightEncode Int
l) :: B.Bytes


-- Utilities

bytepad :: Builder -> Int -> Builder
bytepad :: Builder -> Int -> Builder
bytepad Builder
x Int
w = Builder
prefix Builder -> Builder -> Builder
<+> Builder
x Builder -> Builder -> Builder
<+> Int -> Builder
zero Int
padLen
  where
    prefix :: Builder
prefix = Int -> Builder
leftEncode Int
w
    padLen :: Int
padLen = (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Builder -> Int
builderLength Builder
prefix Int -> Int -> Int
forall a. Num a => a -> a -> a
- Builder -> Int
builderLength Builder
x) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
w

encodeString :: ByteArrayAccess bin => bin -> Builder
encodeString :: bin -> Builder
encodeString bin
s = Int -> Builder
leftEncode (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* bin -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bin
s) Builder -> Builder -> Builder
<+> bin -> Builder
forall bin. ByteArrayAccess bin => bin -> Builder
bytes bin
s

leftEncode :: Int -> Builder
leftEncode :: Int -> Builder
leftEncode Int
x = Word8 -> Builder
byte Word8
len Builder -> Builder -> Builder
<+> Builder
digits
  where
    digits :: Builder
digits = Int -> Builder
i2osp Int
x
    len :: Word8
len    = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Builder -> Int
builderLength Builder
digits)

rightEncode :: Int -> Builder
rightEncode :: Int -> Builder
rightEncode Int
x = Builder
digits Builder -> Builder -> Builder
<+> Word8 -> Builder
byte Word8
len
  where
    digits :: Builder
digits = Int -> Builder
i2osp Int
x
    len :: Word8
len    = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Builder -> Int
builderLength Builder
digits)

i2osp :: Int -> Builder
i2osp :: Int -> Builder
i2osp Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
256  = Int -> Builder
i2osp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR Int
i Int
8) Builder -> Builder -> Builder
<+> Word8 -> Builder
byte (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)
        | Bool
otherwise = Word8 -> Builder
byte (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)


-- Delaying and merging ByteArray allocations

data Builder = Builder !Int (Ptr Word8 -> IO ())  -- size and initializer

(<+>) :: Builder -> Builder -> Builder
(Builder Int
s1 Ptr Word8 -> IO ()
f1) <+> :: Builder -> Builder -> Builder
<+> (Builder Int
s2 Ptr Word8 -> IO ()
f2) = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder (Int
s1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s2) Ptr Word8 -> IO ()
f
  where f :: Ptr Word8 -> IO ()
f Ptr Word8
p = Ptr Word8 -> IO ()
f1 Ptr Word8
p IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> IO ()
f2 (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s1)

builderLength :: Builder -> Int
builderLength :: Builder -> Int
builderLength (Builder Int
s Ptr Word8 -> IO ()
_) = Int
s

builderAllocAndFreeze :: ByteArray ba => Builder -> ba
builderAllocAndFreeze :: Builder -> ba
builderAllocAndFreeze (Builder Int
s Ptr Word8 -> IO ()
f) = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
s Ptr Word8 -> IO ()
f

byte :: Word8 -> Builder
byte :: Word8 -> Builder
byte !Word8
b = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder Int
1 (Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
`poke` Word8
b)

bytes :: ByteArrayAccess ba => ba -> Builder
bytes :: ba -> Builder
bytes ba
bs = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs) (ba -> Ptr Word8 -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
B.copyByteArrayToPtr ba
bs)

zero :: Int -> Builder
zero :: Int -> Builder
zero Int
s = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder Int
s (\Ptr Word8
p -> Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
p Word8
0 Int
s)