{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RecursiveDo         #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

-- `accept` is shadowed, but so what?
{-# OPTIONS_GHC "-fno-warn-name-shadowing" #-}

module Ouroboros.Network.Server.Socket
  ( AcceptedConnectionsLimit (..)
  , AcceptConnectionsPolicyTrace (..)
  , BeginConnection
  , HandleConnection (..)
  , ApplicationStart
  , CompleteConnection
  , CompleteApplicationResult (..)
  , Result (..)
  , Main
  , run
  , Socket (..)
  , ioSocket
  ) where

import           Control.Concurrent.Async (Async)
import qualified Control.Concurrent.Async as Async
import           Control.Concurrent.STM (STM)
import qualified Control.Concurrent.STM as STM
import           Control.Exception (IOException, SomeException (..), finally,
                     mask, mask_, onException, try)
import           Control.Monad (forM_, join)
import           Control.Monad.Class.MonadTime (Time, getMonotonicTime)
import           Control.Monad.Class.MonadTimer (threadDelay)
import           Control.Tracer (Tracer, traceWith)
import           Data.Foldable (traverse_)
import           Data.Set (Set)
import qualified Data.Set as Set

import           Ouroboros.Network.ErrorPolicy (CompleteApplicationResult (..),
                     ErrorPolicyTrace, WithAddr)
import           Ouroboros.Network.Server.RateLimiting

-- | Abstraction of something that can provide connections.
-- A `Network.Socket` can be used to get a
-- `Socket SockAddr (Channel IO Lazy.ByteString)`
-- It's not defined in here, though, because we don't want the dependency
-- on typed-protocols or even on network.
data Socket addr channel = Socket
  { Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
acceptConnection :: IO (addr, channel, IO (), Socket addr channel)
    -- ^ The address, a channel, IO to close the channel.
  }

-- | Expected to be useful for testing.
ioSocket :: IO (addr, channel) -> Socket addr channel
ioSocket :: IO (addr, channel) -> Socket addr channel
ioSocket IO (addr, channel)
io = Socket :: forall addr channel.
IO (addr, channel, IO (), Socket addr channel)
-> Socket addr channel
Socket
  { acceptConnection :: IO (addr, channel, IO (), Socket addr channel)
acceptConnection = do
      (addr
addr, channel
channel) <- IO (addr, channel)
io
      (addr, channel, IO (), Socket addr channel)
-> IO (addr, channel, IO (), Socket addr channel)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (addr
addr, channel
channel, () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure (), IO (addr, channel) -> Socket addr channel
forall addr channel. IO (addr, channel) -> Socket addr channel
ioSocket IO (addr, channel)
io)
  }

type StatusVar st = STM.TVar st


-- | What to do with a new connection: reject it and give a new state, or
-- accept it and give a new state with a continuation to run against the
-- resulting channel.
-- See also `CompleteConnection`, which is run for every connection when it finishes, and
-- can also update the state.
data HandleConnection channel st r where
  Reject :: !st -> HandleConnection channel st r
  Accept :: !st -> !(channel -> IO r) -> HandleConnection channel st r

-- | What to do on a new connection: accept and run this `IO`, or reject.
type BeginConnection addr channel st r = Time -> addr -> st -> STM (HandleConnection channel st r)

-- | A call back which runs when application starts;
--
-- It is needed only because 'BeginConnection' does not have access to the
-- thread which runs the application.
--
type ApplicationStart addr st = addr -> Async () -> st -> STM st

-- | How to update state when a connection finishes. Can use `throwSTM` to
-- terminate the server.
--
-- TODO: remove 'async', use `Async m ()` from 'MonadAsync'.
type CompleteConnection addr st tr r =
    Result addr r -> st -> STM (CompleteApplicationResult IO addr st)

-- | Given a current state, `retry` unless you want to stop the server.
-- When this transaction returns, any running threads spawned by the server
-- will be killed.
--
-- It's possible that a connection is accepted after the main thread
-- returns, but before the server stops. In that case, it will be killed, and
-- the `CompleteConnection` will not run against it.
type Main st t = st -> STM t

-- | To avoid repeatedly blocking on the set of all running threads (a
-- potentially very large STM transaction) the results come in by way of a
-- `TQueue`. Using a queue rather than, say, a `TMVar`, also finesses a
-- potential deadlock when shutting down the server and killing spawned threads:
-- the server can stop pulling from the queue, without causing the child
-- threads to hang attempting to write to it.
type ResultQ addr r = STM.TQueue (Result addr r)

-- | The product of a spawned thread. We catch all (even async) exceptions.
data Result addr r = Result
  { Result addr r -> Async ()
resultThread :: !(Async ())
  , Result addr r -> addr
resultAddr   :: !addr
  , Result addr r -> Time
resultTime   :: !Time
  , Result addr r -> Either SomeException r
resultValue  :: !(Either SomeException r)
  }

-- | The set of all spawned threads. Used for waiting or cancelling them when
-- the server shuts down.
type ThreadsVar = STM.TVar (Set (Async ()))


-- | The action runs inside `try`, and when it finishes, puts its result
-- into the `ResultQ`. Takes care of inserting/deleting from the `ThreadsVar`.
--
-- Async exceptions are masked to ensure that if the thread is spawned, it
-- always gets into the `ThreadsVar`. Exceptions are unmasked in the
-- spawned thread.
spawnOne
  :: addr
  -> StatusVar st
  -> ResultQ addr r
  -> ThreadsVar
  -> ApplicationStart addr st
  -> IO r
  -> IO ()
spawnOne :: addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
spawnOne addr
remoteAddr StatusVar st
statusVar ResultQ addr r
resQ ThreadsVar
threadsVar ApplicationStart addr st
applicationStart IO r
io = IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  rec let threadAction :: (IO r -> IO r) -> IO ()
threadAction = \IO r -> IO r
unmask -> do
            STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
                  StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
              STM st -> (st -> STM st) -> STM st
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ApplicationStart addr st
applicationStart addr
remoteAddr Async ()
thread
              STM st -> (st -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar (st -> STM ()) -> st -> STM ()
forall a b. (a -> b) -> a -> b
$!)
            Either SomeException r
val <- IO r -> IO (Either SomeException r)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO r -> IO r
unmask IO r
io)
            Time
t <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
            -- No matter what the exception, async or sync, this will not
            -- deadlock, since we use a `TQueue`. If the server kills its
            -- children, and stops clearing the queue, it will be collected
            -- shortly thereafter, so no problem.
            STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ ResultQ addr r -> Result addr r -> STM ()
forall a. TQueue a -> a -> STM ()
STM.writeTQueue ResultQ addr r
resQ (Async () -> addr -> Time -> Either SomeException r -> Result addr r
forall addr r.
Async () -> addr -> Time -> Either SomeException r -> Result addr r
Result Async ()
thread addr
remoteAddr Time
t Either SomeException r
val)
      Async ()
thread <- ((forall a. IO a -> IO a) -> IO ()) -> IO (Async ())
forall a. ((forall a. IO a -> IO a) -> IO a) -> IO (Async a)
Async.asyncWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO (Async ()))
-> ((forall a. IO a -> IO a) -> IO ()) -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
          (IO r -> IO r) -> IO ()
threadAction IO r -> IO r
forall a. IO a -> IO a
unmask
  -- The main loop `connectionTx` will remove this entry from the set, once
  -- it receives the result.
  STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadsVar -> (Set (Async ()) -> Set (Async ())) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
STM.modifyTVar' ThreadsVar
threadsVar (Async () -> Set (Async ()) -> Set (Async ())
forall a. Ord a => a -> Set a -> Set a
Set.insert Async ()
thread)


-- | The accept thread is controlled entirely by the `accept` call. To
-- stop it, whether normally or exceptionally, it must be killed by an async
-- exception, or the exception callback here must re-throw.
acceptLoop
  :: Tracer IO AcceptConnectionsPolicyTrace
  -> ResultQ addr r
  -> ThreadsVar
  -> StatusVar st
  -> AcceptedConnectionsLimit
  -> BeginConnection addr channel st r
  -> ApplicationStart addr st
  -> (IOException -> IO ()) -- ^ Exception on `Socket.accept`.
  -> Socket addr channel
  -> IO ()
acceptLoop :: Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket = do
    Maybe (Socket addr channel)
mNextSocket <- Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
forall addr channel st r.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
acceptOne Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
    case Maybe (Socket addr channel)
mNextSocket of
      Maybe (Socket addr channel)
Nothing -> do
        -- Thread delay to mitigate potential livelock.
        DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
0.5
        Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
      Just Socket addr channel
nextSocket ->
        Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
nextSocket

-- | Accept once from the socket, use the `Accept` to make a decision (accept
-- or reject), and spawn the thread if accepted.
acceptOne
  :: forall addr channel st r.
     Tracer IO AcceptConnectionsPolicyTrace
  -> ResultQ addr r
  -> ThreadsVar
  -> StatusVar st
  -> AcceptedConnectionsLimit
  -> BeginConnection addr channel st r
  -> ApplicationStart addr st
  -> (IOException -> IO ()) -- ^ Exception on `Socket.accept`.
  -> Socket addr channel
  -> IO (Maybe (Socket addr channel))
acceptOne :: Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
acceptOne Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionsLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket = ((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
-> IO (Maybe (Socket addr channel))
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
 -> IO (Maybe (Socket addr channel)))
-> ((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
-> IO (Maybe (Socket addr channel))
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do

  -- Rate limiting of accepted connections; this might block.
  Tracer IO AcceptConnectionsPolicyTrace
-> STM IO Int -> AcceptedConnectionsLimit -> IO ()
forall (m :: * -> *).
(MonadSTM m, MonadDelay m, MonadTime m) =>
Tracer m AcceptConnectionsPolicyTrace
-> STM m Int -> AcceptedConnectionsLimit -> m ()
runConnectionRateLimits
    Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace
    (Set (Async ()) -> Int
forall a. Set a -> Int
Set.size (Set (Async ()) -> Int) -> STM (Set (Async ())) -> STM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ThreadsVar -> STM (Set (Async ()))
forall a. TVar a -> STM a
STM.readTVar ThreadsVar
threadsVar)
    AcceptedConnectionsLimit
acceptedConnectionsLimit

  -- mask is to assure that every socket is closed.
  Either IOException (addr, channel, IO (), Socket addr channel)
outcome <- IO (addr, channel, IO (), Socket addr channel)
-> IO
     (Either IOException (addr, channel, IO (), Socket addr channel))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (addr, channel, IO (), Socket addr channel)
-> IO (addr, channel, IO (), Socket addr channel)
forall a. IO a -> IO a
restore (Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
forall addr channel.
Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
acceptConnection Socket addr channel
socket))
  case Either IOException (addr, channel, IO (), Socket addr channel)
outcome :: Either IOException (addr, channel, IO (), Socket addr channel) of
    Left IOException
ex -> do
      -- Classify the exception, if it is fatal to the node or not.
      -- If it is fatal to the node the exception will propagate.
      IO () -> IO ()
forall a. IO a -> IO a
restore (IOException -> IO ()
acceptException IOException
ex)
      Maybe (Socket addr channel) -> IO (Maybe (Socket addr channel))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Socket addr channel)
forall a. Maybe a
Nothing
    Right (addr
addr, channel
channel, IO ()
close, Socket addr channel
nextSocket) -> do
      -- Decide whether to accept or reject, using the current state, and
      -- update it according to the decision.
      Time
t <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      let decision :: IO (Maybe (channel -> IO r))
decision = STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r))
forall a. STM a -> IO a
STM.atomically (STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r)))
-> STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r))
forall a b. (a -> b) -> a -> b
$ do
            st
st <- StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
            !HandleConnection channel st r
handleConn <- BeginConnection addr channel st r
beginConnection Time
t addr
addr st
st
            case HandleConnection channel st r
handleConn of
              Reject st
st' -> do
                StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar st
st'
                Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (channel -> IO r)
forall a. Maybe a
Nothing
              Accept st
st' channel -> IO r
io -> do
                StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar st
st'
                Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r)))
-> Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall a b. (a -> b) -> a -> b
$ (channel -> IO r) -> Maybe (channel -> IO r)
forall a. a -> Maybe a
Just channel -> IO r
io
      -- this could be interrupted, so we use `onException` to close the
      -- socket.
      Maybe (channel -> IO r)
choice <- IO (Maybe (channel -> IO r))
decision IO (Maybe (channel -> IO r))
-> IO () -> IO (Maybe (channel -> IO r))
forall a b. IO a -> IO b -> IO a
`onException` IO ()
close
      case Maybe (channel -> IO r)
choice of
        Maybe (channel -> IO r)
Nothing -> IO ()
close
        Just channel -> IO r
io -> addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
forall addr st r.
addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
spawnOne addr
addr StatusVar st
statusVar ResultQ addr r
resQ ThreadsVar
threadsVar ApplicationStart addr st
applicationStart (channel -> IO r
io channel
channel IO r -> IO () -> IO r
forall a b. IO a -> IO b -> IO a
`finally` IO ()
close)
      Maybe (Socket addr channel) -> IO (Maybe (Socket addr channel))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket addr channel -> Maybe (Socket addr channel)
forall a. a -> Maybe a
Just Socket addr channel
nextSocket)

-- | Main server loop, which runs alongside the `acceptLoop`. It waits for
-- the results of connection threads, as well as the `Main` action, which
-- determines when/if the server should stop.
mainLoop
  :: forall addr st tr r t .
     Tracer IO (WithAddr addr ErrorPolicyTrace)
  -> ResultQ addr r
  -> ThreadsVar
  -> StatusVar st
  -> CompleteConnection addr st tr r
  -> Main st t
  -> IO t
mainLoop :: Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar CompleteConnection addr st tr r
complete Main st t
main =
  IO (IO t) -> IO t
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (STM (IO t) -> IO (IO t)
forall a. STM a -> IO a
STM.atomically (STM (IO t) -> IO (IO t)) -> STM (IO t) -> IO (IO t)
forall a b. (a -> b) -> a -> b
$ STM (IO t)
mainTx STM (IO t) -> STM (IO t) -> STM (IO t)
forall a. STM a -> STM a -> STM a
`STM.orElse` STM (IO t)
connectionTx)

  where

  -- Sample the status, and run the main action. If it does not retry, then
  -- the `mainLoop` finishes with `pure t` where `t` is the main action result.
  mainTx :: STM (IO t)
  mainTx :: STM (IO t)
mainTx = do
    st
st <- StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
    t
t <- Main st t
main st
st
    IO t -> STM (IO t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO t -> STM (IO t)) -> IO t -> STM (IO t)
forall a b. (a -> b) -> a -> b
$ t -> IO t
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
t

  -- Wait for some connection to finish, update the state with its result,
  -- then recurse onto `mainLoop`.
  connectionTx :: STM (IO t)
  connectionTx :: STM (IO t)
connectionTx = do
    Result addr r
result <- ResultQ addr r -> STM (Result addr r)
forall a. TQueue a -> STM a
STM.readTQueue ResultQ addr r
resQ
    -- Make sure we don't cleanup before spawnOne has inserted the thread
    Bool
isMember <- Async () -> Set (Async ()) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (Result addr r -> Async ()
forall addr r. Result addr r -> Async ()
resultThread Result addr r
result) (Set (Async ()) -> Bool) -> STM (Set (Async ())) -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ThreadsVar -> STM (Set (Async ()))
forall a. TVar a -> STM a
STM.readTVar ThreadsVar
threadsVar
    Bool -> STM ()
STM.check Bool
isMember

    st
st <- StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
    CompleteApplicationResult
      { st
carState :: forall (m :: * -> *) addr s.
CompleteApplicationResult m addr s -> s
carState :: st
carState
      , Set (Async IO ())
carThreads :: forall (m :: * -> *) addr s.
CompleteApplicationResult m addr s -> Set (Async m ())
carThreads :: Set (Async IO ())
carThreads
      , Maybe (WithAddr addr ErrorPolicyTrace)
carTrace :: forall (m :: * -> *) addr s.
CompleteApplicationResult m addr s
-> Maybe (WithAddr addr ErrorPolicyTrace)
carTrace :: Maybe (WithAddr addr ErrorPolicyTrace)
carTrace
      } <- CompleteConnection addr st tr r
complete Result addr r
result st
st
    -- 'CompleteConnectionResult' is strict in 'ccrState', thus we write
    -- evaluted state to 'statusVar'
    StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar st
carState
    -- It was inserted by `spawnOne`.
    ThreadsVar -> (Set (Async ()) -> Set (Async ())) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
STM.modifyTVar' ThreadsVar
threadsVar (Async () -> Set (Async ()) -> Set (Async ())
forall a. Ord a => a -> Set a -> Set a
Set.delete (Result addr r -> Async ()
forall addr r. Result addr r -> Async ()
resultThread Result addr r
result))
    IO t -> STM (IO t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO t -> STM (IO t)) -> IO t -> STM (IO t)
forall a b. (a -> b) -> a -> b
$ do
      (Async () -> IO ()) -> Set (Async ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Async () -> IO ()
forall a. Async a -> IO ()
Async.cancel Set (Async ())
Set (Async IO ())
carThreads
      (WithAddr addr ErrorPolicyTrace -> IO ())
-> Maybe (WithAddr addr ErrorPolicyTrace) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Tracer IO (WithAddr addr ErrorPolicyTrace)
-> WithAddr addr ErrorPolicyTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTrace) Maybe (WithAddr addr ErrorPolicyTrace)
carTrace
      Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
forall addr st tr r t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar CompleteConnection addr st tr r
complete Main st t
main


-- | Run a server.
run
  :: Tracer IO (WithAddr addr ErrorPolicyTrace)
  -> Tracer IO AcceptConnectionsPolicyTrace
  -- TODO: extend this trace to trace server action (this might be useful for
  -- debugging)
  -> Socket addr channel
  -> AcceptedConnectionsLimit
  -> (IOException -> IO ())
  -> BeginConnection addr channel st r
  -> ApplicationStart addr st
  -> CompleteConnection addr st tr r
  -> Main st t
  -> STM.TVar st
  -> IO t
run :: Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> Socket addr channel
-> AcceptedConnectionsLimit
-> (IOException -> IO ())
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> CompleteConnection addr st tr r
-> Main st t
-> TVar st
-> IO t
run Tracer IO (WithAddr addr ErrorPolicyTrace)
errroPolicyTrace Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace Socket addr channel
socket AcceptedConnectionsLimit
acceptedConnectionLimit IOException -> IO ()
acceptException BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart CompleteConnection addr st tr r
complete Main st t
main TVar st
statusVar = do
  TQueue (Result addr r)
resQ <- IO (TQueue (Result addr r))
forall a. IO (TQueue a)
STM.newTQueueIO
  ThreadsVar
threadsVar <- Set (Async ()) -> IO ThreadsVar
forall a. a -> IO (TVar a)
STM.newTVarIO Set (Async ())
forall a. Set a
Set.empty
  let acceptLoopDo :: IO ()
acceptLoopDo = Tracer IO AcceptConnectionsPolicyTrace
-> TQueue (Result addr r)
-> ThreadsVar
-> TVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace TQueue (Result addr r)
resQ ThreadsVar
threadsVar TVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
      -- The accept loop is killed when the main loop stops and the main
      -- loop is killed if the accept loop stops.
      mainDo :: IO t
mainDo = Tracer IO (WithAddr addr ErrorPolicyTrace)
-> TQueue (Result addr r)
-> ThreadsVar
-> TVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
forall addr st tr r t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errroPolicyTrace TQueue (Result addr r)
resQ ThreadsVar
threadsVar TVar st
statusVar CompleteConnection addr st tr r
complete Main st t
main
      killChildren :: IO ()
killChildren = do
        Set (Async ())
children <- STM (Set (Async ())) -> IO (Set (Async ()))
forall a. STM a -> IO a
STM.atomically (STM (Set (Async ())) -> IO (Set (Async ())))
-> STM (Set (Async ())) -> IO (Set (Async ()))
forall a b. (a -> b) -> a -> b
$ ThreadsVar -> STM (Set (Async ()))
forall a. TVar a -> STM a
STM.readTVar ThreadsVar
threadsVar
        [Async ()] -> (Async () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set (Async ()) -> [Async ()]
forall a. Set a -> [a]
Set.toList Set (Async ())
children) Async () -> IO ()
forall a. Async a -> IO ()
Async.cancel
  -- After both the main and accept loop have been killed, any remaining
  -- spawned threads are cancelled.
  (((), t) -> t
forall a b. (a, b) -> b
snd (((), t) -> t) -> IO ((), t) -> IO t
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO () -> IO t -> IO ((), t)
forall a b. IO a -> IO b -> IO (a, b)
Async.concurrently IO ()
acceptLoopDo IO t
mainDo) IO t -> IO () -> IO t
forall a b. IO a -> IO b -> IO a
`finally` IO ()
killChildren