-- copied & adapted from cryptic
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ViewPatterns        #-}
module Crypto.Math.Bytes
    ( Bytes
    , Endian(..)
    , pack
    , packSome
    , unpack
    , fromBits
    , toBits
    , append
    , take
    , drop
    , splitHalf
    , trace
    ) where

import           Data.Proxy
import           Data.Word
import           Data.List (foldl')
import           GHC.Natural
import           GHC.TypeLits
import           Crypto.Math.NatMath
import           Data.Bits (shiftL)
import           Crypto.Math.Bits (FBits(..))
import           Prelude hiding (take, drop)
import qualified Prelude
import qualified Debug.Trace as Trace

newtype Bytes (n :: Nat) = Bytes { Bytes n -> [Word8]
unpack :: [Word8] }
    deriving (Int -> Bytes n -> ShowS
[Bytes n] -> ShowS
Bytes n -> String
(Int -> Bytes n -> ShowS)
-> (Bytes n -> String) -> ([Bytes n] -> ShowS) -> Show (Bytes n)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (n :: Nat). Int -> Bytes n -> ShowS
forall (n :: Nat). [Bytes n] -> ShowS
forall (n :: Nat). Bytes n -> String
showList :: [Bytes n] -> ShowS
$cshowList :: forall (n :: Nat). [Bytes n] -> ShowS
show :: Bytes n -> String
$cshow :: forall (n :: Nat). Bytes n -> String
showsPrec :: Int -> Bytes n -> ShowS
$cshowsPrec :: forall (n :: Nat). Int -> Bytes n -> ShowS
Show,Bytes n -> Bytes n -> Bool
(Bytes n -> Bytes n -> Bool)
-> (Bytes n -> Bytes n -> Bool) -> Eq (Bytes n)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (n :: Nat). Bytes n -> Bytes n -> Bool
/= :: Bytes n -> Bytes n -> Bool
$c/= :: forall (n :: Nat). Bytes n -> Bytes n -> Bool
== :: Bytes n -> Bytes n -> Bool
$c== :: forall (n :: Nat). Bytes n -> Bytes n -> Bool
Eq)

data Endian = LittleEndian | BigEndian
    deriving (Int -> Endian -> ShowS
[Endian] -> ShowS
Endian -> String
(Int -> Endian -> ShowS)
-> (Endian -> String) -> ([Endian] -> ShowS) -> Show Endian
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Endian] -> ShowS
$cshowList :: [Endian] -> ShowS
show :: Endian -> String
$cshow :: Endian -> String
showsPrec :: Int -> Endian -> ShowS
$cshowsPrec :: Int -> Endian -> ShowS
Show,Endian -> Endian -> Bool
(Endian -> Endian -> Bool)
-> (Endian -> Endian -> Bool) -> Eq Endian
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Endian -> Endian -> Bool
$c/= :: Endian -> Endian -> Bool
== :: Endian -> Endian -> Bool
$c== :: Endian -> Endian -> Bool
Eq)

pack :: forall n . KnownNat n => [Word8] -> Bytes n
pack :: [Word8] -> Bytes n
pack [Word8]
l
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l
    | Bool
otherwise = String -> Bytes n
forall a. HasCallStack => String -> a
error String
"packing failed: length not as expected"
  where
    len :: Int
len = [Word8] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Word8]
l
    n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)

packSome :: (forall n . KnownNat n => Bytes n -> a) -> [Word8] -> a
packSome :: (forall (n :: Nat). KnownNat n => Bytes n -> a) -> [Word8] -> a
packSome forall (n :: Nat). KnownNat n => Bytes n -> a
f [Word8]
l = case Integer -> Maybe SomeNat
someNatVal (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) of
                    Maybe SomeNat
Nothing          -> String -> a
forall a. HasCallStack => String -> a
error String
"impossible"
                    Just (SomeNat (Proxy n
_ :: Proxy n)) -> Bytes n -> a
forall (n :: Nat). KnownNat n => Bytes n -> a
f ([Word8] -> Bytes n
forall (n :: Nat). KnownNat n => [Word8] -> Bytes n
pack [Word8]
l :: Bytes n)
  where len :: Int
len = [Word8] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Word8]
l

fixupBytes :: Endian -> [Word8] -> [Word8]
fixupBytes :: Endian -> [Word8] -> [Word8]
fixupBytes Endian
LittleEndian = [Word8] -> [Word8]
forall a. [a] -> [a]
reverse
fixupBytes Endian
BigEndian    = [Word8] -> [Word8]
forall a. a -> a
id

trace :: String -> Bytes n -> Bytes n
trace :: String -> Bytes n -> Bytes n
trace String
cmd b :: Bytes n
b@(Bytes [Word8]
l) = String -> Bytes n -> Bytes n
forall a. String -> a -> a
Trace.trace (String
cmd String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Word8 -> String) -> [Word8] -> String
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Word8 -> String
forall a a. (Integral a, Enum a) => a -> [a]
toHex [Word8]
l) Bytes n
b
  where
    toHex :: a -> [a]
toHex a
w = let (a
x,a
y) = a
w a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`divMod` a
16 in [a -> a
forall a p. (Integral a, Enum p) => a -> p
hex a
x, a -> a
forall a p. (Integral a, Enum p) => a -> p
hex a
y]
    hex :: a -> p
hex a
i | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
10    = Int -> p
forall a. Enum a => Int -> a
toEnum (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
'0' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
i)
          | Bool
otherwise = Int -> p
forall a. Enum a => Int -> a
toEnum (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
'a' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
ia -> a -> a
forall a. Num a => a -> a -> a
-a
10))

-- | transform bytes into bits with a specific endianness
toBits :: Endian -> Bytes n -> FBits (n * 8)
toBits :: Endian -> Bytes n -> FBits (n * 8)
toBits Endian
endian (Bytes [Word8]
l) = Natural -> FBits (n * 8)
forall (n :: Nat). Natural -> FBits n
FBits (Natural -> FBits (n * 8)) -> Natural -> FBits (n * 8)
forall a b. (a -> b) -> a -> b
$
    (Natural -> Word8 -> Natural) -> Natural -> [Word8] -> Natural
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Natural
acc Word8
i -> (Natural
acc Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Word8 -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
i) Natural
0 (Endian -> [Word8] -> [Word8]
fixupBytes Endian
endian [Word8]
l)

-- | transform bits into bytes with a specific endianness
fromBits :: forall n . KnownNat n => Endian -> FBits n -> Bytes (Div8 n)
fromBits :: Endian -> FBits n -> Bytes (Div8 n)
fromBits Endian
endian (FBits n -> Natural
forall (n :: Nat). FBits n -> Natural
unFBits -> Natural
allBits) = [Word8] -> Bytes (Div8 n)
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes (Div8 n)) -> [Word8] -> Bytes (Div8 n)
forall a b. (a -> b) -> a -> b
$ [Word8] -> Word -> Natural -> [Word8]
loop [] (Word
0 :: Word) Natural
allBits
  where
    n :: Integer
n = Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
    loop :: [Word8] -> Word -> Natural -> [Word8]
loop [Word8]
acc Word
i Natural
nat
        | Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
n  = String -> [Word8]
forall a. HasCallStack => String -> a
error String
"binFromFBits over"
        | Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
n = Endian -> [Word8] -> [Word8]
fixupBytes Endian
endian [Word8]
acc
        | Bool
otherwise           =
            let (Natural
nat', Word8
b) = Natural -> (Natural, Word8)
divMod8 Natural
nat
             in [Word8] -> Word -> Natural -> [Word8]
loop (Word8
bWord8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
:[Word8]
acc) (Word
iWord -> Word -> Word
forall a. Num a => a -> a -> a
+Word
8) Natural
nat'

    divMod8 :: Natural -> (Natural, Word8)
    divMod8 :: Natural -> (Natural, Word8)
divMod8 Natural
i = let (Natural
q,Natural
m) = Natural
i Natural -> Natural -> (Natural, Natural)
forall a. Integral a => a -> a -> (a, a)
`divMod` Natural
256 in (Natural
q,Natural -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
m)


splitHalf :: forall m n . (KnownNat n, (n * 2) ~ m) => Bytes m -> (Bytes n, Bytes n)
splitHalf :: Bytes m -> (Bytes n, Bytes n)
splitHalf (Bytes [Word8]
l) = ([Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l1, [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes [Word8]
l2)
  where
    ([Word8]
l1, [Word8]
l2) = Int -> [Word8] -> ([Word8], [Word8])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Word8]
l
    n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)

append :: forall m n r . ((m + n) ~ r)
       => Bytes n -> Bytes m -> Bytes r
append :: Bytes n -> Bytes m -> Bytes r
append (Bytes [Word8]
a) (Bytes [Word8]
b) = [Word8] -> Bytes r
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8]
a [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8]
b)

take :: forall n m .(KnownNat n, n <= m) => Bytes m -> Bytes n
take :: Bytes m -> Bytes n
take (Bytes [Word8]
l) = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes n) -> [Word8] -> Bytes n
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
Prelude.take (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)) [Word8]
l

drop :: forall n m . (KnownNat m, KnownNat n, n <= m) => Bytes m -> Bytes n
drop :: Bytes m -> Bytes n
drop (Bytes [Word8]
l) = [Word8] -> Bytes n
forall (n :: Nat). [Word8] -> Bytes n
Bytes ([Word8] -> Bytes n) -> [Word8] -> Bytes n
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
Prelude.drop (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
diff) [Word8]
l
  where diff :: Integer
diff = Proxy m -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)