{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ViewPatterns          #-}
{-# LANGUAGE GADTs                 #-}
module Crypto.ECC.Ed25519BIP32 where

import qualified Crypto.Hash as C (SHA512, SHA256)
import qualified Crypto.MAC.HMAC as C
import qualified Data.ByteArray as B
import qualified Data.ByteString as BS
import           Data.Bits
import           Data.Kind (Type)
import           Data.Word
import           Data.Proxy
import qualified Crypto.Math.Edwards25519 as ED25519
import           Data.Type.Bool
import           Data.Type.Equality
import           GHC.TypeLits
import           Data.Function (on)
import           Unsafe.Coerce (unsafeCoerce)

import           Crypto.Math.Bits
import           Crypto.Math.Bytes (Bytes)
import qualified Crypto.Math.Bytes as Bytes

-- | A Master secret is a 256 bits random quantity
type MasterSecret = FBits 256

-- | A child key is similar to the key in structure
-- except it has an additional annotation representing
-- the indexes for the hierarchy derivation indexes from
-- a base 'Key' (usually the root key)
type ChildKey (didxs :: DerivationHier) = Key

-- | A key is a 512 bit random value and a chaincode
--
-- Left half need to have:
-- * Lowest 3 bits clear
-- * Highest bit clear
-- * Second highest bit set
-- * Third highest bit clear
--
-- Right half doesn't have any particular structure.
type Key = (FBits 256, FBits 256, ChainCode)

-- | A public part of a key
type Public = (PointCompressed, ChainCode)

-- | A point is 1 bit of x sign and 255 bit of y coordinate (y's 256th bit is always 0)
type PointCompressed = FBits 256

-- | A 256 bits chain code
newtype ChainCode = ChainCode { ChainCode -> Bytes 32
unChainCode :: Bytes 32 }
    deriving (ChainCode -> ChainCode -> Bool
(ChainCode -> ChainCode -> Bool)
-> (ChainCode -> ChainCode -> Bool) -> Eq ChainCode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ChainCode -> ChainCode -> Bool
$c/= :: ChainCode -> ChainCode -> Bool
== :: ChainCode -> ChainCode -> Bool
$c== :: ChainCode -> ChainCode -> Bool
Eq)

-- | A n bits Digest
newtype Hash n = Hash { Hash n -> FBits n
unHash :: FBits n }
    deriving (Hash n -> Hash n -> Bool
(Hash n -> Hash n -> Bool)
-> (Hash n -> Hash n -> Bool) -> Eq (Hash n)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (n :: Nat). Hash n -> Hash n -> Bool
/= :: Hash n -> Hash n -> Bool
$c/= :: forall (n :: Nat). Hash n -> Hash n -> Bool
== :: Hash n -> Hash n -> Bool
$c== :: forall (n :: Nat). Hash n -> Hash n -> Bool
Eq)

-- | A Serialized tag used during HMAC
type Tag = Bytes 1

-- | Serialized Index
newtype SerializedIndex = SerializedIndex (Bytes 4)
    deriving (SerializedIndex -> SerializedIndex -> Bool
(SerializedIndex -> SerializedIndex -> Bool)
-> (SerializedIndex -> SerializedIndex -> Bool)
-> Eq SerializedIndex
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SerializedIndex -> SerializedIndex -> Bool
$c/= :: SerializedIndex -> SerializedIndex -> Bool
== :: SerializedIndex -> SerializedIndex -> Bool
$c== :: SerializedIndex -> SerializedIndex -> Bool
Eq)

type HMAC_SHA512 = Bytes 64

data DerivationType = Hard | Soft
data DerivationMaterial = ChainCodeMaterial | KeyMaterial

data DerivationIndex (k :: DerivationType) (n :: Nat) = DerivationIndex

data DerivationHier = (:>) Nat DerivationHier | DerivationEnd

type MaxHardIndex = 0xffffffff
type MinHardIndex = 0x80000000
type MaxSoftIndex = MinHardIndex - 1
type MinSoftIndex = 0

data ValidIndex :: Nat -> Type where
    IsValidIndex    :: (ValidDerivationIndex n :~: 'True) -> ValidIndex n
    IsNotValidIndex :: (ValidDerivationIndex n :~: 'False) -> ValidIndex n

data ValidHardIndex :: Nat -> Type where
    IsValidHardIndex    :: (ValidDerivationHardIndex n :~: 'True) -> ValidHardIndex n
    IsNotValidHardIndex :: (ValidDerivationHardIndex n :~: 'False) -> ValidHardIndex n

data ValidSoftIndex :: Nat -> Type where
    IsValidSoftIndex    :: (ValidDerivationSoftIndex n :~: 'True) -> ValidSoftIndex n
    IsNotValidSoftIndex :: (ValidDerivationSoftIndex n :~: 'False) -> ValidSoftIndex n

getValidIndex :: KnownNat n => Proxy n -> Maybe (ValidDerivationIndex n :~: 'True)
getValidIndex :: Proxy n -> Maybe (ValidDerivationIndex n :~: 'True)
getValidIndex Proxy n
n = case Proxy n -> ValidIndex n
forall (n :: Nat). KnownNat n => Proxy n -> ValidIndex n
isValidIndex Proxy n
n of
                    IsValidIndex ValidDerivationIndex n :~: 'True
Refl -> ('True :~: 'True) -> Maybe ('True :~: 'True)
forall a. a -> Maybe a
Just 'True :~: 'True
forall k (a :: k). a :~: a
Refl
                    ValidIndex n
_                 -> Maybe (ValidDerivationIndex n :~: 'True)
forall a. Maybe a
Nothing

isValidIndex :: KnownNat n => Proxy n -> ValidIndex n
isValidIndex :: Proxy n -> ValidIndex n
isValidIndex Proxy n
n
    |  Proxy MinSoftIndex -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MinSoftIndex
forall k (t :: k). Proxy t
Proxy @MinSoftIndex) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n
    Bool -> Bool -> Bool
&& Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy MaxHardIndex -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MaxHardIndex
forall k (t :: k). Proxy t
Proxy @MaxHardIndex) = (ValidDerivationIndex n :~: 'True) -> ValidIndex n
forall (n :: Nat).
(ValidDerivationIndex n :~: 'True) -> ValidIndex n
IsValidIndex ((Any :~: Any) -> (n <=? MaxHardIndex) :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)
    | Bool
otherwise                                 = (ValidDerivationIndex n :~: 'False) -> ValidIndex n
forall (n :: Nat).
(ValidDerivationIndex n :~: 'False) -> ValidIndex n
IsNotValidIndex ((Any :~: Any) -> (n <=? MaxHardIndex) :~: 'False
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)

getValidHardIndex :: KnownNat n => Proxy n -> Maybe (ValidDerivationHardIndex n :~: 'True)
getValidHardIndex :: Proxy n -> Maybe (ValidDerivationHardIndex n :~: 'True)
getValidHardIndex Proxy n
n = case Proxy n -> ValidHardIndex n
forall (n :: Nat). KnownNat n => Proxy n -> ValidHardIndex n
isValidHardIndex Proxy n
n of
                    IsValidHardIndex ValidDerivationHardIndex n :~: 'True
Refl -> ('True :~: 'True) -> Maybe ('True :~: 'True)
forall a. a -> Maybe a
Just 'True :~: 'True
forall k (a :: k). a :~: a
Refl
                    ValidHardIndex n
_                     -> Maybe (ValidDerivationHardIndex n :~: 'True)
forall a. Maybe a
Nothing

isValidHardIndex :: KnownNat n => Proxy n -> ValidHardIndex n
isValidHardIndex :: Proxy n -> ValidHardIndex n
isValidHardIndex Proxy n
n
    |  Proxy MinHardIndex -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MinHardIndex
forall k (t :: k). Proxy t
Proxy @MinHardIndex) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n
    Bool -> Bool -> Bool
&& Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy MaxHardIndex -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MaxHardIndex
forall k (t :: k). Proxy t
Proxy @MaxHardIndex) = (ValidDerivationHardIndex n :~: 'True) -> ValidHardIndex n
forall (n :: Nat).
(ValidDerivationHardIndex n :~: 'True) -> ValidHardIndex n
IsValidHardIndex ((Any :~: Any) -> ValidDerivationHardIndex n :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)
    | Bool
otherwise                                 = (ValidDerivationHardIndex n :~: 'False) -> ValidHardIndex n
forall (n :: Nat).
(ValidDerivationHardIndex n :~: 'False) -> ValidHardIndex n
IsNotValidHardIndex ((Any :~: Any) -> ValidDerivationHardIndex n :~: 'False
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)

getValidSoftIndex :: KnownNat n => Proxy n -> Maybe (ValidDerivationSoftIndex n :~: 'True)
getValidSoftIndex :: Proxy n -> Maybe (ValidDerivationSoftIndex n :~: 'True)
getValidSoftIndex Proxy n
n = case Proxy n -> ValidSoftIndex n
forall (n :: Nat). KnownNat n => Proxy n -> ValidSoftIndex n
isValidSoftIndex Proxy n
n of
                    IsValidSoftIndex ValidDerivationSoftIndex n :~: 'True
Refl -> ('True :~: 'True) -> Maybe ('True :~: 'True)
forall a. a -> Maybe a
Just 'True :~: 'True
forall k (a :: k). a :~: a
Refl
                    ValidSoftIndex n
_                     -> Maybe (ValidDerivationSoftIndex n :~: 'True)
forall a. Maybe a
Nothing

isValidSoftIndex :: KnownNat n => Proxy n -> ValidSoftIndex n
isValidSoftIndex :: Proxy n -> ValidSoftIndex n
isValidSoftIndex Proxy n
n
    |  Proxy MinSoftIndex -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MinSoftIndex
forall k (t :: k). Proxy t
Proxy @MinSoftIndex) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n
    Bool -> Bool -> Bool
&& Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Proxy 2147483647 -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy MaxSoftIndex
forall k (t :: k). Proxy t
Proxy @MaxSoftIndex) = (ValidDerivationSoftIndex n :~: 'True) -> ValidSoftIndex n
forall (n :: Nat).
(ValidDerivationSoftIndex n :~: 'True) -> ValidSoftIndex n
IsValidSoftIndex ((Any :~: Any) -> (n <=? 2147483647) :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)
    | Bool
otherwise                                 = (ValidDerivationSoftIndex n :~: 'False) -> ValidSoftIndex n
forall (n :: Nat).
(ValidDerivationSoftIndex n :~: 'False) -> ValidSoftIndex n
IsNotValidSoftIndex ((Any :~: Any) -> (n <=? 2147483647) :~: 'False
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)

type ValidDerivationIndex     (n :: Nat) = (MinSoftIndex <=? n) && (n <=? MaxHardIndex)
type ValidDerivationHardIndex (n :: Nat) = (MinHardIndex <=? n) && (n <=? MaxHardIndex)
type ValidDerivationSoftIndex (n :: Nat) = (MinSoftIndex <=? n) && (n <=? MaxSoftIndex)

type family ValidDerivationIndexForType (k :: DerivationType) (n :: Nat) :: Bool where
    ValidDerivationIndexForType 'Hard n = ValidDerivationHardIndex n
    ValidDerivationIndexForType 'Soft n = ValidDerivationSoftIndex n

type family DerivationTag (ty :: DerivationType) (material :: DerivationMaterial) :: Nat where
    DerivationTag 'Hard 'KeyMaterial       = 0x0
    DerivationTag 'Hard 'ChainCodeMaterial = 0x1
    DerivationTag 'Soft 'KeyMaterial       = 0x2
    DerivationTag 'Soft 'ChainCodeMaterial = 0x3

-- | Check if the left half is valid
leftHalfValid :: FBits 256 -> Bool
leftHalfValid :: FBits 256 -> Bool
leftHalfValid FBits 256
v =
    [Bool] -> Bool
forall (t :: Type -> Type). Foldable t => t Bool -> Bool
and [ FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
0 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False
        , FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False
        , FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
2 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False
        , FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
29 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False
        , FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
28 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True
        , FBits 256 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit FBits 256
v Int
31 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False
        ]

toPublic :: Key -> Public
toPublic :: Key -> Public
toPublic (FBits 256
kl, FBits 256
_, ChainCode
cc) = (FBits 256 -> FBits 256
kToPoint FBits 256
kl, ChainCode
cc)

kToPoint :: FBits 256 -> PointCompressed
kToPoint :: FBits 256 -> FBits 256
kToPoint FBits 256
k = PointCompressed -> FBits 256
pointFromRepr (PointCompressed -> FBits 256) -> PointCompressed -> FBits 256
forall a b. (a -> b) -> a -> b
$ Scalar -> PointCompressed
ED25519.scalarToPoint Scalar
r
  where r :: Scalar
r = Bytes 32 -> Scalar
ED25519.scalarP (Bytes 32 -> Scalar) -> Bytes 32 -> Scalar
forall a b. (a -> b) -> a -> b
$ Endian -> FBits 256 -> Bytes (Div8 256)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian FBits 256
k

pointAdd :: PointCompressed -> PointCompressed -> PointCompressed
pointAdd :: FBits 256 -> FBits 256 -> FBits 256
pointAdd = ((PointCompressed -> FBits 256
pointFromRepr (PointCompressed -> FBits 256)
-> (PointCompressed -> PointCompressed)
-> PointCompressed
-> FBits 256
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ) ((PointCompressed -> PointCompressed)
 -> PointCompressed -> FBits 256)
-> (PointCompressed -> PointCompressed -> PointCompressed)
-> PointCompressed
-> PointCompressed
-> FBits 256
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  PointCompressed -> PointCompressed -> PointCompressed
ED25519.pointAdd) (PointCompressed -> PointCompressed -> FBits 256)
-> (FBits 256 -> PointCompressed)
-> FBits 256
-> FBits 256
-> FBits 256
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` FBits 256 -> PointCompressed
pointToRepr

pointToRepr :: PointCompressed -> ED25519.PointCompressed
pointToRepr :: FBits 256 -> PointCompressed
pointToRepr FBits 256
a = Bytes 32 -> PointCompressed
ED25519.pointCompressedP (Bytes 32 -> PointCompressed) -> Bytes 32 -> PointCompressed
forall a b. (a -> b) -> a -> b
$ Endian -> FBits 256 -> Bytes (Div8 256)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian FBits 256
a

pointFromRepr :: ED25519.PointCompressed -> PointCompressed
pointFromRepr :: PointCompressed -> FBits 256
pointFromRepr =
      Endian -> Bytes 32 -> FBits (32 * 8)
forall (n :: Nat). Endian -> Bytes n -> FBits (n * 8)
Bytes.toBits Endian
Bytes.LittleEndian
    (Bytes 32 -> FBits 256)
-> (PointCompressed -> Bytes 32) -> PointCompressed -> FBits 256
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PointCompressed -> Bytes 32
ED25519.unPointCompressedP

type family BitsToHashScheme (n :: Nat) where
    BitsToHashScheme 256 = C.SHA256
    BitsToHashScheme 512 = C.SHA512

type ValidTag tag = (0 <= tag, tag <= 3)

-- | Compute the HMAC-SHA512 using the ChainCode as the key
fcp :: forall tag idx deriveType deriveMaterial
     . ( KnownNat (DerivationTag deriveType deriveMaterial)
       , KnownNat idx
       , (DerivationTag deriveType deriveMaterial) ~ tag
       , ValidDerivationIndex idx ~ 'True
       , ValidDerivationIndexForType deriveType idx ~ 'True
       )
    => Proxy deriveMaterial
    -> Proxy deriveType
    -> Proxy idx
    -> ChainCode
    -> DerivationIndex deriveType idx
    -> [Word8]
    -> HMAC_SHA512
fcp :: Proxy deriveMaterial
-> Proxy deriveType
-> Proxy idx
-> ChainCode
-> DerivationIndex deriveType idx
-> [Word8]
-> HMAC_SHA512
fcp Proxy deriveMaterial
_ Proxy deriveType
_ Proxy idx
pidx ChainCode
c DerivationIndex deriveType idx
_ [Word8]
input =
    Bytes 32 -> Bytes n -> HMAC_SHA512
forall (keyLength :: Nat) (input :: Nat).
Bytes keyLength -> Bytes input -> HMAC_SHA512
hmacSHA512 Bytes 32
key (forall (n :: Nat). KnownNat n => Bytes n -> HMAC_SHA512)
-> [Word8] -> HMAC_SHA512
forall a.
(forall (n :: Nat). KnownNat n => Bytes n -> a) -> [Word8] -> a
`Bytes.packSome` (Bytes 1 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack Bytes 1
tagValue [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8]
input [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Bytes 4 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack Bytes 4
idx)
  where
    key :: Bytes 32
key = ChainCode -> Bytes 32
unChainCode ChainCode
c

    (SerializedIndex Bytes 4
idx) = Proxy idx -> SerializedIndex
forall (idx :: Nat).
(KnownNat idx, ValidDerivationIndex idx ~ 'True) =>
Proxy idx -> SerializedIndex
indexSerialized Proxy idx
pidx

    tagValue :: Tag
    tagValue :: Bytes 1
tagValue = [Word8] -> Bytes 1
forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
Bytes.pack [Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Word8) -> Integer -> Word8
forall a b. (a -> b) -> a -> b
$ Proxy tag -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy tag
forall k (t :: k). Proxy t
Proxy :: Proxy tag)]

hmacSHA512 :: Bytes keyLength -> Bytes input -> HMAC_SHA512
hmacSHA512 :: Bytes keyLength -> Bytes input -> HMAC_SHA512
hmacSHA512 Bytes keyLength
key ({-Bytes.trace "hmac-input" ->-} Bytes input
msg) =
    [Word8] -> HMAC_SHA512
forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
Bytes.pack ([Word8] -> HMAC_SHA512) -> [Word8] -> HMAC_SHA512
forall a b. (a -> b) -> a -> b
$ Digest SHA512 -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack (Digest SHA512 -> [Word8]) -> Digest SHA512 -> [Word8]
forall a b. (a -> b) -> a -> b
$ HMAC SHA512 -> Digest SHA512
forall a. HMAC a -> Digest a
C.hmacGetDigest HMAC SHA512
computed
  where
    computed :: C.HMAC C.SHA512
    computed :: HMAC SHA512
computed = ByteString -> ByteString -> HMAC SHA512
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
C.hmac ([Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Bytes keyLength -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack Bytes keyLength
key) ([Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Bytes input -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack Bytes input
msg)

class GetDerivationMaterial (dtype :: DerivationType) mat where
    getDerivationMaterial :: Proxy dtype -> mat -> [Word8]
instance GetDerivationMaterial 'Soft Key where
    getDerivationMaterial :: Proxy 'Soft -> Key -> [Word8]
getDerivationMaterial Proxy 'Soft
p Key
key = Proxy 'Soft -> FBits 256 -> [Word8]
forall (dtype :: DerivationType) mat.
GetDerivationMaterial dtype mat =>
Proxy dtype -> mat -> [Word8]
getDerivationMaterial Proxy 'Soft
p (Public -> FBits 256
forall a b. (a, b) -> a
fst (Public -> FBits 256) -> Public -> FBits 256
forall a b. (a -> b) -> a -> b
$ Key -> Public
toPublic Key
key)
instance GetDerivationMaterial 'Hard Key where
    getDerivationMaterial :: Proxy 'Hard -> Key -> [Word8]
getDerivationMaterial Proxy 'Hard
_ (FBits 256
kl,FBits 256
kr,ChainCode
_) =
        HMAC_SHA512 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack (HMAC_SHA512 -> [Word8]) -> HMAC_SHA512 -> [Word8]
forall a b. (a -> b) -> a -> b
$ Bytes 32 -> Bytes 32 -> HMAC_SHA512
forall (m :: Nat) (n :: Nat) (r :: Nat).
((m + n) ~ r) =>
Bytes n -> Bytes m -> Bytes r
Bytes.append (Endian -> FBits 256 -> Bytes (Div8 256)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian FBits 256
kl)
                                    (Endian -> FBits 256 -> Bytes (Div8 256)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian FBits 256
kr)
instance GetDerivationMaterial 'Soft PointCompressed where
    getDerivationMaterial :: Proxy 'Soft -> FBits 256 -> [Word8]
getDerivationMaterial Proxy 'Soft
_ FBits 256
p = Bytes 32 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack (Bytes 32 -> [Word8]) -> Bytes 32 -> [Word8]
forall a b. (a -> b) -> a -> b
$ Endian -> FBits 256 -> Bytes (Div8 256)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian FBits 256
p

derive :: forall dtype idx .
          ( KnownNat (DerivationTag dtype 'KeyMaterial)
          , KnownNat (DerivationTag dtype 'ChainCodeMaterial)
          , KnownNat idx
          , ValidDerivationIndex idx ~ 'True
          , ValidDerivationIndexForType dtype idx ~ 'True
          , GetDerivationMaterial dtype Key)
       => DerivationIndex dtype idx
       -> Key
       -> Key
derive :: DerivationIndex dtype idx -> Key -> Key
derive DerivationIndex dtype idx
idx key :: Key
key@(FBits 256
kl, FBits 256
kr, ChainCode
c) = (FBits 256
kl', FBits 256
kr', ChainCode
c')
  where
    dtype :: Proxy dtype
dtype       = Proxy dtype
forall k (t :: k). Proxy t
Proxy :: Proxy dtype
    matKeyProxy :: Proxy 'KeyMaterial
matKeyProxy = Proxy 'KeyMaterial
forall k (t :: k). Proxy t
Proxy :: Proxy 'KeyMaterial
    matCCProxy :: Proxy 'ChainCodeMaterial
matCCProxy  = Proxy 'ChainCodeMaterial
forall k (t :: k). Proxy t
Proxy :: Proxy 'ChainCodeMaterial
    -- 1) Z
    z :: HMAC_SHA512
z = Proxy 'KeyMaterial
-> Proxy dtype
-> Proxy idx
-> ChainCode
-> DerivationIndex dtype idx
-> [Word8]
-> HMAC_SHA512
forall (tag :: Nat) (idx :: Nat) (deriveType :: DerivationType)
       (deriveMaterial :: DerivationMaterial).
(KnownNat (DerivationTag deriveType deriveMaterial), KnownNat idx,
 DerivationTag deriveType deriveMaterial ~ tag,
 ValidDerivationIndex idx ~ 'True,
 ValidDerivationIndexForType deriveType idx ~ 'True) =>
Proxy deriveMaterial
-> Proxy deriveType
-> Proxy idx
-> ChainCode
-> DerivationIndex deriveType idx
-> [Word8]
-> HMAC_SHA512
fcp Proxy 'KeyMaterial
matKeyProxy Proxy dtype
dtype (Proxy idx
forall k (t :: k). Proxy t
Proxy :: Proxy idx) ChainCode
c DerivationIndex dtype idx
idx
            (Proxy dtype -> Key -> [Word8]
forall (dtype :: DerivationType) mat.
GetDerivationMaterial dtype mat =>
Proxy dtype -> mat -> [Word8]
getDerivationMaterial Proxy dtype
dtype Key
key)

    -- 2) produce kl' and kr'
    (FBits 256
zl8, FBits 256
zr) = HMAC_SHA512 -> (FBits 256, FBits 256)
step2 HMAC_SHA512
z

    kl' :: FBits 256
kl' = FBits 256
zl8 FBits 256 -> FBits 256 -> FBits 256
forall a. Num a => a -> a -> a
+ FBits 256
kl
    kr' :: FBits 256
kr' = FBits 256
zr FBits 256 -> FBits 256 -> FBits 256
forall a. Num a => a -> a -> a
+ FBits 256
kr

    -- 3) child chain code
    untrimmedCC :: HMAC_SHA512
untrimmedCC = Proxy 'ChainCodeMaterial
-> Proxy dtype
-> Proxy idx
-> ChainCode
-> DerivationIndex dtype idx
-> [Word8]
-> HMAC_SHA512
forall (tag :: Nat) (idx :: Nat) (deriveType :: DerivationType)
       (deriveMaterial :: DerivationMaterial).
(KnownNat (DerivationTag deriveType deriveMaterial), KnownNat idx,
 DerivationTag deriveType deriveMaterial ~ tag,
 ValidDerivationIndex idx ~ 'True,
 ValidDerivationIndexForType deriveType idx ~ 'True) =>
Proxy deriveMaterial
-> Proxy deriveType
-> Proxy idx
-> ChainCode
-> DerivationIndex deriveType idx
-> [Word8]
-> HMAC_SHA512
fcp Proxy 'ChainCodeMaterial
matCCProxy Proxy dtype
dtype (Proxy idx
forall k (t :: k). Proxy t
Proxy :: Proxy idx) ChainCode
c DerivationIndex dtype idx
idx
                      (Proxy dtype -> Key -> [Word8]
forall (dtype :: DerivationType) mat.
GetDerivationMaterial dtype mat =>
Proxy dtype -> mat -> [Word8]
getDerivationMaterial Proxy dtype
dtype Key
key)
    c' :: ChainCode
c' = Bytes 32 -> ChainCode
ChainCode (Bytes 32 -> ChainCode) -> Bytes 32 -> ChainCode
forall a b. (a -> b) -> a -> b
$ HMAC_SHA512 -> Bytes 32
forall (n :: Nat) (m :: Nat).
(KnownNat m, KnownNat n, n <= m) =>
Bytes m -> Bytes n
Bytes.drop HMAC_SHA512
untrimmedCC

derivePublic :: forall idx dtype .
          ( dtype ~ 'Soft -- can only derive public stuff with Soft Derivation
          , KnownNat (DerivationTag dtype 'KeyMaterial)
          , KnownNat (DerivationTag dtype 'ChainCodeMaterial)
          , KnownNat idx
          , ValidDerivationIndex idx ~ 'True
          , ValidDerivationIndexForType dtype idx ~ 'True
          , GetDerivationMaterial dtype PointCompressed)
       => DerivationIndex 'Soft idx
       -> PointCompressed
       -> ChainCode
       -> (PointCompressed, ChainCode)
derivePublic :: DerivationIndex 'Soft idx -> FBits 256 -> ChainCode -> Public
derivePublic DerivationIndex 'Soft idx
idx FBits 256
p ChainCode
c = (FBits 256
p', ChainCode
c')
  where
    dtype :: Proxy dtype
dtype       = Proxy dtype
forall k (t :: k). Proxy t
Proxy :: Proxy dtype
    matKeyProxy :: Proxy 'KeyMaterial
matKeyProxy = Proxy 'KeyMaterial
forall k (t :: k). Proxy t
Proxy :: Proxy 'KeyMaterial
    matCCProxy :: Proxy 'ChainCodeMaterial
matCCProxy  = Proxy 'ChainCodeMaterial
forall k (t :: k). Proxy t
Proxy :: Proxy 'ChainCodeMaterial
    -- 1) Z
    z :: HMAC_SHA512
z = Proxy 'KeyMaterial
-> Proxy dtype
-> Proxy idx
-> ChainCode
-> DerivationIndex dtype idx
-> [Word8]
-> HMAC_SHA512
forall (tag :: Nat) (idx :: Nat) (deriveType :: DerivationType)
       (deriveMaterial :: DerivationMaterial).
(KnownNat (DerivationTag deriveType deriveMaterial), KnownNat idx,
 DerivationTag deriveType deriveMaterial ~ tag,
 ValidDerivationIndex idx ~ 'True,
 ValidDerivationIndexForType deriveType idx ~ 'True) =>
Proxy deriveMaterial
-> Proxy deriveType
-> Proxy idx
-> ChainCode
-> DerivationIndex deriveType idx
-> [Word8]
-> HMAC_SHA512
fcp Proxy 'KeyMaterial
matKeyProxy Proxy dtype
dtype (Proxy idx
forall k (t :: k). Proxy t
Proxy :: Proxy idx) ChainCode
c DerivationIndex dtype idx
DerivationIndex 'Soft idx
idx (Proxy dtype -> FBits 256 -> [Word8]
forall (dtype :: DerivationType) mat.
GetDerivationMaterial dtype mat =>
Proxy dtype -> mat -> [Word8]
getDerivationMaterial Proxy dtype
dtype FBits 256
p)

    -- 2) produce kl' and kr'
    (FBits 256
zl8, FBits 256
_) = HMAC_SHA512 -> (FBits 256, FBits 256)
step2 HMAC_SHA512
z

    p' :: FBits 256
p' = FBits 256 -> FBits 256
kToPoint FBits 256
zl8 FBits 256 -> FBits 256 -> FBits 256
`pointAdd` FBits 256
p

    -- 3) child chain code
    untrimmedCC :: HMAC_SHA512
untrimmedCC = Proxy 'ChainCodeMaterial
-> Proxy dtype
-> Proxy idx
-> ChainCode
-> DerivationIndex dtype idx
-> [Word8]
-> HMAC_SHA512
forall (tag :: Nat) (idx :: Nat) (deriveType :: DerivationType)
       (deriveMaterial :: DerivationMaterial).
(KnownNat (DerivationTag deriveType deriveMaterial), KnownNat idx,
 DerivationTag deriveType deriveMaterial ~ tag,
 ValidDerivationIndex idx ~ 'True,
 ValidDerivationIndexForType deriveType idx ~ 'True) =>
Proxy deriveMaterial
-> Proxy deriveType
-> Proxy idx
-> ChainCode
-> DerivationIndex deriveType idx
-> [Word8]
-> HMAC_SHA512
fcp Proxy 'ChainCodeMaterial
matCCProxy Proxy dtype
dtype (Proxy idx
forall k (t :: k). Proxy t
Proxy :: Proxy idx) ChainCode
c DerivationIndex dtype idx
DerivationIndex 'Soft idx
idx
                      (Proxy dtype -> FBits 256 -> [Word8]
forall (dtype :: DerivationType) mat.
GetDerivationMaterial dtype mat =>
Proxy dtype -> mat -> [Word8]
getDerivationMaterial Proxy dtype
dtype FBits 256
p)
    c' :: ChainCode
c' = Bytes 32 -> ChainCode
ChainCode (Bytes 32 -> ChainCode) -> Bytes 32 -> ChainCode
forall a b. (a -> b) -> a -> b
$ HMAC_SHA512 -> Bytes 32
forall (n :: Nat) (m :: Nat).
(KnownNat m, KnownNat n, n <= m) =>
Bytes m -> Bytes n
Bytes.drop HMAC_SHA512
untrimmedCC

-- | Given Z, return 8*ZL(28Bytes) and ZR
step2 :: Bytes 64 -> (FBits 256, FBits 256)
step2 :: HMAC_SHA512 -> (FBits 256, FBits 256)
step2 HMAC_SHA512
z = (FBits 256
8 FBits 256 -> FBits 256 -> FBits 256
forall a. Num a => a -> a -> a
* FBits 256
zeroExtendedZl, Endian -> Bytes 32 -> FBits (32 * 8)
forall (n :: Nat). Endian -> Bytes n -> FBits (n * 8)
Bytes.toBits Endian
Bytes.LittleEndian Bytes 32
zRight)
  where
    (Bytes 32
zLeft32, Bytes 32
zRight) = HMAC_SHA512 -> (Bytes 32, Bytes 32)
forall (m :: Nat) (n :: Nat).
(KnownNat n, (n * 2) ~ m) =>
Bytes m -> (Bytes n, Bytes n)
Bytes.splitHalf HMAC_SHA512
z

    -- step
    -- * take 28 bytes of zLeft32 (zl)
    -- * extend back to 32 bytes
    -- * multiply this number by 8
    zeroExtendedZl :: FBits 256
zeroExtendedZl = FBits 32
FBits (4 * 8)
zeroExtender FBits 32 -> FBits 224 -> FBits 256
forall (m :: Nat) (n :: Nat) (r :: Nat).
(SizeValid m, SizeValid n, SizeValid r, (m + n) ~ r) =>
FBits n -> FBits m -> FBits r
`append` Endian -> Bytes 28 -> FBits (28 * 8)
forall (n :: Nat). Endian -> Bytes n -> FBits (n * 8)
Bytes.toBits Endian
Bytes.LittleEndian Bytes 28
zl

    zl :: Bytes 28 -- only take 28 bytes
    zl :: Bytes 28
zl = Bytes 32 -> Bytes 28
forall (n :: Nat) (m :: Nat).
(KnownNat n, n <= m) =>
Bytes m -> Bytes n
Bytes.take Bytes 32
zLeft32

    zeroExtender :: FBits (4*8) -- re-extend by 4 bytes
    zeroExtender :: FBits (4 * 8)
zeroExtender = FBits (4 * 8)
0

-- | Serialized index
indexSerialized :: forall idx . (KnownNat idx, ValidDerivationIndex idx ~ 'True)
                => Proxy idx
                -> SerializedIndex
indexSerialized :: Proxy idx -> SerializedIndex
indexSerialized Proxy idx
idx = Bytes 4 -> SerializedIndex
SerializedIndex (Bytes 4 -> SerializedIndex) -> Bytes 4 -> SerializedIndex
forall a b. (a -> b) -> a -> b
$ Endian -> FBits 32 -> Bytes (Div8 32)
forall (n :: Nat).
KnownNat n =>
Endian -> FBits n -> Bytes (Div8 n)
Bytes.fromBits Endian
Bytes.LittleEndian (Integer -> FBits 32
forall a. Num a => Integer -> a
fromInteger Integer
n :: FBits 32)
  where n :: Integer
n = Proxy idx -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal Proxy idx
idx