{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Cardano.Crypto.Wallet.Pure
    ( XPrv(..)
    , XPub(..)
    , xprvPub
    , deriveXPrv
    , deriveXPrvHardened
    , deriveXPub
    , hInitSeed
    , hFinalize
    ) where

import           Control.DeepSeq             (NFData)
import qualified Crypto.Math.Edwards25519    as Edwards25519
import           Crypto.Hash                 (SHA512)
import qualified Crypto.MAC.HMAC             as HMAC
--import qualified Crypto.PubKey.Ed25519       as Ed25519
import           Data.Bits
import           Data.ByteArray              (ByteArrayAccess, convert)
import qualified Data.ByteArray              as B (splitAt)
import           Data.ByteString             (ByteString)
import qualified Data.ByteString             as B (pack)
import           Data.Hashable               (Hashable)
import           Data.Word
import           GHC.Generics                (Generic)

import           Cardano.Crypto.Wallet.Types

data XPrv = XPrv !Edwards25519.Scalar !ChainCode

data XPub = XPub !Edwards25519.PointCompressed !ChainCode
    deriving (XPub -> XPub -> Bool
(XPub -> XPub -> Bool) -> (XPub -> XPub -> Bool) -> Eq XPub
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: XPub -> XPub -> Bool
$c/= :: XPub -> XPub -> Bool
== :: XPub -> XPub -> Bool
$c== :: XPub -> XPub -> Bool
Eq, Eq XPub
Eq XPub
-> (XPub -> XPub -> Ordering)
-> (XPub -> XPub -> Bool)
-> (XPub -> XPub -> Bool)
-> (XPub -> XPub -> Bool)
-> (XPub -> XPub -> Bool)
-> (XPub -> XPub -> XPub)
-> (XPub -> XPub -> XPub)
-> Ord XPub
XPub -> XPub -> Bool
XPub -> XPub -> Ordering
XPub -> XPub -> XPub
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: XPub -> XPub -> XPub
$cmin :: XPub -> XPub -> XPub
max :: XPub -> XPub -> XPub
$cmax :: XPub -> XPub -> XPub
>= :: XPub -> XPub -> Bool
$c>= :: XPub -> XPub -> Bool
> :: XPub -> XPub -> Bool
$c> :: XPub -> XPub -> Bool
<= :: XPub -> XPub -> Bool
$c<= :: XPub -> XPub -> Bool
< :: XPub -> XPub -> Bool
$c< :: XPub -> XPub -> Bool
compare :: XPub -> XPub -> Ordering
$ccompare :: XPub -> XPub -> Ordering
$cp1Ord :: Eq XPub
Ord, Int -> XPub -> ShowS
[XPub] -> ShowS
XPub -> String
(Int -> XPub -> ShowS)
-> (XPub -> String) -> ([XPub] -> ShowS) -> Show XPub
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [XPub] -> ShowS
$cshowList :: [XPub] -> ShowS
show :: XPub -> String
$cshow :: XPub -> String
showsPrec :: Int -> XPub -> ShowS
$cshowsPrec :: Int -> XPub -> ShowS
Show, (forall x. XPub -> Rep XPub x)
-> (forall x. Rep XPub x -> XPub) -> Generic XPub
forall x. Rep XPub x -> XPub
forall x. XPub -> Rep XPub x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep XPub x -> XPub
$cfrom :: forall x. XPub -> Rep XPub x
Generic)

instance NFData XPub
instance Hashable XPub

xprvPub :: XPrv -> ByteString
xprvPub :: XPrv -> ByteString
xprvPub (XPrv Scalar
s ChainCode
_) =
    PointCompressed -> ByteString
Edwards25519.unPointCompressed (PointCompressed -> ByteString) -> PointCompressed -> ByteString
forall a b. (a -> b) -> a -> b
$ Scalar -> PointCompressed
Edwards25519.scalarToPoint Scalar
s

deriveXPrv :: XPrv -> Word32 -> XPrv
deriveXPrv :: XPrv -> Word32 -> XPrv
deriveXPrv (XPrv Scalar
sec ChainCode
ccode) Word32
n =
    let !pub :: PointCompressed
pub     = Scalar -> PointCompressed
Edwards25519.scalarToPoint Scalar
sec
        (ByteString
iL, ChainCode
iR) = DerivationHash -> (ByteString, ChainCode)
walletHash (DerivationHash -> (ByteString, ChainCode))
-> DerivationHash -> (ByteString, ChainCode)
forall a b. (a -> b) -> a -> b
$ PointCompressed -> ChainCode -> Word32 -> DerivationHash
DerivationHashNormal PointCompressed
pub ChainCode
ccode Word32
n
        !derived :: Scalar
derived = ByteString -> Scalar
Edwards25519.scalar ByteString
iL
     in Scalar -> ChainCode -> XPrv
XPrv (Scalar -> Scalar -> Scalar
Edwards25519.scalarAdd Scalar
sec Scalar
derived) ChainCode
iR

deriveXPrvHardened :: XPrv -> Word32 -> XPrv
deriveXPrvHardened :: XPrv -> Word32 -> XPrv
deriveXPrvHardened (XPrv Scalar
sec ChainCode
ccode) Word32
n =
    let (ByteString
iL, ChainCode
iR) = DerivationHash -> (ByteString, ChainCode)
walletHash (DerivationHash -> (ByteString, ChainCode))
-> DerivationHash -> (ByteString, ChainCode)
forall a b. (a -> b) -> a -> b
$ Scalar -> ChainCode -> Word32 -> DerivationHash
DerivationHashHardened Scalar
sec ChainCode
ccode Word32
n
     in Scalar -> ChainCode -> XPrv
XPrv (ByteString -> Scalar
Edwards25519.scalar ByteString
iL) ChainCode
iR

-- | Derive a child public from an extended public key
deriveXPub :: XPub -> Word32 -> XPub
deriveXPub :: XPub -> Word32 -> XPub
deriveXPub (XPub PointCompressed
pub ChainCode
ccode) Word32
n =
    let (ByteString
iL, ChainCode
iR) = DerivationHash -> (ByteString, ChainCode)
walletHash (DerivationHash -> (ByteString, ChainCode))
-> DerivationHash -> (ByteString, ChainCode)
forall a b. (a -> b) -> a -> b
$ PointCompressed -> ChainCode -> Word32 -> DerivationHash
DerivationHashNormal PointCompressed
pub ChainCode
ccode Word32
n
        !derived :: PointCompressed
derived = Scalar -> PointCompressed
Edwards25519.scalarToPoint (Scalar -> PointCompressed) -> Scalar -> PointCompressed
forall a b. (a -> b) -> a -> b
$ ByteString -> Scalar
Edwards25519.scalar ByteString
iL
     in PointCompressed -> ChainCode -> XPub
XPub (PointCompressed -> PointCompressed -> PointCompressed
Edwards25519.pointAdd PointCompressed
pub PointCompressed
derived) ChainCode
iR


-- hashing methods either hardened or normal
data DerivationHash where
    DerivationHashHardened :: Edwards25519.Scalar          -> ChainCode -> Word32 -> DerivationHash
    DerivationHashNormal   :: Edwards25519.PointCompressed -> ChainCode -> Word32 -> DerivationHash

walletHash :: DerivationHash -> (ByteString, ChainCode)
walletHash :: DerivationHash -> (ByteString, ChainCode)
walletHash (DerivationHashHardened Scalar
sec ChainCode
cc Word32
w) =
    Context SHA512 -> (ByteString, ChainCode)
hFinalize
            (Context SHA512 -> (ByteString, ChainCode))
-> Context SHA512 -> (ByteString, ChainCode)
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update (Word32 -> ByteString
encodeIndex Word32
w)
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update (Scalar -> ByteString
Edwards25519.unScalar Scalar
sec)
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update ByteString
hardenedTag
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ ChainCode -> Context SHA512
hInit ChainCode
cc
walletHash (DerivationHashNormal PointCompressed
pub ChainCode
cc Word32
w) =
    Context SHA512 -> (ByteString, ChainCode)
hFinalize
            (Context SHA512 -> (ByteString, ChainCode))
-> Context SHA512 -> (ByteString, ChainCode)
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update (Word32 -> ByteString
encodeIndex Word32
w)
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update (PointCompressed -> ByteString
Edwards25519.unPointCompressed PointCompressed
pub)
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ (Context SHA512 -> ByteString -> Context SHA512)
-> ByteString -> Context SHA512 -> Context SHA512
forall a b c. (a -> b -> c) -> b -> a -> c
flip Context SHA512 -> ByteString -> Context SHA512
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update ByteString
normalTag
            (Context SHA512 -> Context SHA512)
-> Context SHA512 -> Context SHA512
forall a b. (a -> b) -> a -> b
$ ChainCode -> Context SHA512
hInit ChainCode
cc

hardenedTag :: ByteString
hardenedTag :: ByteString
hardenedTag = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
"HARD"
normalTag :: ByteString
normalTag :: ByteString
normalTag   = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
"NORM"

-- | Encode a Word32 in Big endian
encodeIndex :: Word32 -> ByteString
encodeIndex :: Word32 -> ByteString
encodeIndex Word32
w = [Word8] -> ByteString
B.pack [Word8
d,Word8
c,Word8
b,Word8
a]
  where
    a :: Word8
a = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w
    b :: Word8
b = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
    c :: Word8
c = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
    d :: Word8
d = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)

hInit :: ChainCode -> HMAC.Context SHA512
hInit :: ChainCode -> Context SHA512
hInit (ChainCode ByteString
key) = ByteString -> Context SHA512
forall key a.
(ByteArrayAccess key, HashAlgorithm a) =>
key -> Context a
HMAC.initialize ByteString
key

hInitSeed :: ByteArrayAccess seed => seed -> HMAC.Context SHA512
hInitSeed :: seed -> Context SHA512
hInitSeed seed
seed = seed -> Context SHA512
forall key a.
(ByteArrayAccess key, HashAlgorithm a) =>
key -> Context a
HMAC.initialize seed
seed

hFinalize :: HMAC.Context SHA512 -> (ByteString, ChainCode)
hFinalize :: Context SHA512 -> (ByteString, ChainCode)
hFinalize Context SHA512
ctx =
    let (ByteString
b1, ByteString
b2) = Int -> ByteString -> (ByteString, ByteString)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
32 (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ HMAC SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (HMAC SHA512 -> ByteString) -> HMAC SHA512 -> ByteString
forall a b. (a -> b) -> a -> b
$ Context SHA512 -> HMAC SHA512
forall a. HashAlgorithm a => Context a -> HMAC a
HMAC.finalize Context SHA512
ctx
     in (ByteString
b1, ByteString -> ChainCode
ChainCode ByteString
b2)