{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Pipe (
    PipeChannel (..)
  , pipeChannelFromHandles
#if defined(mingw32_HOST_OS)
  , pipeChannelFromNamedPipe
#endif
  , pipeAsMuxBearer
  ) where

import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Tracer
import qualified Data.ByteString.Lazy as BL
import           System.IO (Handle, hFlush)

#if defined(mingw32_HOST_OS)
import           Data.Foldable (traverse_)

import qualified System.Win32.Types as Win32 (HANDLE)
import qualified System.Win32.Async as Win32.Async
#endif

import qualified Network.Mux as Mx
import           Network.Mux.Types (MuxBearer)
import qualified Network.Mux.Types as Mx
import qualified Network.Mux.Trace as Mx
import qualified Network.Mux.Codec as Mx
import qualified Network.Mux.Time as Mx
import qualified Network.Mux.Timeout as Mx


-- | Abstraction over various types of handles.  We provide two instances:
--
--  * based on 'Handle': os independent, but will not work well on Windows,
--  * based on 'Win32.HANDLE': Windows specific.
--
data PipeChannel = PipeChannel {
    PipeChannel -> Int -> IO ByteString
readHandle  :: Int -> IO BL.ByteString,
    PipeChannel -> ByteString -> IO ()
writeHandle :: BL.ByteString -> IO ()
  }

pipeChannelFromHandles :: Handle
                       -- ^ read handle
                       -> Handle
                       -- ^ write handle
                       -> PipeChannel
pipeChannelFromHandles :: Handle -> Handle -> PipeChannel
pipeChannelFromHandles Handle
r Handle
w = PipeChannel :: (Int -> IO ByteString) -> (ByteString -> IO ()) -> PipeChannel
PipeChannel {
    readHandle :: Int -> IO ByteString
readHandle  = Handle -> Int -> IO ByteString
BL.hGet Handle
r,
    writeHandle :: ByteString -> IO ()
writeHandle = \ByteString
a -> Handle -> ByteString -> IO ()
BL.hPut Handle
w ByteString
a IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
w
  }

#if defined(mingw32_HOST_OS)
-- | Create a 'PipeChannel' from a named pipe.  This allows to emulate
-- anonymous pipes using named pipes on Windows.
--
pipeChannelFromNamedPipe :: Win32.HANDLE
                         -> PipeChannel
pipeChannelFromNamedPipe h = PipeChannel {
      readHandle  = fmap BL.fromStrict . Win32.Async.readHandle h,
      writeHandle = traverse_ (Win32.Async.writeHandle h) . BL.toChunks
    }
#endif

pipeAsMuxBearer
  :: Tracer IO Mx.MuxTrace
  -> PipeChannel
  -> MuxBearer IO
pipeAsMuxBearer :: Tracer IO MuxTrace -> PipeChannel -> MuxBearer IO
pipeAsMuxBearer Tracer IO MuxTrace
tracer PipeChannel
channel =
      MuxBearer :: forall (m :: * -> *).
(TimeoutFn m -> MuxSDU -> m Time)
-> (TimeoutFn m -> m (MuxSDU, Time)) -> SDUSize -> MuxBearer m
Mx.MuxBearer {
          read :: TimeoutFn IO -> IO (MuxSDU, Time)
Mx.read    = TimeoutFn IO -> IO (MuxSDU, Time)
readPipe,
          write :: TimeoutFn IO -> MuxSDU -> IO Time
Mx.write   = TimeoutFn IO -> MuxSDU -> IO Time
writePipe,
          sduSize :: SDUSize
Mx.sduSize = Word16 -> SDUSize
Mx.SDUSize Word16
32768
        }
    where
      readPipe :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time)
      readPipe :: TimeoutFn IO -> IO (MuxSDU, Time)
readPipe TimeoutFn IO
_ = do
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceRecvHeaderStart
          ByteString
hbuf <- Int -> [ByteString] -> IO ByteString
recvLen' Int
8 []
          case ByteString -> Either MuxError MuxSDU
Mx.decodeMuxSDU ByteString
hbuf of
              Left MuxError
e -> MuxError -> IO (MuxSDU, Time)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO MuxError
e
              Right header :: MuxSDU
header@Mx.MuxSDU { MuxSDUHeader
msHeader :: MuxSDU -> MuxSDUHeader
msHeader :: MuxSDUHeader
Mx.msHeader } -> do
                  Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceRecvHeaderEnd MuxSDUHeader
msHeader
                  ByteString
blob <- Int -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> Word16 -> Int
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> Word16
Mx.mhLength MuxSDUHeader
msHeader) []
                  Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
                  Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxSDUHeader -> Time -> MuxTrace
Mx.MuxTraceRecvDeltaQObservation MuxSDUHeader
msHeader Time
ts)
                  (MuxSDU, Time) -> IO (MuxSDU, Time)
forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU
header {msBlob :: ByteString
Mx.msBlob = ByteString
blob}, Time
ts)

      recvLen' :: Int -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int -> [ByteString] -> IO ByteString
recvLen' Int
0 [ByteString]
bufs = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int
l [ByteString]
bufs = do
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvStart Int
l
          ByteString
buf <- PipeChannel -> Int -> IO ByteString
readHandle PipeChannel
channel Int
l
                    IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"readHandle errored"
          if ByteString -> Bool
BL.null ByteString
buf
              then MuxError -> IO ByteString
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxError -> IO ByteString) -> MuxError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxBearerClosed String
"Pipe closed when reading data"
              else do
                  Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvEnd (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
buf)
                  Int -> [ByteString] -> IO ByteString
recvLen' (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int64
BL.length ByteString
buf)) (ByteString
buf ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bufs)

      writePipe :: Mx.TimeoutFn IO -> Mx.MuxSDU -> IO Time
      writePipe :: TimeoutFn IO -> MuxSDU -> IO Time
writePipe TimeoutFn IO
_ MuxSDU
sdu = do
          Time
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 :: Word32
ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' :: MuxSDU
sdu' = MuxSDU -> RemoteClockModel -> MuxSDU
Mx.setTimestamp MuxSDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf :: ByteString
buf  = MuxSDU -> ByteString
Mx.encodeMuxSDU MuxSDU
sdu'
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceSendStart (MuxSDU -> MuxSDUHeader
Mx.msHeader MuxSDU
sdu')
          PipeChannel -> ByteString -> IO ()
writeHandle PipeChannel
channel ByteString
buf
            IO () -> (IOException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ()
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"writeHandle errored"
          Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSendEnd
          Time -> IO Time
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts