-- |
-- Module      : Crypto.Cipher.Types.AEAD
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : Stable
-- Portability : Excellent
--
-- AEAD cipher basic types
--
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Crypto.Cipher.Types.AEAD where

import           Crypto.Cipher.Types.Base
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B
import           Crypto.Internal.Imports

-- | AEAD Implementation
data AEADModeImpl st = AEADModeImpl
    { AEADModeImpl st -> forall ba. ByteArrayAccess ba => st -> ba -> st
aeadImplAppendHeader :: forall ba . ByteArrayAccess ba => st -> ba -> st
    , AEADModeImpl st -> forall ba. ByteArray ba => st -> ba -> (ba, st)
aeadImplEncrypt      :: forall ba . ByteArray ba => st -> ba -> (ba, st)
    , AEADModeImpl st -> forall ba. ByteArray ba => st -> ba -> (ba, st)
aeadImplDecrypt      :: forall ba . ByteArray ba => st -> ba -> (ba, st)
    , AEADModeImpl st -> st -> Int -> AuthTag
aeadImplFinalize     :: st -> Int -> AuthTag
    }

-- | Authenticated Encryption with Associated Data algorithms
data AEAD cipher = forall st . AEAD
    { ()
aeadModeImpl :: AEADModeImpl st
    , ()
aeadState    :: st
    }

-- | Append some header information to an AEAD context
aeadAppendHeader :: ByteArrayAccess aad => AEAD cipher -> aad -> AEAD cipher
aeadAppendHeader :: AEAD cipher -> aad -> AEAD cipher
aeadAppendHeader (AEAD AEADModeImpl st
impl st
st) aad
aad = AEADModeImpl st -> st -> AEAD cipher
forall cipher st. AEADModeImpl st -> st -> AEAD cipher
AEAD AEADModeImpl st
impl (st -> AEAD cipher) -> st -> AEAD cipher
forall a b. (a -> b) -> a -> b
$ (AEADModeImpl st -> st -> aad -> st
forall st.
AEADModeImpl st -> forall ba. ByteArrayAccess ba => st -> ba -> st
aeadImplAppendHeader AEADModeImpl st
impl) st
st aad
aad

-- | Encrypt some data and update the AEAD context
aeadEncrypt :: ByteArray ba => AEAD cipher -> ba -> (ba, AEAD cipher)
aeadEncrypt :: AEAD cipher -> ba -> (ba, AEAD cipher)
aeadEncrypt (AEAD AEADModeImpl st
impl st
st) ba
ba = (st -> AEAD cipher) -> (ba, st) -> (ba, AEAD cipher)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (AEADModeImpl st -> st -> AEAD cipher
forall cipher st. AEADModeImpl st -> st -> AEAD cipher
AEAD AEADModeImpl st
impl) ((ba, st) -> (ba, AEAD cipher)) -> (ba, st) -> (ba, AEAD cipher)
forall a b. (a -> b) -> a -> b
$ (AEADModeImpl st -> st -> ba -> (ba, st)
forall st.
AEADModeImpl st -> forall ba. ByteArray ba => st -> ba -> (ba, st)
aeadImplEncrypt AEADModeImpl st
impl) st
st ba
ba

-- | Decrypt some data and update the AEAD context
aeadDecrypt :: ByteArray ba => AEAD cipher -> ba -> (ba, AEAD cipher)
aeadDecrypt :: AEAD cipher -> ba -> (ba, AEAD cipher)
aeadDecrypt (AEAD AEADModeImpl st
impl st
st) ba
ba = (st -> AEAD cipher) -> (ba, st) -> (ba, AEAD cipher)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (AEADModeImpl st -> st -> AEAD cipher
forall cipher st. AEADModeImpl st -> st -> AEAD cipher
AEAD AEADModeImpl st
impl) ((ba, st) -> (ba, AEAD cipher)) -> (ba, st) -> (ba, AEAD cipher)
forall a b. (a -> b) -> a -> b
$ (AEADModeImpl st -> st -> ba -> (ba, st)
forall st.
AEADModeImpl st -> forall ba. ByteArray ba => st -> ba -> (ba, st)
aeadImplDecrypt AEADModeImpl st
impl) st
st ba
ba

-- | Finalize the AEAD context and return the authentication tag
aeadFinalize :: AEAD cipher -> Int -> AuthTag
aeadFinalize :: AEAD cipher -> Int -> AuthTag
aeadFinalize (AEAD AEADModeImpl st
impl st
st) Int
n = (AEADModeImpl st -> st -> Int -> AuthTag
forall st. AEADModeImpl st -> st -> Int -> AuthTag
aeadImplFinalize AEADModeImpl st
impl) st
st Int
n

-- | Simple AEAD encryption
aeadSimpleEncrypt :: (ByteArrayAccess aad, ByteArray ba)
                  => AEAD a        -- ^ A new AEAD Context
                  -> aad           -- ^ Optional Authentication data header
                  -> ba            -- ^ Optional Plaintext
                  -> Int           -- ^ Tag length
                  -> (AuthTag, ba) -- ^ Authentication tag and ciphertext
aeadSimpleEncrypt :: AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
aeadSimpleEncrypt AEAD a
aeadIni aad
header ba
input Int
taglen = (AuthTag
tag, ba
output)
  where aead :: AEAD a
aead                = AEAD a -> aad -> AEAD a
forall aad cipher.
ByteArrayAccess aad =>
AEAD cipher -> aad -> AEAD cipher
aeadAppendHeader AEAD a
aeadIni aad
header
        (ba
output, AEAD a
aeadFinal) = AEAD a -> ba -> (ba, AEAD a)
forall ba cipher.
ByteArray ba =>
AEAD cipher -> ba -> (ba, AEAD cipher)
aeadEncrypt AEAD a
aead ba
input
        tag :: AuthTag
tag                 = AEAD a -> Int -> AuthTag
forall cipher. AEAD cipher -> Int -> AuthTag
aeadFinalize AEAD a
aeadFinal Int
taglen

-- | Simple AEAD decryption
aeadSimpleDecrypt :: (ByteArrayAccess aad, ByteArray ba)
                  => AEAD a        -- ^ A new AEAD Context
                  -> aad           -- ^ Optional Authentication data header
                  -> ba            -- ^ Ciphertext
                  -> AuthTag       -- ^ The authentication tag
                  -> Maybe ba      -- ^ Plaintext
aeadSimpleDecrypt :: AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD a
aeadIni aad
header ba
input AuthTag
authTag
    | AuthTag
tag AuthTag -> AuthTag -> Bool
forall a. Eq a => a -> a -> Bool
== AuthTag
authTag = ba -> Maybe ba
forall a. a -> Maybe a
Just ba
output
    | Bool
otherwise      = Maybe ba
forall a. Maybe a
Nothing
  where aead :: AEAD a
aead                = AEAD a -> aad -> AEAD a
forall aad cipher.
ByteArrayAccess aad =>
AEAD cipher -> aad -> AEAD cipher
aeadAppendHeader AEAD a
aeadIni aad
header
        (ba
output, AEAD a
aeadFinal) = AEAD a -> ba -> (ba, AEAD a)
forall ba cipher.
ByteArray ba =>
AEAD cipher -> ba -> (ba, AEAD cipher)
aeadDecrypt AEAD a
aead ba
input
        tag :: AuthTag
tag                 = AEAD a -> Int -> AuthTag
forall cipher. AEAD cipher -> Int -> AuthTag
aeadFinalize AEAD a
aeadFinal (AuthTag -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length AuthTag
authTag)