{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}

-- |Strict Decoder
module Flat.Decoder.Strict
  ( decodeArrayWith
  , decodeListWith
  , dByteString
  , dLazyByteString
  , dShortByteString
  , dShortByteString_
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
  , dUTF16
#endif
  , dUTF8
  , dInteger
  , dNatural
  , dChar
  , dWord8
  , dWord16
  , dWord32
  , dWord64
  , dWord
  , dInt8
  , dInt16
  , dInt32
  , dInt64
  , dInt
  ) where

import           Data.Bits
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as L
import qualified Data.ByteString.Short          as SBS
import qualified Data.ByteString.Short.Internal as SBS
import qualified Data.DList                     as DL
import           Flat.Decoder.Prim
import           Flat.Decoder.Types
import           Data.Int
import           Data.Primitive.ByteArray
import qualified Data.Text                      as T
import qualified Data.Text.Encoding             as T

#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
import qualified Data.Text.Array                as TA
import qualified Data.Text.Internal             as T
#endif

import           Data.Word
import           Data.ZigZag
import           GHC.Base                       (unsafeChr)
import           Numeric.Natural
#include "MachDeps.h"

{-# INLINE decodeListWith #-}
decodeListWith :: Get a -> Get [a]
decodeListWith :: Get a -> Get [a]
decodeListWith Get a
dec = Get [a]
go
  where
    go :: Get [a]
go = do
      Bool
b <- Get Bool
dBool
      if Bool
b
        then (:) (a -> [a] -> [a]) -> Get a -> Get ([a] -> [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a
dec Get ([a] -> [a]) -> Get [a] -> Get [a]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get [a]
go
        else [a] -> Get [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []

decodeArrayWith :: Get a -> Get [a]
decodeArrayWith :: Get a -> Get [a]
decodeArrayWith Get a
dec = DList a -> [a]
forall a. DList a -> [a]
DL.toList (DList a -> [a]) -> Get (DList a) -> Get [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a -> Get (DList a)
forall a. Get a -> Get (DList a)
getAsL_ Get a
dec

-- TODO: test if it would it be faster with DList.unfoldr :: (b -> Maybe (a, b)) -> b -> Data.DList.DList a
--  getAsL_ :: Flat a => Get (DL.DList a)
getAsL_ :: Get a -> Get (DL.DList a)
getAsL_ :: Get a -> Get (DList a)
getAsL_ Get a
dec = do
  Word8
tag <- Get Word8
dWord8
  case Word8
tag of
    Word8
0 -> DList a -> Get (DList a)
forall (m :: * -> *) a. Monad m => a -> m a
return DList a
forall a. DList a
DL.empty
    Word8
_ -> do
      DList a
h <- Word8 -> Get (DList a)
forall t. (Eq t, Num t) => t -> Get (DList a)
gets Word8
tag
      DList a
t <- Get a -> Get (DList a)
forall a. Get a -> Get (DList a)
getAsL_ Get a
dec
      DList a -> Get (DList a)
forall (m :: * -> *) a. Monad m => a -> m a
return (DList a -> DList a -> DList a
forall a. DList a -> DList a -> DList a
DL.append DList a
h DList a
t)
  where
    gets :: t -> Get (DList a)
gets t
0 = DList a -> Get (DList a)
forall (m :: * -> *) a. Monad m => a -> m a
return DList a
forall a. DList a
DL.empty
    gets t
n = a -> DList a -> DList a
forall a. a -> DList a -> DList a
DL.cons (a -> DList a -> DList a) -> Get a -> Get (DList a -> DList a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a
dec Get (DList a -> DList a) -> Get (DList a) -> Get (DList a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t -> Get (DList a)
gets (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1)

{-# INLINE dNatural #-}
dNatural :: Get Natural
dNatural :: Get Natural
dNatural = Get Natural
forall b. (Num b, Bits b) => Get b
dUnsigned

{-# INLINE dInteger #-}
dInteger :: Get Integer
dInteger :: Get Integer
dInteger = Natural -> Integer
forall signed unsigned.
ZigZag signed unsigned =>
unsigned -> signed
zagZig (Natural -> Integer) -> Get Natural -> Get Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Natural
forall b. (Num b, Bits b) => Get b
dUnsigned

{-# INLINE dWord #-}
{-# INLINE dInt #-}
dWord :: Get Word
dInt :: Get Int
#if WORD_SIZE_IN_BITS == 64
dWord :: Get Word
dWord = (Word64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word64 -> Word) (Word64 -> Word) -> Get Word64 -> Get Word
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word64
dWord64

dInt :: Get Int
dInt = (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int64 -> Int) (Int64 -> Int) -> Get Int64 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int64
dInt64
#elif WORD_SIZE_IN_BITS == 32
dWord = (fromIntegral :: Word32 -> Word) <$> dWord32

dInt = (fromIntegral :: Int32 -> Int) <$> dInt32
#else
#error expected WORD_SIZE_IN_BITS to be 32 or 64
#endif

{-# INLINE dInt8 #-}
dInt8 :: Get Int8
dInt8 :: Get Int8
dInt8 = Word8 -> Int8
forall signed unsigned.
ZigZag signed unsigned =>
unsigned -> signed
zagZig (Word8 -> Int8) -> Get Word8 -> Get Int8
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
dWord8

{-# INLINE dInt16 #-}
dInt16 :: Get Int16
dInt16 :: Get Int16
dInt16 = Word16 -> Int16
forall signed unsigned.
ZigZag signed unsigned =>
unsigned -> signed
zagZig (Word16 -> Int16) -> Get Word16 -> Get Int16
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word16
dWord16

{-# INLINE dInt32 #-}
dInt32 :: Get Int32
dInt32 :: Get Int32
dInt32 = Word32 -> Int32
forall signed unsigned.
ZigZag signed unsigned =>
unsigned -> signed
zagZig (Word32 -> Int32) -> Get Word32 -> Get Int32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
dWord32

{-# INLINE dInt64 #-}
dInt64 :: Get Int64
dInt64 :: Get Int64
dInt64 = Word64 -> Int64
forall signed unsigned.
ZigZag signed unsigned =>
unsigned -> signed
zagZig (Word64 -> Int64) -> Get Word64 -> Get Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word64
dWord64

-- {-# INLINE dWord16  #-}
dWord16 :: Get Word16
dWord16 :: Get Word16
dWord16 = Int -> (Word16 -> Get Word16) -> Word16 -> Get Word16
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
0 (Int -> (Word16 -> Get Word16) -> Word16 -> Get Word16
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
7 (Int -> Word16 -> Get Word16
forall b. (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep Int
14)) Word16
0

-- {-# INLINE dWord32  #-}
dWord32 :: Get Word32
dWord32 :: Get Word32
dWord32 = Int -> (Word32 -> Get Word32) -> Word32 -> Get Word32
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
0 (Int -> (Word32 -> Get Word32) -> Word32 -> Get Word32
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
7 (Int -> (Word32 -> Get Word32) -> Word32 -> Get Word32
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
14 (Int -> (Word32 -> Get Word32) -> Word32 -> Get Word32
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
21 (Int -> Word32 -> Get Word32
forall b. (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep Int
28)))) Word32
0

-- {-# INLINE dWord64  #-}
dWord64 :: Get Word64
dWord64 :: Get Word64
dWord64 =
  Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
    Int
0
    (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
       Int
7
       (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
          Int
14
          (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
             Int
21
             (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
                Int
28
                (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
                   Int
35
                   (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
                      Int
42
                      (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep
                         Int
49
                         (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
56 (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
63 (Int -> (Word64 -> Get Word64) -> Word64 -> Get Word64
forall a. (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep Int
70 (Int -> Word64 -> Get Word64
forall b. (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep Int
77)))))))))))
    Word64
0

{-# INLINE dChar #-}
dChar :: Get Char
-- dChar = chr . fromIntegral <$> dWord32
-- Not really faster than the simpler version above
dChar :: Get Char
dChar = Int -> (Int -> Get Char) -> Int -> Get Char
charStep Int
0 (Int -> (Int -> Get Char) -> Int -> Get Char
charStep Int
7 (Int -> Int -> Get Char
lastCharStep Int
14)) Int
0

{-# INLINE charStep #-}
charStep :: Int -> (Int -> Get Char) -> Int -> Get Char
charStep :: Int -> (Int -> Get Char) -> Int -> Get Char
charStep !Int
shl !Int -> Get Char
cont !Int
n = do
  !Int
tw <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
dWord8
  let !w :: Int
w = Int
tw Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
127
  let !v :: Int
v = Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
shl)
  if Int
tw Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w
    then Char -> Get Char
forall (m :: * -> *) a. Monad m => a -> m a
return (Char -> Get Char) -> Char -> Get Char
forall a b. (a -> b) -> a -> b
$ Int -> Char
unsafeChr Int
v
    else Int -> Get Char
cont Int
v

{-# INLINE lastCharStep #-}
lastCharStep :: Int -> Int -> Get Char
lastCharStep :: Int -> Int -> Get Char
lastCharStep !Int
shl !Int
n = do
  !Int
tw <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
dWord8
  let !w :: Int
w = Int
tw Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
127
  let !v :: Int
v = Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
shl)
  if Int
tw Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w
    then if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0x10FFFF
           then Int -> Get Char
forall (m :: * -> *) a a. (MonadFail m, Show a) => a -> m a
charErr Int
v
           else Char -> Get Char
forall (m :: * -> *) a. Monad m => a -> m a
return (Char -> Get Char) -> Char -> Get Char
forall a b. (a -> b) -> a -> b
$ Int -> Char
unsafeChr Int
v
    else Int -> Get Char
forall (m :: * -> *) a a. (MonadFail m, Show a) => a -> m a
charErr Int
v
 where 
  charErr :: a -> m a
charErr a
v = String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"Unexpected extra byte or non unicode char", a -> String
forall a. Show a => a -> String
show a
v]

{-# INLINE wordStep #-}
wordStep :: (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep :: Int -> (a -> Get a) -> a -> Get a
wordStep Int
shl a -> Get a
k a
n = do
  a
tw <- Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> a) -> Get Word8 -> Get a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
dWord8
  let w :: a
w = a
tw a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
127
  let v :: a
v = a
n a -> a -> a
forall a. Bits a => a -> a -> a
.|. (a
w a -> Int -> a
forall a. Bits a => a -> Int -> a
`shift` Int
shl)
  if a
tw a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
w
    then a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
    --else oneShot k v
    else a -> Get a
k a
v

{-# INLINE lastStep #-}
lastStep :: (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep :: Int -> b -> Get b
lastStep Int
shl b
n = do
  b
tw <- Word8 -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> b) -> Get Word8 -> Get b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
dWord8
  let w :: b
w = b
tw b -> b -> b
forall a. Bits a => a -> a -> a
.&. b
127
  let v :: b
v = b
n b -> b -> b
forall a. Bits a => a -> a -> a
.|. (b
w b -> Int -> b
forall a. Bits a => a -> Int -> a
`shift` Int
shl)
  if b
tw b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== b
w
    then if b -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros b
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
shl
           then b -> Get b
forall (m :: * -> *) a a. (MonadFail m, Show a) => a -> m a
wordErr b
v
           else b -> Get b
forall (m :: * -> *) a. Monad m => a -> m a
return b
v
    else b -> Get b
forall (m :: * -> *) a a. (MonadFail m, Show a) => a -> m a
wordErr b
v
 where 
   wordErr :: a -> m a
wordErr a
v = String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"Unexpected extra byte in unsigned integer", a -> String
forall a. Show a => a -> String
show a
v]

-- {-# INLINE dUnsigned #-}
dUnsigned :: (Num b, Bits b) => Get b
dUnsigned :: Get b
dUnsigned = do
  (b
v, Int
shl) <- Int -> b -> Get (b, Int)
forall t. (Bits t, Num t) => Int -> t -> Get (t, Int)
dUnsigned_ Int
0 b
0
  Get b -> (Int -> Get b) -> Maybe Int -> Get b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    (b -> Get b
forall (m :: * -> *) a. Monad m => a -> m a
return b
v)
    (\Int
s ->
       if Int
shl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
s
         then String -> Get b
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unexpected extra data in unsigned integer"
         else b -> Get b
forall (m :: * -> *) a. Monad m => a -> m a
return b
v) (Maybe Int -> Get b) -> Maybe Int -> Get b
forall a b. (a -> b) -> a -> b
$
    b -> Maybe Int
forall a. Bits a => a -> Maybe Int
bitSizeMaybe b
v

-- {-# INLINE dUnsigned_ #-}
dUnsigned_ :: (Bits t, Num t) => Int -> t -> Get (t, Int)
dUnsigned_ :: Int -> t -> Get (t, Int)
dUnsigned_ Int
shl t
n = do
  Word8
tw <- Get Word8
dWord8
  let w :: Word8
w = Word8
tw Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
127
  let v :: t
v = t
n t -> t -> t
forall a. Bits a => a -> a -> a
.|. (Word8 -> t
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w t -> Int -> t
forall a. Bits a => a -> Int -> a
`shift` Int
shl)
  if Word8
tw Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
w
    then (t, Int) -> Get (t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (t
v, Int
shl)
    else Int -> t -> Get (t, Int)
forall t. (Bits t, Num t) => Int -> t -> Get (t, Int)
dUnsigned_ (Int
shl Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) t
v
--encode = encode . blob UTF8Encoding . L.fromStrict . T.encodeUtf8
--decode = T.decodeUtf8 . L.toStrict . (unblob :: BLOB UTF8Encoding -> L.ByteString) <$> decode
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
-- BLOB UTF16Encoding
dUTF16 :: Get T.Text
dUTF16 :: Get Text
dUTF16 = do
  ()
_ <- Get ()
dFiller
  -- Checked decoding
  -- T.decodeUtf16LE <$> dByteString_
  -- Unchecked decoding
  (ByteArray ByteArray#
array, Int
lengthInBytes) <- Get (ByteArray, Int)
dByteArray_
  Text -> Get Text
forall (m :: * -> *) a. Monad m => a -> m a
return (Array -> Int -> Int -> Text
T.Text (ByteArray# -> Array
TA.Array ByteArray#
array) Int
0 (Int
lengthInBytes Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2))
#endif
dUTF8 :: Get T.Text
dUTF8 :: Get Text
dUTF8 = do
  ()
_ <- Get ()
dFiller
  ByteString
bs <- Get ByteString
dByteString_
  case ByteString -> Either UnicodeException Text
T.decodeUtf8' ByteString
bs of
    Right Text
t -> Text -> Get Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
t
    Left UnicodeException
e -> String -> Get Text
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get Text) -> String -> Get Text
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"Input contains invalid UTF-8 data", UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
e]
dFiller :: Get ()
dFiller :: Get ()
dFiller = do
  Bool
tag <- Get Bool
dBool
  case Bool
tag of
    Bool
False -> Get ()
dFiller
    Bool
True  -> () -> Get ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

dLazyByteString :: Get L.ByteString
dLazyByteString :: Get ByteString
dLazyByteString = Get ()
dFiller Get () -> Get ByteString -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Get ByteString
dLazyByteString_

dShortByteString :: Get SBS.ShortByteString
dShortByteString :: Get ShortByteString
dShortByteString = Get ()
dFiller Get () -> Get ShortByteString -> Get ShortByteString
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Get ShortByteString
dShortByteString_

dShortByteString_ :: Get SBS.ShortByteString
dShortByteString_ :: Get ShortByteString
dShortByteString_ = do
  (ByteArray ByteArray#
array, Int
_) <- Get (ByteArray, Int)
dByteArray_
  ShortByteString -> Get ShortByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ShortByteString -> Get ShortByteString)
-> ShortByteString -> Get ShortByteString
forall a b. (a -> b) -> a -> b
$ ByteArray# -> ShortByteString
SBS.SBS ByteArray#
array

dByteString :: Get B.ByteString
dByteString :: Get ByteString
dByteString = Get ()
dFiller Get () -> Get ByteString -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Get ByteString
dByteString_