-- |
-- Module      : Crypto.Math.Edwards25519
-- Description : Edwards 25519 arithmetics
-- Maintainer  : vincent@typed.io
--
-- Simple module to play with the arithmetics of the twisted edwards curve Ed25519
-- using Extended Twisted Edwards Coordinates. Compared to the normal implementation
-- this allow to use standard DH property:
--
-- for all valid s1 and s2 scalar:
--
-- > scalarToPoint (s1 + s2) = pointAdd (scalarToPoint s1) (scalarToPoint s2)
--
-- For further useful references about Ed25519:
--
-- * RFC 8032
-- * <http://ed25519.cr.yp.to/>
-- * <http://ed25519.cr.yp.to/ed25519-20110926.pdf>
-- * <http://eprint.iacr.org/2008/522.pdf>
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE CPP #-}

module Crypto.Math.Edwards25519
    (
    -- * Basic types
      Scalar
    , PointCompressed
    , Signature(..)
    -- * smart constructor & destructor
    , scalar
    , scalarP
    , unScalar
    , pointCompressed
    , pointCompressedP
    , unPointCompressed
    , unPointCompressedP
    -- * Arithmetic
    , scalarFromInteger
    , scalarAdd
    , scalarToPoint
    , pointAdd
    -- * Signature & Verify
    , sign
    , verify
    ) where

import           Control.DeepSeq             (NFData)
import           Crypto.Hash
import           Crypto.Number.ModArithmetic
import           Crypto.Number.Serialize
import           Data.Bits
#if MIN_VERSION_memory(0,14,18)
import qualified Data.ByteArray              as B hiding (append, reverse)
#else
import qualified Data.ByteArray              as B hiding (append)
#endif
import           Data.ByteString             (ByteString)
import qualified Data.ByteString             as B (append, reverse)
import           Data.Hashable               (Hashable)
import           GHC.Stack

import           Crypto.Math.Bytes (Bytes)
import qualified Crypto.Math.Bytes as Bytes

-- | Represent a scalar in the base field
newtype Scalar = Scalar { Scalar -> ByteString
unScalar :: ByteString }

-- | Represent a point on the Edwards 25519 curve
newtype PointCompressed = PointCompressed { PointCompressed -> ByteString
unPointCompressed :: ByteString }
    deriving (Int -> PointCompressed -> ShowS
[PointCompressed] -> ShowS
PointCompressed -> String
(Int -> PointCompressed -> ShowS)
-> (PointCompressed -> String)
-> ([PointCompressed] -> ShowS)
-> Show PointCompressed
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PointCompressed] -> ShowS
$cshowList :: [PointCompressed] -> ShowS
show :: PointCompressed -> String
$cshow :: PointCompressed -> String
showsPrec :: Int -> PointCompressed -> ShowS
$cshowsPrec :: Int -> PointCompressed -> ShowS
Show, PointCompressed -> PointCompressed -> Bool
(PointCompressed -> PointCompressed -> Bool)
-> (PointCompressed -> PointCompressed -> Bool)
-> Eq PointCompressed
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PointCompressed -> PointCompressed -> Bool
$c/= :: PointCompressed -> PointCompressed -> Bool
== :: PointCompressed -> PointCompressed -> Bool
$c== :: PointCompressed -> PointCompressed -> Bool
Eq, Eq PointCompressed
Eq PointCompressed
-> (PointCompressed -> PointCompressed -> Ordering)
-> (PointCompressed -> PointCompressed -> Bool)
-> (PointCompressed -> PointCompressed -> Bool)
-> (PointCompressed -> PointCompressed -> Bool)
-> (PointCompressed -> PointCompressed -> Bool)
-> (PointCompressed -> PointCompressed -> PointCompressed)
-> (PointCompressed -> PointCompressed -> PointCompressed)
-> Ord PointCompressed
PointCompressed -> PointCompressed -> Bool
PointCompressed -> PointCompressed -> Ordering
PointCompressed -> PointCompressed -> PointCompressed
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PointCompressed -> PointCompressed -> PointCompressed
$cmin :: PointCompressed -> PointCompressed -> PointCompressed
max :: PointCompressed -> PointCompressed -> PointCompressed
$cmax :: PointCompressed -> PointCompressed -> PointCompressed
>= :: PointCompressed -> PointCompressed -> Bool
$c>= :: PointCompressed -> PointCompressed -> Bool
> :: PointCompressed -> PointCompressed -> Bool
$c> :: PointCompressed -> PointCompressed -> Bool
<= :: PointCompressed -> PointCompressed -> Bool
$c<= :: PointCompressed -> PointCompressed -> Bool
< :: PointCompressed -> PointCompressed -> Bool
$c< :: PointCompressed -> PointCompressed -> Bool
compare :: PointCompressed -> PointCompressed -> Ordering
$ccompare :: PointCompressed -> PointCompressed -> Ordering
$cp1Ord :: Eq PointCompressed
Ord, PointCompressed -> ()
(PointCompressed -> ()) -> NFData PointCompressed
forall a. (a -> ()) -> NFData a
rnf :: PointCompressed -> ()
$crnf :: PointCompressed -> ()
NFData, Int -> PointCompressed -> Int
PointCompressed -> Int
(Int -> PointCompressed -> Int)
-> (PointCompressed -> Int) -> Hashable PointCompressed
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: PointCompressed -> Int
$chash :: PointCompressed -> Int
hashWithSalt :: Int -> PointCompressed -> Int
$chashWithSalt :: Int -> PointCompressed -> Int
Hashable)

-- | Represent a signature
newtype Signature = Signature { Signature -> ByteString
unSignature :: ByteString }
    deriving (Int -> Signature -> ShowS
[Signature] -> ShowS
Signature -> String
(Int -> Signature -> ShowS)
-> (Signature -> String)
-> ([Signature] -> ShowS)
-> Show Signature
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Signature] -> ShowS
$cshowList :: [Signature] -> ShowS
show :: Signature -> String
$cshow :: Signature -> String
showsPrec :: Int -> Signature -> ShowS
$cshowsPrec :: Int -> Signature -> ShowS
Show, Signature -> Signature -> Bool
(Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool) -> Eq Signature
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signature -> Signature -> Bool
$c/= :: Signature -> Signature -> Bool
== :: Signature -> Signature -> Bool
$c== :: Signature -> Signature -> Bool
Eq, Eq Signature
Eq Signature
-> (Signature -> Signature -> Ordering)
-> (Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool)
-> (Signature -> Signature -> Signature)
-> (Signature -> Signature -> Signature)
-> Ord Signature
Signature -> Signature -> Bool
Signature -> Signature -> Ordering
Signature -> Signature -> Signature
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Signature -> Signature -> Signature
$cmin :: Signature -> Signature -> Signature
max :: Signature -> Signature -> Signature
$cmax :: Signature -> Signature -> Signature
>= :: Signature -> Signature -> Bool
$c>= :: Signature -> Signature -> Bool
> :: Signature -> Signature -> Bool
$c> :: Signature -> Signature -> Bool
<= :: Signature -> Signature -> Bool
$c<= :: Signature -> Signature -> Bool
< :: Signature -> Signature -> Bool
$c< :: Signature -> Signature -> Bool
compare :: Signature -> Signature -> Ordering
$ccompare :: Signature -> Signature -> Ordering
$cp1Ord :: Eq Signature
Ord, Signature -> ()
(Signature -> ()) -> NFData Signature
forall a. (a -> ()) -> NFData a
rnf :: Signature -> ()
$crnf :: Signature -> ()
NFData, Int -> Signature -> Int
Signature -> Int
(Int -> Signature -> Int)
-> (Signature -> Int) -> Hashable Signature
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: Signature -> Int
$chash :: Signature -> Int
hashWithSalt :: Int -> Signature -> Int
$chashWithSalt :: Int -> Signature -> Int
Hashable)

newtype Fq = Fq { Fq -> Integer
unFq :: Integer }
-- newtype Fp = Fp { unFp :: Integer }

{- for debugging
fq :: HasCallStack => Integer -> Fq
fq n
    | n >= 0 && n < q = Fq n
    | otherwise       = error "fq"
-}

fq :: Integer -> Fq
fq :: Integer -> Fq
fq = Integer -> Fq
Fq

-- | Create a Ed25519 scalar
--
-- Only check that the length is of expected size (32 bytes), no effort is made for the scalar
-- to be in the right base field range on purpose.
scalar :: ByteString -> Scalar
scalar :: ByteString -> Scalar
scalar ByteString
bs
    | ByteString -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = String -> Scalar
forall a. HasCallStack => String -> a
error String
"invalid scalar"
    | Bool
otherwise         = ByteString -> Scalar
Scalar ByteString
bs

scalarP :: Bytes 32 -> Scalar
scalarP :: Bytes 32 -> Scalar
scalarP = ByteString -> Scalar
scalar (ByteString -> Scalar)
-> (Bytes 32 -> ByteString) -> Bytes 32 -> Scalar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
forall a. ByteArray a => [Word8] -> a
B.pack ([Word8] -> ByteString)
-> (Bytes 32 -> [Word8]) -> Bytes 32 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes 32 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack


-- | Check if a scalar is valid and all the bits properly set/cleared
-- scalarValid :: Scalar -> Bool
-- scalarValid _s = True -- TODO

-- | Smart constructor to create a compress point binary
--
-- Check if the length is of expected size
pointCompressed :: HasCallStack => ByteString -> PointCompressed
pointCompressed :: ByteString -> PointCompressed
pointCompressed ByteString
bs
    | ByteString -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = String -> PointCompressed
forall a. HasCallStack => String -> a
error (String
"invalid compressed point: expecting 32 bytes, got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ByteString
bs) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" bytes")
    | Bool
otherwise         = ByteString -> PointCompressed
PointCompressed ByteString
bs

pointCompressedP :: Bytes 32 -> PointCompressed
pointCompressedP :: Bytes 32 -> PointCompressed
pointCompressedP = HasCallStack => ByteString -> PointCompressed
ByteString -> PointCompressed
pointCompressed (ByteString -> PointCompressed)
-> (Bytes 32 -> ByteString) -> Bytes 32 -> PointCompressed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
forall a. ByteArray a => [Word8] -> a
B.pack ([Word8] -> ByteString)
-> (Bytes 32 -> [Word8]) -> Bytes 32 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes 32 -> [Word8]
forall (n :: Nat). Bytes n -> [Word8]
Bytes.unpack

unPointCompressedP :: PointCompressed -> Bytes 32
unPointCompressedP :: PointCompressed -> Bytes 32
unPointCompressedP (PointCompressed ByteString
bs) = [Word8] -> Bytes 32
forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
Bytes.pack ([Word8] -> Bytes 32) -> [Word8] -> Bytes 32
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack ByteString
bs

-- | Create a signature using a variant of ED25519 signature
--
-- we don't hash the secret key to derive a key + prefix, but
-- instead we take an explicit salt and compute a prefix
-- using the secret key + salt.
sign :: B.ByteArrayAccess msg => Scalar -> ByteString -> msg -> Signature
sign :: Scalar -> ByteString -> msg -> Signature
sign Scalar
a ByteString
salt msg
msg =
    ByteString -> Signature
Signature (PointCompressed -> ByteString
unPointCompressed PointCompressed
pR ByteString -> ByteString -> ByteString
`B.append` Integer -> ByteString
toBytes Integer
s)
  where
    prefix :: Digest SHA512
prefix = ByteString -> Digest SHA512
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
hash ((Scalar -> ByteString
unScalar Scalar
a) ByteString -> ByteString -> ByteString
`B.append` ByteString
salt) :: Digest SHA512
    pA :: PointCompressed
pA = Scalar -> PointCompressed
scalarToPoint Scalar
a
    r :: Fq
r = ByteString -> Fq
forall ba. ByteArrayAccess ba => ba -> Fq
sha512_modq (Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert Digest SHA512
prefix ByteString -> ByteString -> ByteString
`B.append` msg -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert msg
msg)
    pR :: PointCompressed
pR = ExtendedPoint -> PointCompressed
ePointCompress (ExtendedPoint -> PointCompressed)
-> ExtendedPoint -> PointCompressed
forall a b. (a -> b) -> a -> b
$ Fq -> ExtendedPoint -> ExtendedPoint
ePointMul Fq
r ExtendedPoint
pG
    h :: Fq
h = ByteString -> Fq
forall ba. ByteArrayAccess ba => ba -> Fq
sha512_modq (PointCompressed -> ByteString
unPointCompressed PointCompressed
pR ByteString -> ByteString -> ByteString
`B.append` PointCompressed -> ByteString
unPointCompressed PointCompressed
pA ByteString -> ByteString -> ByteString
`B.append` msg -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert msg
msg)
    s :: Integer
s = (Fq -> Integer
unFq Fq
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Fq -> Integer
unFq Fq
h Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (ByteString -> Integer
fromBytes (Scalar -> ByteString
unScalar Scalar
a))) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
q

-- | Verify a signature
verify :: B.ByteArrayAccess msg => PointCompressed -> msg -> Signature -> Bool
verify :: PointCompressed -> msg -> Signature -> Bool
verify PointCompressed
pA msg
msg (Signature ByteString
signature) =
    ExtendedPoint
pS ExtendedPoint -> ExtendedPoint -> Bool
`pointEqual` ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd (PointCompressed -> ExtendedPoint
ePointDecompress PointCompressed
pR) ExtendedPoint
hA
  where
    (PointCompressed
pR, Fq
s) =
        let (ByteString
sig0, ByteString
sig1) = Int -> ByteString -> (ByteString, ByteString)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
32 ByteString
signature
         in (ByteString -> PointCompressed
PointCompressed ByteString
sig0, Integer -> Fq
fq (Integer -> Fq) -> Integer -> Fq
forall a b. (a -> b) -> a -> b
$ ByteString -> Integer
fromBytes ByteString
sig1)

    pointEqual :: ExtendedPoint -> ExtendedPoint -> Bool
pointEqual (ExtendedPoint Integer
pX Integer
pY Integer
pZ Integer
_) (ExtendedPoint Integer
qX Integer
qY Integer
qZ Integer
_) =
        ((Integer
pX Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
qZ Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
qX Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
pZ) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) Bool -> Bool -> Bool
&& ((Integer
pY Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
qZ Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
qY Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
pZ) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0)

    h :: Fq
h = ByteString -> Fq
forall ba. ByteArrayAccess ba => ba -> Fq
sha512_modq (PointCompressed -> ByteString
unPointCompressed PointCompressed
pR ByteString -> ByteString -> ByteString
`B.append` PointCompressed -> ByteString
unPointCompressed PointCompressed
pA ByteString -> ByteString -> ByteString
`B.append` msg -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert msg
msg)
    pS :: ExtendedPoint
pS = Fq -> ExtendedPoint -> ExtendedPoint
ePointMul Fq
s ExtendedPoint
pG
    hA :: ExtendedPoint
hA = Fq -> ExtendedPoint -> ExtendedPoint
ePointMul Fq
h (PointCompressed -> ExtendedPoint
ePointDecompress PointCompressed
pA)

-- | Add 2 scalar in the base field together
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd (Scalar ByteString
s1) (Scalar ByteString
s2) = ByteString -> Scalar
Scalar (ByteString -> Scalar) -> ByteString -> Scalar
forall a b. (a -> b) -> a -> b
$ Integer -> ByteString
toBytes Integer
r
  where
    r :: Integer
r = (ByteString -> Integer
fromBytes ByteString
s1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ ByteString -> Integer
fromBytes ByteString
s2) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
q

-- | Create a scalar from integer. mainly for debugging purpose.
scalarFromInteger :: Integer -> Scalar
scalarFromInteger :: Integer -> Scalar
scalarFromInteger Integer
n = ByteString -> Scalar
Scalar (ByteString -> Scalar) -> ByteString -> Scalar
forall a b. (a -> b) -> a -> b
$ Integer -> ByteString
toBytes (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
q)

-- | Add 2 points together
pointAdd :: PointCompressed -> PointCompressed -> PointCompressed
pointAdd :: PointCompressed -> PointCompressed -> PointCompressed
pointAdd PointCompressed
p1 PointCompressed
p2 = ExtendedPoint -> PointCompressed
ePointCompress (ExtendedPoint -> PointCompressed)
-> ExtendedPoint -> PointCompressed
forall a b. (a -> b) -> a -> b
$ ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd (PointCompressed -> ExtendedPoint
ePointDecompress PointCompressed
p1) (PointCompressed -> ExtendedPoint
ePointDecompress PointCompressed
p2)

-- | Lift a scalar to the curve, and returning a compressed point
scalarToPoint :: Scalar -> PointCompressed
scalarToPoint :: Scalar -> PointCompressed
scalarToPoint (Scalar ByteString
sec) = ExtendedPoint -> PointCompressed
ePointCompress (ExtendedPoint -> PointCompressed)
-> ExtendedPoint -> PointCompressed
forall a b. (a -> b) -> a -> b
$ Fq -> ExtendedPoint -> ExtendedPoint
ePointMul (Integer -> Fq
fq (ByteString -> Integer
fromBytes ByteString
sec Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
q)) ExtendedPoint
pG

-- | Point represented by (X, Y, Z, T) in extended twisted edward coordinates.
--
--   x = X/Z
--   y = Y/Z
-- x*y = T/Z
data ExtendedPoint = ExtendedPoint !Integer !Integer !Integer !Integer
    deriving (Int -> ExtendedPoint -> ShowS
[ExtendedPoint] -> ShowS
ExtendedPoint -> String
(Int -> ExtendedPoint -> ShowS)
-> (ExtendedPoint -> String)
-> ([ExtendedPoint] -> ShowS)
-> Show ExtendedPoint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ExtendedPoint] -> ShowS
$cshowList :: [ExtendedPoint] -> ShowS
show :: ExtendedPoint -> String
$cshow :: ExtendedPoint -> String
showsPrec :: Int -> ExtendedPoint -> ShowS
$cshowsPrec :: Int -> ExtendedPoint -> ShowS
Show,ExtendedPoint -> ExtendedPoint -> Bool
(ExtendedPoint -> ExtendedPoint -> Bool)
-> (ExtendedPoint -> ExtendedPoint -> Bool) -> Eq ExtendedPoint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ExtendedPoint -> ExtendedPoint -> Bool
$c/= :: ExtendedPoint -> ExtendedPoint -> Bool
== :: ExtendedPoint -> ExtendedPoint -> Bool
$c== :: ExtendedPoint -> ExtendedPoint -> Bool
Eq)

ePointAdd :: ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd :: ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd (ExtendedPoint Integer
pX Integer
pY Integer
pZ Integer
pT) (ExtendedPoint Integer
qX Integer
qY Integer
qZ Integer
qT) =
    Integer -> Integer -> Integer -> Integer -> ExtendedPoint
ExtendedPoint (Integer
eInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
f) (Integer
gInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
h) (Integer
fInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
g) (Integer
eInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
h)
  where
    a :: Integer
a = ((Integer
pYInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
pX) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
qYInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
qX)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    b :: Integer
b = ((Integer
pYInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
pX) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
qYInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
qX)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    c :: Integer
c = (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
pT Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
qT Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
curveD) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    d :: Integer
d = (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
pZ Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
qZ) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    e :: Integer
e = Integer
bInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
a
    f :: Integer
f = Integer
dInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
c
    g :: Integer
g = Integer
dInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
c
    h :: Integer
h = Integer
bInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
a

ePointMul :: Fq -> ExtendedPoint -> ExtendedPoint
ePointMul :: Fq -> ExtendedPoint -> ExtendedPoint
ePointMul (Fq Integer
s) = Int -> ExtendedPoint -> ExtendedPoint -> ExtendedPoint
loop Int
255 (Integer -> Integer -> Integer -> Integer -> ExtendedPoint
ExtendedPoint Integer
0 Integer
1 Integer
1 Integer
0)
  where
    loop :: Int -> ExtendedPoint -> ExtendedPoint -> ExtendedPoint
loop !Int
i !ExtendedPoint
acc !ExtendedPoint
pP
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0       = ExtendedPoint
pP ExtendedPoint -> ExtendedPoint -> ExtendedPoint
`seq` ExtendedPoint
acc
        | Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
s Int
i = Int -> ExtendedPoint -> ExtendedPoint -> ExtendedPoint
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd ExtendedPoint
acc ExtendedPoint
pP)  (ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd ExtendedPoint
pP ExtendedPoint
pP)
        | Bool
otherwise   = Int -> ExtendedPoint -> ExtendedPoint -> ExtendedPoint
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd ExtendedPoint
acc ExtendedPoint
acc) (ExtendedPoint -> ExtendedPoint -> ExtendedPoint
ePointAdd ExtendedPoint
acc ExtendedPoint
pP)

ePointCompress :: ExtendedPoint -> PointCompressed
ePointCompress :: ExtendedPoint -> PointCompressed
ePointCompress (ExtendedPoint Integer
pX Integer
pY Integer
pZ Integer
_) =
    ByteString -> PointCompressed
PointCompressed (ByteString -> PointCompressed) -> ByteString -> PointCompressed
forall a b. (a -> b) -> a -> b
$ Integer -> ByteString
toBytes (Integer
y Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.|. ((Integer
x Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
0x1) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
255))
  where
    zinv :: Integer
zinv = Integer -> Integer
modp_inv Integer
pZ
    x :: Integer
x = (Integer
pX Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
zinv) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    y :: Integer
y = (Integer
pY Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
zinv) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p

ePointDecompress :: PointCompressed -> ExtendedPoint
ePointDecompress :: PointCompressed -> ExtendedPoint
ePointDecompress (PointCompressed ByteString
bs) =
    let cy :: Integer
cy    = ByteString -> Integer
fromBytes ByteString
bs
        xSign :: Bool
xSign = Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
cy Int
255
        y :: Integer
y     = Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
clearBit Integer
cy Int
255
        x :: Integer
x     = Integer -> Bool -> Integer
recoverX Integer
y Bool
xSign
     in Integer -> Integer -> Integer -> Integer -> ExtendedPoint
ExtendedPoint Integer
x Integer
y Integer
1 ((Integer
xInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
y) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p)

-- | Given y and the sign of x, recover x
recoverX :: Integer -> Bool -> Integer
recoverX :: Integer -> Bool -> Integer
recoverX Integer
y Bool
xSign = Integer
x''
  where
    x2 :: Integer
x2 = (Integer
yInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
yInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer
modp_inv (Integer
curveDInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
yInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
yInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
    x :: Integer
x = Integer -> Integer -> Integer -> Integer
expFast Integer
x2 ((Integer
pInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
3) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
8) Integer
p

    x' :: Integer
x'
        | (Integer
xInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
x2) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 = (Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
modp_sqrt_m1) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
        | Bool
otherwise               = Integer
x

    x'' :: Integer
x''
        | Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
x' Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool
xSign = Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
x'
        | Bool
otherwise       = Integer
x'

    modp_sqrt_m1 :: Integer
    !modp_sqrt_m1 :: Integer
modp_sqrt_m1 = Integer -> Integer -> Integer -> Integer
expFast Integer
2 ((Integer
pInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
4) Integer
p

-- | Unserialize little endian
fromBytes :: ByteString -> Integer
fromBytes :: ByteString -> Integer
fromBytes = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (ByteString -> Integer)
-> (ByteString -> ByteString) -> ByteString -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B.reverse

-- | Serialize little endian of a given size (32 bytes)
toBytes :: Integer -> ByteString
toBytes :: Integer -> ByteString
toBytes = ByteString -> ByteString
B.reverse (ByteString -> ByteString)
-> (Integer -> ByteString) -> Integer -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
32

-- | Inverse modular p
modp_inv :: Integer -> Integer
modp_inv :: Integer -> Integer
modp_inv Integer
x = Integer -> Integer -> Integer -> Integer
expFast Integer
x (Integer
pInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
2) Integer
p

-- | Base field 2^255-19 => 25519
p :: Integer
p :: Integer
p = Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
255 ::Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
19

-- | Curve constant d
curveD :: Integer
curveD :: Integer
curveD = (-Integer
121665 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer
modp_inv Integer
121666) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p

-- | Group order
q :: Integer
q :: Integer
q = Integer
2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
252 ::Int) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
27742317777372353535851937790883648493

-- | Base Point in extended form
pG :: ExtendedPoint
pG :: ExtendedPoint
pG = Integer -> Integer -> Integer -> Integer -> ExtendedPoint
ExtendedPoint Integer
g_x Integer
g_y Integer
1 ((Integer
g_x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
g_y) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p)
  where
    !g_y :: Integer
g_y = (Integer
4 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer
modp_inv Integer
5) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    !g_x :: Integer
g_x = Integer -> Bool -> Integer
recoverX Integer
g_y Bool
False

sha512_modq :: B.ByteArrayAccess ba => ba -> Fq
sha512_modq :: ba -> Fq
sha512_modq ba
bs =
    Integer -> Fq
Fq (ByteString -> Integer
fromBytes (Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (ba -> Digest SHA512
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
hash ba
bs :: Digest SHA512)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
q)