{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE CPP #-}
module Network.DNS.StateBinary (
    PState(..)
  , initialState
  , SPut
  , runSPut
  , put8
  , put16
  , put32
  , putInt8
  , putInt16
  , putInt32
  , putByteString
  , SGet
  , runSGet
  , runSGetWithLeftovers
  , get8
  , get16
  , get32
  , getInt8
  , getInt16
  , getInt32
  , getNByteString
  , getPosition
  , getInput
  , wsPop
  , wsPush
  , wsPosition
  , addPositionW
  , push
  , pop
  , getNBytes
  ) where

import Control.Monad.State (State, StateT)
import qualified Control.Monad.State as ST
import qualified Data.Attoparsec.ByteString as A
import qualified Data.Attoparsec.Types as T
import qualified Data.ByteString as BS
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.Map (Map)
import qualified Data.Map as M
import Data.Semigroup as Sem

import Network.DNS.Imports
import Network.DNS.Types

----------------------------------------------------------------

type SPut = State WState Builder

data WState = WState {
    WState -> Map Domain Int
wsDomain :: Map Domain Int
  , WState -> Int
wsPosition :: Int
}

initialWState :: WState
initialWState :: WState
initialWState = Map Domain Int -> Int -> WState
WState Map Domain Int
forall k a. Map k a
M.empty Int
0

instance Sem.Semigroup SPut where
    SPut
p1 <> :: SPut -> SPut -> SPut
<> SPut
p2 = Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
(Sem.<>) (Builder -> Builder -> Builder)
-> SPut -> StateT WState Identity (Builder -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SPut
p1 StateT WState Identity (Builder -> Builder) -> SPut -> SPut
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SPut
p2

instance Monoid SPut where
    mempty :: SPut
mempty = Builder -> SPut
forall (m :: * -> *) a. Monad m => a -> m a
return Builder
forall a. Monoid a => a
mempty
#if !(MIN_VERSION_base(4,11,0))
    mappend = (Sem.<>)
#endif

put8 :: Word8 -> SPut
put8 :: Word8 -> SPut
put8 = Int -> (Word8 -> Builder) -> Word8 -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
1 Word8 -> Builder
BB.word8

put16 :: Word16 -> SPut
put16 :: Word16 -> SPut
put16 = Int -> (Word16 -> Builder) -> Word16 -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
2 Word16 -> Builder
BB.word16BE

put32 :: Word32 -> SPut
put32 :: Word32 -> SPut
put32 = Int -> (Word32 -> Builder) -> Word32 -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
4 Word32 -> Builder
BB.word32BE

putInt8 :: Int -> SPut
putInt8 :: Int -> SPut
putInt8 = Int -> (Int -> Builder) -> Int -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
1 (Int8 -> Builder
BB.int8 (Int8 -> Builder) -> (Int -> Int8) -> Int -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int8
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putInt16 :: Int -> SPut
putInt16 :: Int -> SPut
putInt16 = Int -> (Int -> Builder) -> Int -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
2 (Int16 -> Builder
BB.int16BE (Int16 -> Builder) -> (Int -> Int16) -> Int -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putInt32 :: Int -> SPut
putInt32 :: Int -> SPut
putInt32 = Int -> (Int -> Builder) -> Int -> SPut
forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
4 (Int32 -> Builder
BB.int32BE (Int32 -> Builder) -> (Int -> Int32) -> Int -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putByteString :: ByteString -> SPut
putByteString :: Domain -> SPut
putByteString = (Domain -> Int) -> (Domain -> Builder) -> Domain -> SPut
forall a. (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized Domain -> Int
BS.length Domain -> Builder
BB.byteString

addPositionW :: Int -> State WState ()
addPositionW :: Int -> State WState ()
addPositionW Int
n = do
    (WState Map Domain Int
m Int
cur) <- StateT WState Identity WState
forall s (m :: * -> *). MonadState s m => m s
ST.get
    WState -> State WState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (WState -> State WState ()) -> WState -> State WState ()
forall a b. (a -> b) -> a -> b
$ Map Domain Int -> Int -> WState
WState Map Domain Int
m (Int
curInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n)

fixedSized :: Int -> (a -> Builder) -> a -> SPut
fixedSized :: Int -> (a -> Builder) -> a -> SPut
fixedSized Int
n a -> Builder
f a
a = do Int -> State WState ()
addPositionW Int
n
                      Builder -> SPut
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Builder
f a
a)

writeSized :: (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized :: (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized a -> Int
n a -> Builder
f a
a = do Int -> State WState ()
addPositionW (a -> Int
n a
a)
                      Builder -> SPut
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Builder
f a
a)

wsPop :: Domain -> State WState (Maybe Int)
wsPop :: Domain -> State WState (Maybe Int)
wsPop Domain
dom = do
    Map Domain Int
doms <- (WState -> Map Domain Int)
-> StateT WState Identity (Map Domain Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets WState -> Map Domain Int
wsDomain
    Maybe Int -> State WState (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Int -> State WState (Maybe Int))
-> Maybe Int -> State WState (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Domain -> Map Domain Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Domain
dom Map Domain Int
doms

wsPush :: Domain -> Int -> State WState ()
wsPush :: Domain -> Int -> State WState ()
wsPush Domain
dom Int
pos = do
    (WState Map Domain Int
m Int
cur) <- StateT WState Identity WState
forall s (m :: * -> *). MonadState s m => m s
ST.get
    WState -> State WState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (WState -> State WState ()) -> WState -> State WState ()
forall a b. (a -> b) -> a -> b
$ Map Domain Int -> Int -> WState
WState (Domain -> Int -> Map Domain Int -> Map Domain Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Domain
dom Int
pos Map Domain Int
m) Int
cur

----------------------------------------------------------------

type SGet = StateT PState (T.Parser ByteString)

data PState = PState {
    PState -> IntMap Domain
psDomain :: IntMap Domain
  , PState -> Int
psPosition :: Int
  , PState -> Domain
psInput :: ByteString
  }

----------------------------------------------------------------

getPosition :: SGet Int
getPosition :: SGet Int
getPosition = PState -> Int
psPosition (PState -> Int) -> StateT PState (Parser Domain) PState -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT PState (Parser Domain) PState
forall s (m :: * -> *). MonadState s m => m s
ST.get

getInput :: SGet ByteString
getInput :: SGet Domain
getInput = PState -> Domain
psInput (PState -> Domain)
-> StateT PState (Parser Domain) PState -> SGet Domain
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT PState (Parser Domain) PState
forall s (m :: * -> *). MonadState s m => m s
ST.get

addPosition :: Int -> SGet ()
addPosition :: Int -> SGet ()
addPosition Int
n = do
    PState IntMap Domain
dom Int
pos Domain
inp <- StateT PState (Parser Domain) PState
forall s (m :: * -> *). MonadState s m => m s
ST.get
    PState -> SGet ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (PState -> SGet ()) -> PState -> SGet ()
forall a b. (a -> b) -> a -> b
$ IntMap Domain -> Int -> Domain -> PState
PState IntMap Domain
dom (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) Domain
inp

push :: Int -> Domain -> SGet ()
push :: Int -> Domain -> SGet ()
push Int
n Domain
d = do
    PState IntMap Domain
dom Int
pos Domain
inp <- StateT PState (Parser Domain) PState
forall s (m :: * -> *). MonadState s m => m s
ST.get
    PState -> SGet ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (PState -> SGet ()) -> PState -> SGet ()
forall a b. (a -> b) -> a -> b
$ IntMap Domain -> Int -> Domain -> PState
PState (Int -> Domain -> IntMap Domain -> IntMap Domain
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
n Domain
d IntMap Domain
dom) Int
pos Domain
inp

pop :: Int -> SGet (Maybe Domain)
pop :: Int -> SGet (Maybe Domain)
pop Int
n = Int -> IntMap Domain -> Maybe Domain
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
n (IntMap Domain -> Maybe Domain)
-> (PState -> IntMap Domain) -> PState -> Maybe Domain
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PState -> IntMap Domain
psDomain (PState -> Maybe Domain)
-> StateT PState (Parser Domain) PState -> SGet (Maybe Domain)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT PState (Parser Domain) PState
forall s (m :: * -> *). MonadState s m => m s
ST.get

----------------------------------------------------------------

get8 :: SGet Word8
get8 :: SGet Word8
get8  = Parser Domain Word8 -> SGet Word8
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser Domain Word8
A.anyWord8 SGet Word8 -> SGet () -> SGet Word8
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
1

get16 :: SGet Word16
get16 :: SGet Word16
get16 = Parser Domain Word16 -> SGet Word16
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser Domain Word16
getWord16be SGet Word16 -> SGet () -> SGet Word16
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
2
  where
    word8' :: Parser Domain Word16
word8' = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word16) -> Parser Domain Word8 -> Parser Domain Word16
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Domain Word8
A.anyWord8
    getWord16be :: Parser Domain Word16
getWord16be = do
        Word16
a <- Parser Domain Word16
word8'
        Word16
b <- Parser Domain Word16
word8'
        Word16 -> Parser Domain Word16
forall (m :: * -> *) a. Monad m => a -> m a
return (Word16 -> Parser Domain Word16) -> Word16 -> Parser Domain Word16
forall a b. (a -> b) -> a -> b
$ Word16
a Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Word16
0x100 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
b

get32 :: SGet Word32
get32 :: SGet Word32
get32 = Parser Domain Word32 -> SGet Word32
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser Domain Word32
getWord32be SGet Word32 -> SGet () -> SGet Word32
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
4
  where
    word8' :: Parser Domain Word32
word8' = Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word32) -> Parser Domain Word8 -> Parser Domain Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Domain Word8
A.anyWord8
    getWord32be :: Parser Domain Word32
getWord32be = do
        Word32
a <- Parser Domain Word32
word8'
        Word32
b <- Parser Domain Word32
word8'
        Word32
c <- Parser Domain Word32
word8'
        Word32
d <- Parser Domain Word32
word8'
        Word32 -> Parser Domain Word32
forall (m :: * -> *) a. Monad m => a -> m a
return (Word32 -> Parser Domain Word32) -> Word32 -> Parser Domain Word32
forall a b. (a -> b) -> a -> b
$ Word32
a Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x1000000 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
b Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x10000 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
c Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x100 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
d

getInt8 :: SGet Int
getInt8 :: SGet Int
getInt8 = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> SGet Word8 -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word8
get8

getInt16 :: SGet Int
getInt16 :: SGet Int
getInt16 = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> SGet Word16 -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
get16

getInt32 :: SGet Int
getInt32 :: SGet Int
getInt32 = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> SGet Word32 -> SGet Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word32
get32

----------------------------------------------------------------

getNBytes :: Int -> SGet [Int]
getNBytes :: Int -> SGet [Int]
getNBytes Int
len = Domain -> [Int]
toInts (Domain -> [Int]) -> SGet Domain -> SGet [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet Domain
getNByteString Int
len
  where
    toInts :: Domain -> [Int]
toInts = (Word8 -> Int) -> [Word8] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Word8] -> [Int]) -> (Domain -> [Word8]) -> Domain -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Domain -> [Word8]
BS.unpack

getNByteString :: Int -> SGet ByteString
getNByteString :: Int -> SGet Domain
getNByteString Int
n = Parser Domain Domain -> SGet Domain
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift (Int -> Parser Domain Domain
A.take Int
n) SGet Domain -> SGet () -> SGet Domain
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
n

----------------------------------------------------------------

initialState :: ByteString -> PState
initialState :: Domain -> PState
initialState Domain
inp = IntMap Domain -> Int -> Domain -> PState
PState IntMap Domain
forall a. IntMap a
IM.empty Int
0 Domain
inp

runSGet :: SGet a -> ByteString -> Either DNSError (a, PState)
runSGet :: SGet a -> Domain -> Either DNSError (a, PState)
runSGet SGet a
parser Domain
inp = Result (a, PState) -> Either DNSError (a, PState)
forall r. Result r -> Either DNSError r
toResult (Result (a, PState) -> Either DNSError (a, PState))
-> Result (a, PState) -> Either DNSError (a, PState)
forall a b. (a -> b) -> a -> b
$ Parser (a, PState) -> Domain -> Result (a, PState)
forall a. Parser a -> Domain -> Result a
A.parse (SGet a -> PState -> Parser (a, PState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
ST.runStateT SGet a
parser (PState -> Parser (a, PState)) -> PState -> Parser (a, PState)
forall a b. (a -> b) -> a -> b
$ Domain -> PState
initialState Domain
inp) Domain
inp
  where
    toResult :: A.Result r -> Either DNSError r
    toResult :: Result r -> Either DNSError r
toResult (A.Done Domain
_ r
r)        = r -> Either DNSError r
forall a b. b -> Either a b
Right r
r
    toResult (A.Fail Domain
_ [String]
_ String
msg)    = DNSError -> Either DNSError r
forall a b. a -> Either a b
Left (DNSError -> Either DNSError r) -> DNSError -> Either DNSError r
forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError String
msg
    toResult (A.Partial Domain -> Result r
_)       = DNSError -> Either DNSError r
forall a b. a -> Either a b
Left (DNSError -> Either DNSError r) -> DNSError -> Either DNSError r
forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError String
"incomplete input"

runSGetWithLeftovers :: SGet a -> ByteString -> Either DNSError ((a, PState), ByteString)
runSGetWithLeftovers :: SGet a -> Domain -> Either DNSError ((a, PState), Domain)
runSGetWithLeftovers SGet a
parser Domain
inp = Result (a, PState) -> Either DNSError ((a, PState), Domain)
forall r. Result r -> Either DNSError (r, Domain)
toResult (Result (a, PState) -> Either DNSError ((a, PState), Domain))
-> Result (a, PState) -> Either DNSError ((a, PState), Domain)
forall a b. (a -> b) -> a -> b
$ Parser (a, PState) -> Domain -> Result (a, PState)
forall a. Parser a -> Domain -> Result a
A.parse (SGet a -> PState -> Parser (a, PState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
ST.runStateT SGet a
parser (PState -> Parser (a, PState)) -> PState -> Parser (a, PState)
forall a b. (a -> b) -> a -> b
$ Domain -> PState
initialState Domain
inp) Domain
inp
  where
    toResult :: A.Result r -> Either DNSError (r, ByteString)
    toResult :: Result r -> Either DNSError (r, Domain)
toResult (A.Done     Domain
i r
r) = (r, Domain) -> Either DNSError (r, Domain)
forall a b. b -> Either a b
Right (r
r, Domain
i)
    toResult (A.Partial  Domain -> Result r
f)   = Result r -> Either DNSError (r, Domain)
forall r. Result r -> Either DNSError (r, Domain)
toResult (Result r -> Either DNSError (r, Domain))
-> Result r -> Either DNSError (r, Domain)
forall a b. (a -> b) -> a -> b
$ Domain -> Result r
f Domain
BS.empty
    toResult (A.Fail Domain
_ [String]
_ String
err) = DNSError -> Either DNSError (r, Domain)
forall a b. a -> Either a b
Left (DNSError -> Either DNSError (r, Domain))
-> DNSError -> Either DNSError (r, Domain)
forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError String
err

runSPut :: SPut -> ByteString
runSPut :: SPut -> Domain
runSPut = ByteString -> Domain
LBS.toStrict (ByteString -> Domain) -> (SPut -> ByteString) -> SPut -> Domain
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BB.toLazyByteString (Builder -> ByteString) -> (SPut -> Builder) -> SPut -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SPut -> WState -> Builder) -> WState -> SPut -> Builder
forall a b c. (a -> b -> c) -> b -> a -> c
flip SPut -> WState -> Builder
forall s a. State s a -> s -> a
ST.evalState WState
initialWState