{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}

module Network.Mux.Ingress
  ( -- $ingress
    demuxer
  ) where

import           Data.Array
import qualified Data.ByteString.Lazy as BL
import           Data.List (nub)
import           Text.Printf

import           Control.Monad
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM.Strict
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer hiding (timeout)

import           Network.Mux.Timeout
import           Network.Mux.Trace
import           Network.Mux.Types


flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir MiniProtocolDir
InitiatorDir = MiniProtocolDir
ResponderDir
flipMiniProtocolDir MiniProtocolDir
ResponderDir = MiniProtocolDir
InitiatorDir

-- $ingress
-- = Ingress Path
--
-- >                  ●
-- >                  │
-- >                  │ ByteStrings
-- >                  │
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│ Bearer.read() │░ Mux Bearer implementation (Socket, Pipes, etc.)
-- >         ░└───────────────┘░
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░         MuxSDUs
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│     demux     │░ For a given Mux Bearer there is a single demux
-- >         ░└───────┬───────┘░ thread reading from the underlying bearer.
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░
-- >        ░░░░░░░░░░▼░░░░░░░░░░
-- >        ░ ╭────┬────┬─────╮ ░ There is a limited queue (in bytes) for each mode
-- >        ░ │    │    │     │ ░ (responder/initiator) per miniprotocol. Overflowing
-- >        ░ ▼    ▼    ▼     ▼ ░ a queue is a protocol violation and a
-- >        ░│  │ │  │ │  │ │  │░ MuxIngressQueueOverRun exception is thrown
-- >        ░│ci│ │  │ │bi│ │br│░ and the bearer torn down.
-- >        ░│ci│ │cr│ │bi│ │br│░
-- >        ░└──┘ └──┘ └──┘ └──┘░ Every ingress queue has a dedicated thread which will
-- >        ░░│░░░░│░░░░│░░░░│░░░ read application encoded data from its queue.
-- >          │    │    │    │
-- >           application data
-- >          │    │    │    │
-- >          ▼    │    │    ▼
-- > ┌───────────┐ │    │  ┌───────────┐
-- > │ muxDuplex │ │    │  │ muxDuplex │
-- > │ Initiator │ │    │  │ Responder │
-- > │ ChainSync │ │    │  │ BlockFetch│
-- > └───────────┘ │    │  └───────────┘
-- >               ▼    ▼
-- >    ┌───────────┐  ┌───────────┐
-- >    │ muxDuplex │  │ muxDuplex │
-- >    │ Responder │  │ Initiator │
-- >    │ ChainSync │  │ BlockFetch│
-- >    └───────────┘  └───────────┘

-- | Each peer's multiplexer has some state that provides both
-- de-multiplexing details (for despatch of incoming messages to mini
-- protocols) and for dispatching incoming SDUs.  This is shared
-- between the muxIngress and the bearerIngress processes.
--
data MiniProtocolDispatch m =
     MiniProtocolDispatch
       !(Array MiniProtocolNum (Maybe MiniProtocolIx))
       !(Array (MiniProtocolIx, MiniProtocolDir)
               (MiniProtocolDispatchInfo m))

data MiniProtocolDispatchInfo m =
     MiniProtocolDispatchInfo
       !(IngressQueue m)
       !Int
   | MiniProtocolDirUnused


-- | demux runs as a single separate thread and reads complete 'MuxSDU's from
-- the underlying Mux Bearer and forwards it to the matching ingress queue.
demuxer :: (MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
            MonadTimer m, MonadTime m)
      => [MiniProtocolState mode m]
      -> MuxBearer m
      -> m void
demuxer :: [MiniProtocolState mode m] -> MuxBearer m -> m void
demuxer [MiniProtocolState mode m]
ptcls MuxBearer m
bearer =
  let !dispatchTable :: MiniProtocolDispatch m
dispatchTable = [MiniProtocolState mode m] -> MiniProtocolDispatch m
forall (mode :: MuxMode) (m :: * -> *).
[MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls in
  (TimeoutFn m -> m void) -> m void
forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
 MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerial ((TimeoutFn m -> m void) -> m void)
-> (TimeoutFn m -> m void) -> m void
forall a b. (a -> b) -> a -> b
$ \TimeoutFn m
timeout ->
  m () -> m void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m void) -> m () -> m void
forall a b. (a -> b) -> a -> b
$ do
    (MuxSDU
sdu, Time
_) <- MuxBearer m -> TimeoutFn m -> m (MuxSDU, Time)
forall (m :: * -> *).
MuxBearer m -> TimeoutFn m -> m (MuxSDU, Time)
Network.Mux.Types.read MuxBearer m
bearer TimeoutFn m
timeout
    -- say $ printf "demuxing sdu on mid %s mode %s lenght %d " (show $ msId sdu) (show $ msDir sdu)
    --             (BL.length $ msBlob sdu)
    case MiniProtocolDispatch m
-> MiniProtocolNum
-> MiniProtocolDir
-> Maybe (MiniProtocolDispatchInfo m)
forall (m :: * -> *).
MiniProtocolDispatch m
-> MiniProtocolNum
-> MiniProtocolDir
-> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol MiniProtocolDispatch m
dispatchTable (MuxSDU -> MiniProtocolNum
msNum MuxSDU
sdu)
                            -- Notice the mode reversal, ResponderDir is
                            -- delivered to InitiatorDir and vice versa:
                            (MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir (MiniProtocolDir -> MiniProtocolDir)
-> MiniProtocolDir -> MiniProtocolDir
forall a b. (a -> b) -> a -> b
$ MuxSDU -> MiniProtocolDir
msDir MuxSDU
sdu) of
      Maybe (MiniProtocolDispatchInfo m)
Nothing   -> MuxError -> m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxErrorType -> String -> MuxError
MuxError MuxErrorType
MuxUnknownMiniProtocol
                           (String
"id = " String -> String -> String
forall a. [a] -> [a] -> [a]
++ MiniProtocolNum -> String
forall a. Show a => a -> String
show (MuxSDU -> MiniProtocolNum
msNum MuxSDU
sdu)))
      Just MiniProtocolDispatchInfo m
MiniProtocolDirUnused ->
                   MuxError -> m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MuxErrorType -> String -> MuxError
MuxError MuxErrorType
MuxInitiatorOnly
                           (String
"id = " String -> String -> String
forall a. [a] -> [a] -> [a]
++ MiniProtocolNum -> String
forall a. Show a => a -> String
show (MuxSDU -> MiniProtocolNum
msNum MuxSDU
sdu)))
      Just (MiniProtocolDispatchInfo IngressQueue m
q Int
qMax) ->
        STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          ByteString
buf <- IngressQueue m -> STM m ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar IngressQueue m
q
          if ByteString -> Int64
BL.length ByteString
buf Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length (MuxSDU -> ByteString
msBlob MuxSDU
sdu) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
qMax
              then IngressQueue m -> ByteString -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar IngressQueue m
q (ByteString -> STM m ()) -> ByteString -> STM m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> ByteString
BL.append ByteString
buf (MuxSDU -> ByteString
msBlob MuxSDU
sdu)
              else MuxError -> STM m ()
forall (m :: * -> *) e a.
(MonadSTM m, MonadThrow (STM m), Exception e) =>
e -> STM m a
throwSTM (MuxError -> STM m ()) -> MuxError -> STM m ()
forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
MuxError MuxErrorType
MuxIngressQueueOverRun
                                (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Ingress Queue overrun on %s %s"
                                (MiniProtocolNum -> String
forall a. Show a => a -> String
show (MiniProtocolNum -> String) -> MiniProtocolNum -> String
forall a b. (a -> b) -> a -> b
$ MuxSDU -> MiniProtocolNum
msNum MuxSDU
sdu) (MiniProtocolDir -> String
forall a. Show a => a -> String
show (MiniProtocolDir -> String) -> MiniProtocolDir -> String
forall a b. (a -> b) -> a -> b
$ MuxSDU -> MiniProtocolDir
msDir MuxSDU
sdu))

lookupMiniProtocol :: MiniProtocolDispatch m
                   -> MiniProtocolNum
                   -> MiniProtocolDir
                   -> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol :: MiniProtocolDispatch m
-> MiniProtocolNum
-> MiniProtocolDir
-> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol (MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray) MiniProtocolNum
pnum MiniProtocolDir
pdir
  | (MiniProtocolNum, MiniProtocolNum) -> MiniProtocolNum -> Bool
forall a. Ix a => (a, a) -> a -> Bool
inRange (Array MiniProtocolNum (Maybe MiniProtocolIx)
-> (MiniProtocolNum, MiniProtocolNum)
forall i e. Array i e -> (i, i)
bounds Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray) MiniProtocolNum
pnum
  , Just MiniProtocolIx
mpid <- Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
pnum = MiniProtocolDispatchInfo m -> Maybe (MiniProtocolDispatchInfo m)
forall a. a -> Maybe a
Just (Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> (MiniProtocolIx, MiniProtocolDir) -> MiniProtocolDispatchInfo m
forall i e. Ix i => Array i e -> i -> e
! (MiniProtocolIx
mpid, MiniProtocolDir
pdir))
  | Bool
otherwise                   = Maybe (MiniProtocolDispatchInfo m)
forall a. Maybe a
Nothing

-- | Construct the table that maps 'MiniProtocolNum' and 'MiniProtocolDir' to
-- 'MiniProtocolDispatchInfo'. Use 'lookupMiniProtocol' to index it.
--
setupDispatchTable :: forall mode m.
                      [MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable :: [MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls =
    Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
forall (m :: * -> *).
Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray
  where
    -- The 'MiniProtocolNum' space is sparse but we don't want a huge single
    -- table if we use large protocol numbers. So we use a two level mapping.
    --
    -- The first array maps 'MiniProtocolNum' to a dense space of intermediate
    -- integer indexes. These indexes are meaningless outside of the context of
    -- this table. Then we use the index and the 'MiniProtocolDir' for the
    -- second table.
    --
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray =
      (MiniProtocolNum, MiniProtocolNum)
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (MiniProtocolNum
minpnum, MiniProtocolNum
maxpnum) ([(MiniProtocolNum, Maybe MiniProtocolIx)]
 -> Array MiniProtocolNum (Maybe MiniProtocolIx))
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall a b. (a -> b) -> a -> b
$
            -- Fill in Nothing first to cover any unused ones.
            [ (MiniProtocolNum
pnum, Maybe MiniProtocolIx
forall a. Maybe a
Nothing)    | MiniProtocolNum
pnum <- [MiniProtocolNum
minpnum..MiniProtocolNum
maxpnum] ]

            -- And override with the ones actually used.
         [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
forall a. [a] -> [a] -> [a]
++ [ (MiniProtocolNum
pnum, MiniProtocolIx -> Maybe MiniProtocolIx
forall a. a -> Maybe a
Just MiniProtocolIx
pix)   | (MiniProtocolIx
pix, MiniProtocolNum
pnum) <- [MiniProtocolIx]
-> [MiniProtocolNum] -> [(MiniProtocolIx, MiniProtocolNum)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MiniProtocolIx
0..] [MiniProtocolNum]
pnums ]

    ptclArray :: Array (MiniProtocolIx, MiniProtocolDir)
                       (MiniProtocolDispatchInfo m)
    ptclArray :: Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray =
      ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir), (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ([((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
 -> Array
      (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall a b. (a -> b) -> a -> b
$
            -- Fill in MiniProtocolDirUnused first to cover any unused ones.
            [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), MiniProtocolDispatchInfo m
forall (m :: * -> *). MiniProtocolDispatchInfo m
MiniProtocolDirUnused)
            | (MiniProtocolIx
pix, MiniProtocolDir
dir) <- ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [(MiniProtocolIx, MiniProtocolDir)]
forall a. Ix a => (a, a) -> [a]
range ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir),
                                   (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ]

             -- And override with the ones actually used.
         [((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
forall a. [a] -> [a] -> [a]
++ [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), IngressQueue m -> Int -> MiniProtocolDispatchInfo m
forall (m :: * -> *).
IngressQueue m -> Int -> MiniProtocolDispatchInfo m
MiniProtocolDispatchInfo IngressQueue m
q Int
qMax)
            | MiniProtocolState {
                miniProtocolInfo :: forall (mode :: MuxMode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo =
                  MiniProtocolInfo {
                    MiniProtocolNum
miniProtocolNum :: forall (mode :: MuxMode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum,
                    MiniProtocolDirection mode
miniProtocolDir :: forall (mode :: MuxMode).
MiniProtocolInfo mode -> MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
miniProtocolDir,
                    MiniProtocolLimits
miniProtocolLimits :: forall (mode :: MuxMode).
MiniProtocolInfo mode -> MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits
                  },
                miniProtocolIngressQueue :: forall (mode :: MuxMode) (m :: * -> *).
MiniProtocolState mode m -> IngressQueue m
miniProtocolIngressQueue = IngressQueue m
q
              } <- [MiniProtocolState mode m]
ptcls
            , let pix :: MiniProtocolIx
pix  =
                   case Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
miniProtocolNum of
                     Just MiniProtocolIx
a  -> MiniProtocolIx
a
                     Maybe MiniProtocolIx
Nothing -> String -> MiniProtocolIx
forall a. HasCallStack => String -> a
error (String
"setupDispatchTable: missing " String -> String -> String
forall a. [a] -> [a] -> [a]
++ MiniProtocolNum -> String
forall a. Show a => a -> String
show MiniProtocolNum
miniProtocolNum)
                  dir :: MiniProtocolDir
dir      = MiniProtocolDirection mode -> MiniProtocolDir
forall (mode :: MuxMode).
MiniProtocolDirection mode -> MiniProtocolDir
protocolDirEnum MiniProtocolDirection mode
miniProtocolDir
                  qMax :: Int
qMax     = MiniProtocolLimits -> Int
maximumIngressQueue MiniProtocolLimits
miniProtocolLimits
            ]

    -- The protocol numbers actually used, in the order of the first use within
    -- the 'ptcls' list. The order does not matter provided we do it
    -- consistently between the two arrays.
    pnums :: [MiniProtocolNum]
pnums   = [MiniProtocolNum] -> [MiniProtocolNum]
forall a. Eq a => [a] -> [a]
nub ([MiniProtocolNum] -> [MiniProtocolNum])
-> [MiniProtocolNum] -> [MiniProtocolNum]
forall a b. (a -> b) -> a -> b
$ (MiniProtocolState mode m -> MiniProtocolNum)
-> [MiniProtocolState mode m] -> [MiniProtocolNum]
forall a b. (a -> b) -> [a] -> [b]
map (MiniProtocolInfo mode -> MiniProtocolNum
forall (mode :: MuxMode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum (MiniProtocolInfo mode -> MiniProtocolNum)
-> (MiniProtocolState mode m -> MiniProtocolInfo mode)
-> MiniProtocolState mode m
-> MiniProtocolNum
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MiniProtocolState mode m -> MiniProtocolInfo mode
forall (mode :: MuxMode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo) [MiniProtocolState mode m]
ptcls

    -- The dense range of indexes of used protocol numbers.
    minpix, maxpix :: MiniProtocolIx
    minpix :: MiniProtocolIx
minpix  = MiniProtocolIx
0
    maxpix :: MiniProtocolIx
maxpix  = Int -> MiniProtocolIx
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MiniProtocolNum] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MiniProtocolNum]
pnums Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    -- The sparse range of protocol numbers
    minpnum, maxpnum :: MiniProtocolNum
    minpnum :: MiniProtocolNum
minpnum = [MiniProtocolNum] -> MiniProtocolNum
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [MiniProtocolNum]
pnums
    maxpnum :: MiniProtocolNum
maxpnum = [MiniProtocolNum] -> MiniProtocolNum
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [MiniProtocolNum]
pnums