{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An extension of 'Network.TypedProtocol.Channel', with additional 'Channel'
-- implementations.
--
module Network.Mux.Channel
  ( Channel (..)
  , createBufferConnectedChannels
  , createPipeConnectedChannels
#if !defined(mingw32_HOST_OS)
  , createSocketConnectedChannels
#endif
  , withFifosAsChannel
  , socketAsChannel
  , channelEffect
  , delayChannel
  , loggingChannel
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Lazy.Internal as LBS (smallChunkSize)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as Socket
import qualified System.IO as IO (Handle, IOMode (..), hFlush, hIsEOF, withFile)
import qualified System.Process as IO (createPipe)

import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadSay
import           Control.Monad.Class.MonadTimer


data Channel m = Channel {

    -- | Write bytes to the channel.
    --
    -- It maybe raise exceptions.
    --
    Channel m -> ByteString -> m ()
send :: LBS.ByteString -> m (),

    -- | Read some input from the channel, or @Nothing@ to indicate EOF.
    --
    -- Note that having received EOF it is still possible to send.
    -- The EOF condition is however monotonic.
    --
    -- It may raise exceptions (as appropriate for the monad and kind of
    -- channel).
    --
    Channel m -> m (Maybe ByteString)
recv :: m (Maybe LBS.ByteString)
  }


-- | Make a 'Channel' from a pair of IO 'Handle's, one for reading and one
-- for writing.
--
-- The Handles should be open in the appropriate read or write mode, and in
-- binary mode. Writes are flushed after each write, so it is safe to use
-- a buffering mode.
--
-- For bidirectional handles it is safe to pass the same handle for both.
--
handlesAsChannel :: IO.Handle -- ^ Read handle
                 -> IO.Handle -- ^ Write handle
                 -> Channel IO
handlesAsChannel :: Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndRead Handle
hndWrite =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunk = do
      Handle -> ByteString -> IO ()
LBS.hPut Handle
hndWrite ByteString
chunk
      Handle -> IO ()
IO.hFlush Handle
hndWrite

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      Bool
eof <- Handle -> IO Bool
IO.hIsEOF Handle
hndRead
      if Bool
eof
        then Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (ByteString -> ByteString) -> ByteString -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> Int -> IO ByteString
BS.hGetSome Handle
hndRead Int
LBS.smallChunkSize

-- | Create a pair of 'Channel's that are connected internally.
--
-- This is intended for inter-thread communication, such as between a
-- multiplexing thread and a thread running a peer.
--
-- It uses lazy 'ByteString's but it ensures that data written to the channel
-- is /fully evaluated/ first. This ensures that any work to serialise the data
-- takes place on the /writer side and not the reader side/.
--
createBufferConnectedChannels :: forall m. MonadSTM m
                              => m (Channel m,
                                    Channel m)
createBufferConnectedChannels :: m (Channel m, Channel m)
createBufferConnectedChannels = do
    TMVar m ByteString
bufferA <- m (TMVar m ByteString)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
newEmptyTMVarIO
    TMVar m ByteString
bufferB <- m (TMVar m ByteString)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
newEmptyTMVarIO

    (Channel m, Channel m) -> m (Channel m, Channel m)
forall (m :: * -> *) a. Monad m => a -> m a
return (TMVar m ByteString -> TMVar m ByteString -> Channel m
buffersAsChannel TMVar m ByteString
bufferB TMVar m ByteString
bufferA,
            TMVar m ByteString -> TMVar m ByteString -> Channel m
buffersAsChannel TMVar m ByteString
bufferA TMVar m ByteString
bufferB)
  where
    buffersAsChannel :: TMVar m ByteString -> TMVar m ByteString -> Channel m
buffersAsChannel TMVar m ByteString
bufferRead TMVar m ByteString
bufferWrite =
        Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: ByteString -> m ()
send, m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv}
      where
        send :: LBS.ByteString -> m ()
        send :: ByteString -> m ()
send ByteString
x = [m ()] -> m ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar m ByteString -> ByteString -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> a -> STM m ()
putTMVar TMVar m ByteString
bufferWrite ByteString
c)
                           | !ByteString
c <- ByteString -> [ByteString]
LBS.toChunks ByteString
x ]
                           -- Evaluate the chunk c /before/ doing the STM
                           -- transaction to write it to the buffer.

        recv :: m (Maybe LBS.ByteString)
        recv :: m (Maybe ByteString)
recv   = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (ByteString -> ByteString) -> ByteString -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict (ByteString -> Maybe ByteString)
-> m ByteString -> m (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m ByteString -> m ByteString
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar m ByteString -> STM m ByteString
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> STM m a
takeTMVar TMVar m ByteString
bufferRead)


-- | Create a local pipe, with both ends in this process, and expose that as
-- a pair of 'Channel's, one for each end.
--
-- This is primarily for testing purposes since it does not allow actual IPC.
--
createPipeConnectedChannels :: IO (Channel IO,
                                   Channel IO)
createPipeConnectedChannels :: IO (Channel IO, Channel IO)
createPipeConnectedChannels = do
    -- Create two pipes (each one is unidirectional) to make both ends of
    -- a bidirectional channel
    (Handle
hndReadA, Handle
hndWriteB) <- IO (Handle, Handle)
IO.createPipe
    (Handle
hndReadB, Handle
hndWriteA) <- IO (Handle, Handle)
IO.createPipe

    (Channel IO, Channel IO) -> IO (Channel IO, Channel IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndReadA Handle
hndWriteA,
            Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndReadB Handle
hndWriteB)

-- | Open a pair of Unix FIFOs, and expose that as a 'Channel'.
--
-- The peer process needs to open the same files but the other way around,
-- for writing and reading.
--
-- This is primarily for the purpose of demonstrations that use communication
-- between multiple local processes. It is Unix specific.
--
withFifosAsChannel :: FilePath -- ^ FIFO for reading
                   -> FilePath -- ^ FIFO for writing
                   -> (Channel IO -> IO a) -> IO a
withFifosAsChannel :: FilePath -> FilePath -> (Channel IO -> IO a) -> IO a
withFifosAsChannel FilePath
fifoPathRead FilePath
fifoPathWrite Channel IO -> IO a
action =
    FilePath -> IOMode -> (Handle -> IO a) -> IO a
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
fifoPathRead  IOMode
IO.ReadMode  ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndRead  ->
    FilePath -> IOMode -> (Handle -> IO a) -> IO a
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
fifoPathWrite IOMode
IO.WriteMode ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndWrite ->
      let channel :: Channel IO
channel = Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndRead Handle
hndWrite
       in Channel IO -> IO a
action Channel IO
channel


-- | Make a 'Channel' from a 'Socket'. The socket must be a stream socket
--- type and status connected.
---
socketAsChannel :: Socket.Socket -> Channel IO
socketAsChannel :: Socket -> Channel IO
socketAsChannel Socket
socket =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunks =
     -- Use vectored writes.
     Socket -> [ByteString] -> IO ()
Socket.sendMany Socket
socket (ByteString -> [ByteString]
LBS.toChunks ByteString
chunks)
     -- TODO: limit write sizes, or break them into multiple sends.

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      -- We rely on the behaviour of stream sockets that a zero length chunk
      -- indicates EOF.
      ByteString
chunk <- Socket -> Int -> IO ByteString
Socket.recv Socket
socket Int
LBS.smallChunkSize
      if ByteString -> Bool
BS.null ByteString
chunk
        then Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        else Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> ByteString
LBS.fromStrict ByteString
chunk))

#if !defined(mingw32_HOST_OS)
--- | Create a local socket, with both ends in this process, and expose that as
--- a pair of 'ByteChannel's, one for each end.
---
--- This is primarily for testing purposes since it does not allow actual IPC.
---
createSocketConnectedChannels :: Socket.Family -- ^ Usually AF_UNIX or AF_INET
                              -> IO (Channel IO,
                                     Channel IO)
createSocketConnectedChannels :: Family -> IO (Channel IO, Channel IO)
createSocketConnectedChannels Family
family = do
   -- Create a socket pair to make both ends of a bidirectional channel
   (Socket
socketA, Socket
socketB) <- Family -> SocketType -> ProtocolNumber -> IO (Socket, Socket)
Socket.socketPair Family
family SocketType
Socket.Stream
                                           ProtocolNumber
Socket.defaultProtocol

   (Channel IO, Channel IO) -> IO (Channel IO, Channel IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Channel IO
socketAsChannel Socket
socketA,
           Socket -> Channel IO
socketAsChannel Socket
socketB)
#endif

channelEffect :: forall m.
                 Monad m
              => (LBS.ByteString -> m ())       -- ^ Action before 'send'
              -> (Maybe LBS.ByteString -> m ()) -- ^ Action after 'recv'
              -> Channel m
              -> Channel m
channelEffect :: (ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
channelEffect ByteString -> m ()
beforeSend Maybe ByteString -> m ()
afterRecv Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: forall (m :: * -> *). Channel m -> ByteString -> m ()
send, m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: forall (m :: * -> *). Channel m -> m (Maybe ByteString)
recv} =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{
      send :: ByteString -> m ()
send = \ByteString
x -> do
        ByteString -> m ()
beforeSend ByteString
x
        ByteString -> m ()
send ByteString
x

    , recv :: m (Maybe ByteString)
recv = do
        Maybe ByteString
mx <- m (Maybe ByteString)
recv
        Maybe ByteString -> m ()
afterRecv Maybe ByteString
mx
        Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
mx
    }

-- | Delay a channel on the receiver end.
--
-- This is intended for testing, as a crude approximation of network delays.
-- More accurate models along these lines are of course possible.
--
delayChannel :: ( MonadSTM m
                , MonadTimer m
                )
             => DiffTime
             -> Channel m
             -> Channel m
delayChannel :: DiffTime -> Channel m -> Channel m
delayChannel DiffTime
delay = (ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
forall (m :: * -> *).
Monad m =>
(ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
channelEffect (\ByteString
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                                   (\Maybe ByteString
_ -> DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay)

-- | Channel which logs sent and received messages.
--
loggingChannel :: ( MonadSay m
                  , Show id
                  )
               => id
               -> Channel m
               -> Channel m
loggingChannel :: id -> Channel m -> Channel m
loggingChannel id
ident Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: forall (m :: * -> *). Channel m -> ByteString -> m ()
send,m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: forall (m :: * -> *). Channel m -> m (Maybe ByteString)
recv} =
  Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel {
    send :: ByteString -> m ()
send = ByteString -> m ()
loggingSend,
    recv :: m (Maybe ByteString)
recv = m (Maybe ByteString)
loggingRecv
  }
 where
  loggingSend :: ByteString -> m ()
loggingSend ByteString
a = do
    FilePath -> m ()
forall (m :: * -> *). MonadSay m => FilePath -> m ()
say (id -> FilePath
forall a. Show a => a -> FilePath
show id
ident FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
":send:" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ ByteString -> FilePath
forall a. Show a => a -> FilePath
show ByteString
a)
    ByteString -> m ()
send ByteString
a

  loggingRecv :: m (Maybe ByteString)
loggingRecv = do
    Maybe ByteString
msg <- m (Maybe ByteString)
recv
    case Maybe ByteString
msg of
      Maybe ByteString
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just ByteString
a  -> FilePath -> m ()
forall (m :: * -> *). MonadSay m => FilePath -> m ()
say (id -> FilePath
forall a. Show a => a -> FilePath
show id
ident FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
":recv:" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ ByteString -> FilePath
forall a. Show a => a -> FilePath
show ByteString
a)
    Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
msg