{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Socket (socketAsMuxBearer) where

import           Control.Monad (when)
import           Control.Tracer
import qualified Data.ByteString.Lazy as BL
import           Data.Int

import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer hiding (timeout)

import qualified Network.Socket as Socket
#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as Socket (recv, sendAll)
#else
import qualified System.Win32.Async.Socket.ByteString.Lazy as Win32.Async
#endif

import qualified Network.Mux as Mx
import qualified Network.Mux.Codec as Mx
import qualified Network.Mux.Time as Mx
import qualified Network.Mux.Timeout as Mx
import qualified Network.Mux.Trace as Mx
import           Network.Mux.Types (MuxBearer)
import qualified Network.Mux.Types as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import           Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif

-- |
-- Create @'MuxBearer'@ from a socket.
--
-- On Windows 'System.Win32.Async` operations are used to read and write from
-- a socket.  This means that the socket must be associated with the I/O
-- completion port with
-- 'System.Win32.Async.IOManager.associateWithIOCompletionPort'.
--
-- Note: 'IOException's thrown by 'sendAll' and 'recv' are wrapped in
-- 'MuxError'.
--
socketAsMuxBearer
  :: DiffTime
  -> Tracer IO Mx.MuxTrace
  -> Socket.Socket
  -> MuxBearer IO
socketAsMuxBearer :: DiffTime -> Tracer IO MuxTrace -> Socket -> MuxBearer IO
socketAsMuxBearer DiffTime
sduTimeout Tracer IO MuxTrace
tracer Socket
sd =
      MuxBearer :: forall (m :: * -> *).
(TimeoutFn m -> MuxSDU -> m Time)
-> (TimeoutFn m -> m (MuxSDU, Time)) -> SDUSize -> MuxBearer m
Mx.MuxBearer {
        read :: TimeoutFn IO -> IO (MuxSDU, Time)
Mx.read    = TimeoutFn IO -> IO (MuxSDU, Time)
readSocket,
        write :: TimeoutFn IO -> MuxSDU -> IO Time
Mx.write   = TimeoutFn IO -> MuxSDU -> IO Time
writeSocket,
        sduSize :: SDUSize
Mx.sduSize = Word16 -> SDUSize
Mx.SDUSize Word16
12288
      }
    where
      hdrLenght :: Int64
hdrLenght = Int64
8

      readSocket :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time)
      readSocket :: TimeoutFn IO -> IO (MuxSDU, Time)
readSocket TimeoutFn IO
timeout = do
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceRecvHeaderStart

          -- Wait for the first part of the header without any timeout
          ByteString
h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
hdrLenght

          -- Optionally wait at most sduTimeout seconds for the complete SDU.
          Maybe (MuxSDU, Time)
r_m <- DiffTime -> IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time))
TimeoutFn IO
timeout DiffTime
sduTimeout (IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time)))
-> IO (MuxSDU, Time) -> IO (Maybe (MuxSDU, Time))
forall a b. (a -> b) -> a -> b
$ ByteString -> IO (MuxSDU, Time)
recvRem ByteString
h0
          case Maybe (MuxSDU, Time)
r_m of
                Maybe (MuxSDU, Time)
Nothing -> do
                    Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSDUReadTimeoutException
                    MuxError -> IO (MuxSDU, Time)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO (MuxSDU, Time)) -> MuxError -> IO (MuxSDU, Time)
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxSDUReadTimeout String
"Mux SDU Timeout"
                Just (MuxSDU, Time)
r -> (MuxSDU, Time) -> IO (MuxSDU, Time)
forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU, Time)
r

      recvRem :: BL.ByteString -> IO (Mx.MuxSDU, Time)
      recvRem :: ByteString -> IO (MuxSDU, Time)
recvRem !ByteString
h0 = do
          ByteString
hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
hdrLenght Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
          case ByteString -> Either MuxError MuxSDU
Mx.decodeMuxSDU ByteString
hbuf of
               Left  MuxError
e ->  MuxError -> IO (MuxSDU, Time)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO MuxError
e
               Right header :: MuxSDU
header@Mx.MuxSDU { MuxSDUHeader
msHeader :: MuxSDU -> MuxSDUHeader
msHeader :: MuxSDUHeader
Mx.msHeader } -> do
                   Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceRecvHeaderEnd MuxSDUHeader
msHeader
                   !ByteString
blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> Word16
Mx.mhLength MuxSDUHeader
msHeader) []

                   !Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
                   let !header' :: MuxSDU
header' = MuxSDU
header {msBlob :: ByteString
Mx.msBlob = ByteString
blob}
                   Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxSDUHeader -> Time -> MuxTrace
Mx.MuxTraceRecvDeltaQObservation MuxSDUHeader
msHeader Time
ts)
                   (MuxSDU, Time) -> IO (MuxSDU, Time)
forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU
header', Time
ts)

      recvLen' ::  Int64 -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int64
l [ByteString]
bufs = do
          ByteString
buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
          Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
l Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
buf) (ByteString
buf ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bufs)

      recvAtMost :: Bool -> Int64 -> IO BL.ByteString
      recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvStart (Int -> MuxTrace) -> Int -> MuxTrace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
#if defined(mingw32_HOST_OS)
          buf <- Win32.Async.recv sd (fromIntegral l)
#else
          ByteString
buf <- Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
#endif
                    IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
          if ByteString -> Bool
BL.null ByteString
buf
              then do
                  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
waitingOnNxtHeader) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                      {- This may not be an error, but could be an orderly shutdown.
                       - We wait 1 seconds to give the mux protocols time to perform
                       - a clean up and exit.
                       -}
                      DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
1
                  MuxError -> IO ByteString
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO ByteString) -> MuxError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxBearerClosed (Socket -> String
forall a. Show a => a -> String
show Socket
sd String -> String -> String
forall a. [a] -> [a] -> [a]
++
                      String
" closed when reading data, waiting on next header " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                      Bool -> String
forall a. Show a => a -> String
show Bool
waitingOnNxtHeader)
              else do
                  Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvEnd (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
buf)
                  ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buf

      writeSocket :: Mx.TimeoutFn IO -> Mx.MuxSDU -> IO Time
      writeSocket :: TimeoutFn IO -> MuxSDU -> IO Time
writeSocket TimeoutFn IO
timeout MuxSDU
sdu = do
          Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 :: Word32
ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' :: MuxSDU
sdu' = MuxSDU -> RemoteClockModel -> MuxSDU
Mx.setTimestamp MuxSDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf :: ByteString
buf  = MuxSDU -> ByteString
Mx.encodeMuxSDU MuxSDU
sdu'
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceSendStart (MuxSDU -> MuxSDUHeader
Mx.msHeader MuxSDU
sdu')
          Maybe ()
r <- DiffTime -> IO () -> IO (Maybe ())
TimeoutFn IO
timeout DiffTime
sduTimeout (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$
#if defined(mingw32_HOST_OS)
              Win32.Async.sendAll sd buf
#else
              Socket -> ByteString -> IO ()
Socket.sendAll Socket
sd ByteString
buf
#endif
              IO () -> (IOException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ()
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"sendAll errored"
          case Maybe ()
r of
               Maybe ()
Nothing -> do
                    Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSDUWriteTimeoutException
                    MuxError -> IO Time
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO Time) -> MuxError -> IO Time
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxSDUWriteTimeout String
"Mux SDU Timeout"
               Just ()
_ -> do
                   Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the MuxTraceTCPInfo was
                   -- enable we woulnd't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.MuxTraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
                   Time -> IO Time
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts