module Crypto.Cipher.ChaChaPoly1305
( State
, Nonce
, nonce12
, nonce8
, incrementNonce
, initialize
, appendAAD
, finalizeAAD
, encrypt
, decrypt
, finalize
) where
import Control.Monad (when)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Imports
import Crypto.Error
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Memory.Endian
import qualified Data.ByteArray.Pack as P
import Foreign.Ptr
import Foreign.Storable
data State = State !ChaCha.State
!Poly1305.State
!Word64
!Word64
data Nonce = Nonce8 Bytes | Nonce12 Bytes
instance ByteArrayAccess Nonce where
length :: Nonce -> Int
length (Nonce8 Bytes
n) = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
length (Nonce12 Bytes
n) = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
withByteArray :: Nonce -> (Ptr p -> IO a) -> IO a
withByteArray (Nonce8 Bytes
n) = Bytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
withByteArray (Nonce12 Bytes
n) = Bytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
pad16 :: Word64 -> Bytes
pad16 :: Word64 -> Bytes
pad16 Word64
n
| Int
modLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bytes
forall a. ByteArray a => a
B.empty
| Bool
otherwise = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
modLen) Word8
0
where
modLen :: Int
modLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
n Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` Word64
16)
nonce12 :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 :: iv -> CryptoFailable Nonce
nonce12 iv
iv
| iv -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length iv
iv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| Bool
otherwise = Nonce -> CryptoFailable Nonce
forall a. a -> CryptoFailable a
CryptoPassed (Nonce -> CryptoFailable Nonce)
-> (iv -> Nonce) -> iv -> CryptoFailable Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce12 (Bytes -> Nonce) -> (iv -> Bytes) -> iv -> Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. iv -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (iv -> CryptoFailable Nonce) -> iv -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ iv
iv
nonce8 :: ByteArrayAccess ba
=> ba
-> ba
-> CryptoFailable Nonce
nonce8 :: ba -> ba -> CryptoFailable Nonce
nonce8 ba
constant ba
iv
| ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
constant Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
4 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
iv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
8 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
| Bool
otherwise = Nonce -> CryptoFailable Nonce
forall a. a -> CryptoFailable a
CryptoPassed (Nonce -> CryptoFailable Nonce)
-> ([ba] -> Nonce) -> [ba] -> CryptoFailable Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce8 (Bytes -> Nonce) -> ([ba] -> Bytes) -> [ba] -> Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ba] -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat ([ba] -> CryptoFailable Nonce) -> [ba] -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ [ba
constant, ba
iv]
incrementNonce :: Nonce -> Nonce
incrementNonce :: Nonce -> Nonce
incrementNonce (Nonce8 Bytes
n) = Bytes -> Nonce
Nonce8 (Bytes -> Nonce) -> Bytes -> Nonce
forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
4
incrementNonce (Nonce12 Bytes
n) = Bytes -> Nonce
Nonce12 (Bytes -> Nonce) -> Bytes -> Nonce
forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
0
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' Bytes
b Int
offset = Bytes -> (Ptr Word8 -> IO ()) -> Bytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze Bytes
b ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s ->
Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset)
where
loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s Ptr Word8
p
| Ptr Word8
s Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) = Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s (Word8 -> IO ()) -> (Word8 -> Word8) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
(+) Word8
1
| Bool
otherwise = do
Word8
r <- Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
(+) Word8
1 (Word8 -> Word8) -> IO Word8 -> IO Word8
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
r
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
r Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)
initialize :: ByteArrayAccess key
=> key -> Nonce -> CryptoFailable State
initialize :: key -> Nonce -> CryptoFailable State
initialize key
key (Nonce8 Bytes
nonce) = key -> Bytes -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize key
key (Nonce12 Bytes
nonce) = key -> Bytes -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize' :: ByteArrayAccess key
=> key -> Bytes -> CryptoFailable State
initialize' :: key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
| key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = CryptoError -> CryptoFailable State
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_KeySizeInvalid
| Bool
otherwise = State -> CryptoFailable State
forall a. a -> CryptoFailable a
CryptoPassed (State -> CryptoFailable State) -> State -> CryptoFailable State
forall a b. (a -> b) -> a -> b
$ State -> State -> Word64 -> Word64 -> State
State State
encState State
polyState Word64
0 Word64
0
where
rootState :: State
rootState = Int -> key -> Bytes -> State
forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
ChaCha.initialize Int
20 key
key Bytes
nonce
(ScrubbedBytes
polyKey, State
encState) = State -> Int -> (ScrubbedBytes, State)
forall ba. ByteArray ba => State -> Int -> (ba, State)
ChaCha.generate State
rootState Int
64
polyState :: State
polyState = CryptoFailable State -> State
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable State -> State) -> CryptoFailable State -> State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> CryptoFailable State
forall key. ByteArrayAccess key => key -> CryptoFailable State
Poly1305.initialize (Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
32 ScrubbedBytes
polyKey :: ScrubbedBytes)
appendAAD :: ByteArrayAccess ba => ba -> State -> State
appendAAD :: ba -> State -> State
appendAAD ba
ba (State State
encState State
macState Word64
aadLength Word64
plainLength) =
State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
newLength Word64
plainLength
where
newMacState :: State
newMacState = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
ba
newLength :: Word64
newLength = Word64
aadLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba)
finalizeAAD :: State -> State
finalizeAAD :: State -> State
finalizeAAD (State State
encState State
macState Word64
aadLength Word64
plainLength) =
State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
aadLength Word64
plainLength
where
newMacState :: State
newMacState = State -> Bytes -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState (Bytes -> State) -> Bytes -> State
forall a b. (a -> b) -> a -> b
$ Word64 -> Bytes
pad16 Word64
aadLength
encrypt :: ByteArray ba => ba -> State -> (ba, State)
encrypt :: ba -> State -> (ba, State)
encrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
(ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
where
(ba
output, State
newEncState) = State -> ba -> (ba, State)
forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
newMacState :: State
newMacState = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
output
newPlainLength :: Word64
newPlainLength = Word64
plainLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)
decrypt :: ByteArray ba => ba -> State -> (ba, State)
decrypt :: ba -> State -> (ba, State)
decrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
(ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
where
(ba
output, State
newEncState) = State -> ba -> (ba, State)
forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
newMacState :: State
newMacState = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
input
newPlainLength :: Word64
newPlainLength = Word64
plainLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)
finalize :: State -> Poly1305.Auth
finalize :: State -> Auth
finalize (State State
_ State
macState Word64
aadLength Word64
plainLength) =
State -> Auth
Poly1305.finalize (State -> Auth) -> State -> Auth
forall a b. (a -> b) -> a -> b
$ State -> [Bytes] -> State
forall ba. ByteArrayAccess ba => State -> [ba] -> State
Poly1305.updates State
macState
[ Word64 -> Bytes
pad16 Word64
plainLength
, (String -> Bytes)
-> (Bytes -> Bytes) -> Either String Bytes -> Bytes
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> String -> Bytes
forall a. HasCallStack => String -> a
error String
"finalize: internal error") Bytes -> Bytes
forall a. a -> a
id (Either String Bytes -> Bytes) -> Either String Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ Int -> Packer () -> Either String Bytes
forall byteArray a.
ByteArray byteArray =>
Int -> Packer a -> Either String byteArray
P.fill Int
16 (LE Word64 -> Packer ()
forall storable. Storable storable => storable -> Packer ()
P.putStorable (Word64 -> LE Word64
forall a. ByteSwap a => a -> LE a
toLE Word64
aadLength) Packer () -> Packer () -> Packer ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> LE Word64 -> Packer ()
forall storable. Storable storable => storable -> Packer ()
P.putStorable (Word64 -> LE Word64
forall a. ByteSwap a => a -> LE a
toLE Word64
plainLength))
]