{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Copyright: © 2018-2021 IOHK
-- License: Apache-2.0
--
-- Generating and verifying hashes of wallet passwords.
--

module Cardano.Wallet.Primitive.Passphrase.Current
    ( encryptPassphrase
    , checkPassphrase
    , preparePassphrase
    , genSalt
    ) where

import Prelude

import Cardano.Wallet.Primitive.Passphrase.Types
    ( ErrWrongPassphrase (..), Passphrase (..), PassphraseHash (..) )
import Control.Monad
    ( unless )
import Crypto.KDF.PBKDF2
    ( Parameters (..), fastPBKDF2_SHA512 )
import Crypto.Random.Types
    ( MonadRandom (..) )
import Data.ByteArray
    ( ScrubbedBytes )
import Data.ByteString
    ( ByteString )
import Data.Coerce
    ( coerce )
import Data.Function
    ( on )

import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS

-- | Encrypt a 'Passphrase' into a format that is suitable for storing on disk
encryptPassphrase
    :: MonadRandom m
    => Passphrase "encryption"
    -> m PassphraseHash
encryptPassphrase :: Passphrase "encryption" -> m PassphraseHash
encryptPassphrase (Passphrase ScrubbedBytes
bytes) = Passphrase "salt" -> PassphraseHash
forall (purpose :: Symbol). Passphrase purpose -> PassphraseHash
mkPassphraseHash (Passphrase "salt" -> PassphraseHash)
-> m (Passphrase "salt") -> m PassphraseHash
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Passphrase "salt")
forall (m :: * -> *). MonadRandom m => m (Passphrase "salt")
genSalt
  where
    mkPassphraseHash :: Passphrase purpose -> PassphraseHash
mkPassphraseHash (Passphrase ScrubbedBytes
salt) = ScrubbedBytes -> PassphraseHash
PassphraseHash (ScrubbedBytes -> PassphraseHash)
-> ScrubbedBytes -> PassphraseHash
forall a b. (a -> b) -> a -> b
$ ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ScrubbedBytes) -> ByteString -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ ByteString
forall a. Monoid a => a
mempty
        ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
salt))
        ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ScrubbedBytes
salt
        ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Parameters -> ScrubbedBytes -> ScrubbedBytes -> ByteString
forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA512 Parameters
params ScrubbedBytes
bytes ScrubbedBytes
salt

    params :: Parameters
params = Parameters :: Int -> Int -> Parameters
Parameters
        { iterCounts :: Int
iterCounts = Int
20000
        , outputLength :: Int
outputLength = Int
64
        }

genSalt :: MonadRandom m => m (Passphrase "salt")
genSalt :: m (Passphrase "salt")
genSalt = ScrubbedBytes -> Passphrase "salt"
forall (purpose :: Symbol). ScrubbedBytes -> Passphrase purpose
Passphrase (ScrubbedBytes -> Passphrase "salt")
-> m ScrubbedBytes -> m (Passphrase "salt")
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m ScrubbedBytes
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16

preparePassphrase :: Passphrase "user" -> Passphrase "encryption"
preparePassphrase :: Passphrase "user" -> Passphrase "encryption"
preparePassphrase = Passphrase "user" -> Passphrase "encryption"
coerce

checkPassphrase
    :: Passphrase "encryption"
    -> PassphraseHash
    -> Either ErrWrongPassphrase ()
checkPassphrase :: Passphrase "encryption"
-> PassphraseHash -> Either ErrWrongPassphrase ()
checkPassphrase Passphrase "encryption"
prepared PassphraseHash
stored = do
    Passphrase "salt"
salt <- ByteString -> Either ErrWrongPassphrase (Passphrase "salt")
getSalt (PassphraseHash -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert PassphraseHash
stored)
    Bool
-> Either ErrWrongPassphrase () -> Either ErrWrongPassphrase ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PassphraseHash -> PassphraseHash -> Bool
constantTimeEq (Passphrase "encryption" -> Passphrase "salt" -> PassphraseHash
forall (m :: * -> *).
MonadRandom m =>
Passphrase "encryption" -> m PassphraseHash
encryptPassphrase Passphrase "encryption"
prepared Passphrase "salt"
salt) PassphraseHash
stored) (Either ErrWrongPassphrase () -> Either ErrWrongPassphrase ())
-> Either ErrWrongPassphrase () -> Either ErrWrongPassphrase ()
forall a b. (a -> b) -> a -> b
$
        ErrWrongPassphrase -> Either ErrWrongPassphrase ()
forall a b. a -> Either a b
Left ErrWrongPassphrase
ErrWrongPassphrase
  where
    getSalt :: ByteString -> Either ErrWrongPassphrase (Passphrase "salt")
    getSalt :: ByteString -> Either ErrWrongPassphrase (Passphrase "salt")
getSalt ByteString
bytes = do
        Int
len <- case ByteString -> [Word8]
BS.unpack (Int -> ByteString -> ByteString
BS.take Int
1 ByteString
bytes) of
            [Word8
len] -> Int -> Either ErrWrongPassphrase Int
forall a b. b -> Either a b
Right (Int -> Either ErrWrongPassphrase Int)
-> Int -> Either ErrWrongPassphrase Int
forall a b. (a -> b) -> a -> b
$ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
len
            [Word8]
_ -> ErrWrongPassphrase -> Either ErrWrongPassphrase Int
forall a b. a -> Either a b
Left ErrWrongPassphrase
ErrWrongPassphrase
        Passphrase "salt" -> Either ErrWrongPassphrase (Passphrase "salt")
forall a b. b -> Either a b
Right (Passphrase "salt"
 -> Either ErrWrongPassphrase (Passphrase "salt"))
-> Passphrase "salt"
-> Either ErrWrongPassphrase (Passphrase "salt")
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Passphrase "salt"
forall (purpose :: Symbol). ScrubbedBytes -> Passphrase purpose
Passphrase (ScrubbedBytes -> Passphrase "salt")
-> ScrubbedBytes -> Passphrase "salt"
forall a b. (a -> b) -> a -> b
$ ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> ScrubbedBytes) -> ByteString -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.take Int
len (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
bytes)

    constantTimeEq :: PassphraseHash -> PassphraseHash -> Bool
    constantTimeEq :: PassphraseHash -> PassphraseHash -> Bool
constantTimeEq = ScrubbedBytes -> ScrubbedBytes -> Bool
forall a. Eq a => a -> a -> Bool
(==) (ScrubbedBytes -> ScrubbedBytes -> Bool)
-> (PassphraseHash -> ScrubbedBytes)
-> PassphraseHash
-> PassphraseHash
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (ByteArrayAccess PassphraseHash, ByteArray ScrubbedBytes) =>
PassphraseHash -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert @_ @ScrubbedBytes