{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Mux.Bearer.Socket (socketAsMuxBearer) where
import Control.Monad (when)
import Control.Tracer
import qualified Data.ByteString.Lazy as BL
import Data.Int
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer hiding (timeout)
import qualified Network.Socket as Socket
#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as Socket (recv, sendAll)
#else
import qualified System.Win32.Async.Socket.ByteString.Lazy as Win32.Async
#endif
import qualified Network.Mux as Mx
import qualified Network.Mux.Codec as Mx
import qualified Network.Mux.Time as Mx
import qualified Network.Mux.Timeout as Mx
import qualified Network.Mux.Trace as Mx
import Network.Mux.Types (MuxBearer)
import qualified Network.Mux.Types as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif
socketAsMuxBearer
:: DiffTime
-> Tracer IO Mx.MuxTrace
-> Socket.Socket
-> MuxBearer IO
socketAsMuxBearer :: DiffTime -> Tracer IO MuxTrace -> Socket -> MuxBearer IO
socketAsMuxBearer DiffTime
sduTimeout Tracer IO MuxTrace
tracer Socket
sd =
MuxBearer :: forall (m :: * -> *).
(TimeoutFn m -> MuxSDU -> m Time)
-> (TimeoutFn m -> m (MuxSDU, Time)) -> SDUSize -> MuxBearer m
Mx.MuxBearer {
read :: TimeoutFn IO -> IO (MuxSDU, Time)
Mx.read = TimeoutFn IO -> IO (MuxSDU, Time)
readSocket,
write :: TimeoutFn IO -> MuxSDU -> IO Time
Mx.write = TimeoutFn IO -> MuxSDU -> IO Time
writeSocket,
sduSize :: SDUSize
Mx.sduSize = Word16 -> SDUSize
Mx.SDUSize Word16
12288
}
where
hdrLenght :: Int64
hdrLenght = Int64
8
readSocket :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time)
readSocket :: TimeoutFn IO -> IO (MuxSDU, Time)
readSocket TimeoutFn IO
timeout = do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceRecvHeaderStart
ByteString
h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
hdrLenght
Maybe (MuxSDU, Time)
r_m <- DiffTime -> IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time))
TimeoutFn IO
timeout DiffTime
sduTimeout (IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time)))
-> IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time))
forall a b. (a -> b) -> a -> b
$ ByteString -> IO (MuxSDU, Time)
recvRem ByteString
h0
case Maybe (MuxSDU, Time)
r_m of
Maybe (MuxSDU, Time)
Nothing -> do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSDUReadTimeoutException
MuxError -> IO (MuxSDU, Time)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO (MuxSDU, Time)) -> MuxError -> IO (MuxSDU, Time)
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxSDUReadTimeout String
"Mux SDU Timeout"
Just (MuxSDU, Time)
r -> (MuxSDU, Time) -> IO (MuxSDU, Time)
forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU, Time)
r
recvRem :: BL.ByteString -> IO (Mx.MuxSDU, Time)
recvRem :: ByteString -> IO (MuxSDU, Time)
recvRem !ByteString
h0 = do
ByteString
hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
hdrLenght Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
case ByteString -> Either MuxError MuxSDU
Mx.decodeMuxSDU ByteString
hbuf of
Left MuxError
e -> MuxError -> IO (MuxSDU, Time)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO MuxError
e
Right header :: MuxSDU
header@Mx.MuxSDU { MuxSDUHeader
msHeader :: MuxSDU -> MuxSDUHeader
msHeader :: MuxSDUHeader
Mx.msHeader } -> do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceRecvHeaderEnd MuxSDUHeader
msHeader
!ByteString
blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> Word16
Mx.mhLength MuxSDUHeader
msHeader) []
!Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let !header' :: MuxSDU
header' = MuxSDU
header {msBlob :: ByteString
Mx.msBlob = ByteString
blob}
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxSDUHeader -> Time -> MuxTrace
Mx.MuxTraceRecvDeltaQObservation MuxSDUHeader
msHeader Time
ts)
(MuxSDU, Time) -> IO (MuxSDU, Time)
forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU
header', Time
ts)
recvLen' :: Int64 -> [BL.ByteString] -> IO BL.ByteString
recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
recvLen' Int64
l [ByteString]
bufs = do
ByteString
buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
l Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
buf) (ByteString
buf ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bufs)
recvAtMost :: Bool -> Int64 -> IO BL.ByteString
recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvStart (Int -> MuxTrace) -> Int -> MuxTrace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
#if defined(mingw32_HOST_OS)
buf <- Win32.Async.recv sd (fromIntegral l)
#else
ByteString
buf <- Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
#endif
IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
if ByteString -> Bool
BL.null ByteString
buf
then do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
waitingOnNxtHeader) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
1
MuxError -> IO ByteString
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO ByteString) -> MuxError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxBearerClosed (Socket -> String
forall a. Show a => a -> String
show Socket
sd String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
" closed when reading data, waiting on next header " String -> String -> String
forall a. [a] -> [a] -> [a]
++
Bool -> String
forall a. Show a => a -> String
show Bool
waitingOnNxtHeader)
else do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvEnd (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
buf)
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buf
writeSocket :: Mx.TimeoutFn IO -> Mx.MuxSDU -> IO Time
writeSocket :: TimeoutFn IO -> MuxSDU -> IO Time
writeSocket TimeoutFn IO
timeout MuxSDU
sdu = do
Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 :: Word32
ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
sdu' :: MuxSDU
sdu' = MuxSDU -> RemoteClockModel -> MuxSDU
Mx.setTimestamp MuxSDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
buf :: ByteString
buf = MuxSDU -> ByteString
Mx.encodeMuxSDU MuxSDU
sdu'
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceSendStart (MuxSDU -> MuxSDUHeader
Mx.msHeader MuxSDU
sdu')
Maybe ()
r <- DiffTime -> IO () -> IO (Maybe ())
TimeoutFn IO
timeout DiffTime
sduTimeout (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$
#if defined(mingw32_HOST_OS)
Win32.Async.sendAll sd buf
#else
Socket -> ByteString -> IO ()
Socket.sendAll Socket
sd ByteString
buf
#endif
IO () -> (IOException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ()
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"sendAll errored"
case Maybe ()
r of
Maybe ()
Nothing -> do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSDUWriteTimeoutException
MuxError -> IO Time
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO Time) -> MuxError -> IO Time
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxSDUWriteTimeout String
"Mux SDU Timeout"
Just ()
_ -> do
Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.MuxTraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
Time -> IO Time
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts