{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE GADTSyntax                #-}
{-# LANGUAGE KindSignatures            #-}
{-# LANGUAGE NamedFieldPuns            #-}
{-# LANGUAGE RankNTypes                #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeFamilies              #-}

module Network.Mux.Compat
  ( muxStart
    -- * Mux bearers
  , MuxBearer
    -- * Defining 'MuxApplication's
  , MuxMode (..)
  , HasInitiator
  , HasResponder
  , MuxApplication (..)
  , MuxMiniProtocol (..)
  , RunMiniProtocol (..)
  , MiniProtocolNum (..)
  , MiniProtocolLimits (..)
  , MiniProtocolDir (..)
    -- * Errors
  , MuxError (..)
  , MuxErrorType (..)
    -- * Tracing
  , traceMuxBearerState
  , MuxBearerState (..)
  , MuxTrace (..)
  , WithMuxBearer (..)
  ) where

import qualified Data.ByteString.Lazy as BL
import           Data.Void (Void)

import           Control.Applicative ((<|>))
import           Control.Monad
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM.Strict
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer
import           Control.Tracer

import           Network.Mux (StartOnDemandOrEagerly (..), newMux,
                     runMiniProtocol, runMux, stopMux, traceMuxBearerState)
import           Network.Mux.Channel
import           Network.Mux.Trace
import           Network.Mux.Types hiding (MiniProtocolInfo (..))
import qualified Network.Mux.Types as Types


newtype MuxApplication (mode :: MuxMode) m a b =
        MuxApplication [MuxMiniProtocol mode m a b]

data MuxMiniProtocol (mode :: MuxMode) m a b =
     MuxMiniProtocol {
       MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum    :: !MiniProtocolNum,
       MuxMiniProtocol mode m a b -> MiniProtocolLimits
miniProtocolLimits :: !MiniProtocolLimits,
       MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun    :: !(RunMiniProtocol mode m a b)
     }

data RunMiniProtocol (mode :: MuxMode) m a b where
  InitiatorProtocolOnly
    -- Initiator application; most simple application will be @'runPeer'@ or
    -- @'runPipelinedPeer'@ supplied with a codec and a @'Peer'@ for each
    -- @ptcl@.  But it allows to handle resources if just application of
    -- @'runPeer'@ is not enough.  It will be run as @'InitiatorDir'@.
    :: (Channel m -> m (a, Maybe BL.ByteString))
    -> RunMiniProtocol InitiatorMode m a Void

  ResponderProtocolOnly
    -- Responder application; similarly to the @'MuxInitiatorApplication'@ but it
    -- will be run using @'ResponderDir'@.
    :: (Channel m -> m (b, Maybe BL.ByteString))
    -> RunMiniProtocol ResponderMode m Void b

  InitiatorAndResponderProtocol
    -- Initiator and server applications.
    :: (Channel m -> m (a, Maybe BL.ByteString))
    -> (Channel m -> m (b, Maybe BL.ByteString))
    -> RunMiniProtocol InitiatorResponderMode m a b


muxStart
    :: forall m mode a b.
       ( MonadAsync m
       , MonadFork m
       , MonadLabelledSTM m
       , MonadThrow (STM m)
       , MonadTime  m
       , MonadTimer m
       , MonadMask m
       )
    => Tracer m MuxTrace
    -> MuxApplication mode m a b
    -> MuxBearer m
    -> m ()
muxStart :: Tracer m MuxTrace
-> MuxApplication mode m a b -> MuxBearer m -> m ()
muxStart Tracer m MuxTrace
tracer MuxApplication mode m a b
muxapp MuxBearer m
bearer = do
    Mux mode m
mux <- MiniProtocolBundle mode -> m (Mux mode m)
forall (m :: * -> *) (mode :: MuxMode).
MonadSTM m =>
MiniProtocolBundle mode -> m (Mux mode m)
newMux (MuxApplication mode m a b -> MiniProtocolBundle mode
toMiniProtocolBundle MuxApplication mode m a b
muxapp)

    [STM m (Either SomeException ())]
resOps <- [m (STM m (Either SomeException ()))]
-> m [STM m (Either SomeException ())]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ Mux mode m
-> MiniProtocolNum
-> MiniProtocolDirection mode
-> StartOnDemandOrEagerly
-> (Channel m -> m ((), Maybe ByteString))
-> m (STM m (Either SomeException ()))
forall (mode :: MuxMode) (m :: * -> *) a.
(MonadSTM m, MonadThrow m, MonadThrow (STM m)) =>
Mux mode m
-> MiniProtocolNum
-> MiniProtocolDirection mode
-> StartOnDemandOrEagerly
-> (Channel m -> m (a, Maybe ByteString))
-> m (STM m (Either SomeException a))
runMiniProtocol
          Mux mode m
mux
          MiniProtocolNum
miniProtocolNum
          MiniProtocolDirection mode
ptclDir
          StartOnDemandOrEagerly
StartEagerly
          (\Channel m
a -> do
            ()
r <- Channel m -> m ()
action Channel m
a
            ((), Maybe ByteString) -> m ((), Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (()
r, Maybe ByteString
forall a. Maybe a
Nothing) -- Compat interface doesn't do restarts
          )
      | let MuxApplication [MuxMiniProtocol mode m a b]
ptcls = MuxApplication mode m a b
muxapp
      , MuxMiniProtocol{MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum, RunMiniProtocol mode m a b
miniProtocolRun :: RunMiniProtocol mode m a b
miniProtocolRun :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun} <- [MuxMiniProtocol mode m a b]
ptcls
      , (MiniProtocolDirection mode
ptclDir, Channel m -> m ()
action) <- RunMiniProtocol mode m a b
-> [(MiniProtocolDirection mode, Channel m -> m ())]
selectRunner RunMiniProtocol mode m a b
miniProtocolRun
      ]

    -- Wait for the first MuxApplication to finish, then stop the mux.
    m () -> (Async m () -> m ()) -> m ()
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (Tracer m MuxTrace -> Mux mode m -> MuxBearer m -> m ()
forall (m :: * -> *) (mode :: MuxMode).
(MonadAsync m, MonadCatch m, MonadFork m, MonadLabelledSTM m,
 MonadThrow (STM m), MonadTime m, MonadTimer m, MonadMask m) =>
Tracer m MuxTrace -> Mux mode m -> MuxBearer m -> m ()
runMux Tracer m MuxTrace
tracer Mux mode m
mux MuxBearer m
bearer) ((Async m () -> m ()) -> m ()) -> (Async m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async m ()
aid -> do
      [STM m (Either SomeException ())] -> m ()
waitOnAny [STM m (Either SomeException ())]
resOps
      Mux mode m -> m ()
forall (m :: * -> *) (mode :: MuxMode).
MonadSTM m =>
Mux mode m -> m ()
stopMux Mux mode m
mux
      Async m () -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m a
wait Async m ()
aid

  where
    waitOnAny :: [STM m (Either SomeException  ())] -> m ()
    waitOnAny :: [STM m (Either SomeException ())] -> m ()
waitOnAny [STM m (Either SomeException ())]
resOps = STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ STM m (Either SomeException ()) -> STM m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM m (Either SomeException ()) -> STM m ())
-> STM m (Either SomeException ()) -> STM m ()
forall a b. (a -> b) -> a -> b
$ (STM m (Either SomeException ())
 -> STM m (Either SomeException ())
 -> STM m (Either SomeException ()))
-> STM m (Either SomeException ())
-> [STM m (Either SomeException ())]
-> STM m (Either SomeException ())
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr STM m (Either SomeException ())
-> STM m (Either SomeException ())
-> STM m (Either SomeException ())
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>) STM m (Either SomeException ())
forall (m :: * -> *) a. MonadSTM m => STM m a
retry [STM m (Either SomeException ())]
resOps

    toMiniProtocolBundle :: MuxApplication mode m a b -> MiniProtocolBundle mode
    toMiniProtocolBundle :: MuxApplication mode m a b -> MiniProtocolBundle mode
toMiniProtocolBundle (MuxApplication [MuxMiniProtocol mode m a b]
ptcls) =
      [MiniProtocolInfo mode] -> MiniProtocolBundle mode
forall (mode :: MuxMode).
[MiniProtocolInfo mode] -> MiniProtocolBundle mode
MiniProtocolBundle
        [ MiniProtocolInfo :: forall (mode :: MuxMode).
MiniProtocolNum
-> MiniProtocolDirection mode
-> MiniProtocolLimits
-> MiniProtocolInfo mode
Types.MiniProtocolInfo {
            MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
Types.miniProtocolNum,
            MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
Types.miniProtocolDir,
            MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
Types.miniProtocolLimits
          }
        | MuxMiniProtocol {
            MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum,
            MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolLimits
miniProtocolLimits,
            RunMiniProtocol mode m a b
miniProtocolRun :: RunMiniProtocol mode m a b
miniProtocolRun :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun
          } <- [MuxMiniProtocol mode m a b]
ptcls
        , MiniProtocolDirection mode
miniProtocolDir <- case RunMiniProtocol mode m a b
miniProtocolRun of
            InitiatorProtocolOnly{} -> [MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorMode
InitiatorDirectionOnly]
            ResponderProtocolOnly{} -> [MiniProtocolDirection mode
MiniProtocolDirection 'ResponderMode
ResponderDirectionOnly]
            InitiatorAndResponderProtocol{} -> [MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorResponderMode
InitiatorDirection, MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorResponderMode
ResponderDirection]
        ]

    selectRunner :: RunMiniProtocol mode m a b
                 -> [(MiniProtocolDirection mode, Channel m -> m ())]
    selectRunner :: RunMiniProtocol mode m a b
-> [(MiniProtocolDirection mode, Channel m -> m ())]
selectRunner (InitiatorProtocolOnly Channel m -> m (a, Maybe ByteString)
initiator) =
      [(MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorMode
InitiatorDirectionOnly, m (a, Maybe ByteString) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (a, Maybe ByteString) -> m ())
-> (Channel m -> m (a, Maybe ByteString)) -> Channel m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (a, Maybe ByteString)
initiator)]
    selectRunner (ResponderProtocolOnly Channel m -> m (b, Maybe ByteString)
responder) =
      [(MiniProtocolDirection mode
MiniProtocolDirection 'ResponderMode
ResponderDirectionOnly, m (b, Maybe ByteString) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (b, Maybe ByteString) -> m ())
-> (Channel m -> m (b, Maybe ByteString)) -> Channel m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (b, Maybe ByteString)
responder)]
    selectRunner (InitiatorAndResponderProtocol Channel m -> m (a, Maybe ByteString)
initiator Channel m -> m (b, Maybe ByteString)
responder) =
      [(MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorResponderMode
InitiatorDirection, m (a, Maybe ByteString) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (a, Maybe ByteString) -> m ())
-> (Channel m -> m (a, Maybe ByteString)) -> Channel m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (a, Maybe ByteString)
initiator)
      ,(MiniProtocolDirection mode
MiniProtocolDirection 'InitiatorResponderMode
ResponderDirection, m (b, Maybe ByteString) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (b, Maybe ByteString) -> m ())
-> (Channel m -> m (b, Maybe ByteString)) -> Channel m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (b, Maybe ByteString)
responder)]