{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE DerivingVia         #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- it is useful to have 'HasInitiator' constraint on 'connectToNode' & friends.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- For Hashable SockAddr
{-# OPTIONS_GHC -Wno-orphans #-}


-- |
-- Module exports interface for running a node over a socket over TCP \/ IP.
--
module Ouroboros.Network.Socket
  ( -- * High level socket interface
    ConnectionTable
  , ConnectionTableRef (..)
  , ValencyCounter
  , NetworkMutableState (..)
  , SomeResponderApplication (..)
  , newNetworkMutableState
  , newNetworkMutableStateSTM
  , cleanNetworkMutableState
  , AcceptedConnectionsLimit (..)
  , ConnectionId (..)
  , withServerNode
  , withServerNode'
  , connectToNode
  , connectToNodeSocket
  , connectToNode'
    -- * Traces
  , NetworkConnectTracers (..)
  , nullNetworkConnectTracers
  , debuggingNetworkConnectTracers
  , NetworkServerTracers (..)
  , nullNetworkServerTracers
  , debuggingNetworkServerTracers
  , AcceptConnectionsPolicyTrace (..)
    -- * Helper function for creating servers
  , fromSnocket
  , beginConnection
    -- * Re-export of PeerStates
  , PeerStates
    -- * Re-export connection table functions
  , newConnectionTable
  , refConnection
  , addConnection
  , removeConnection
  , newValencyCounter
  , addValencyCounter
  , remValencyCounter
  , waitValencyCounter
  , readValencyCounter
    -- * Auxiliary functions
  , sockAddrFamily
  ) where

import           Control.Concurrent.Async
import           Control.Exception (SomeException (..))
-- TODO: remove this, it will not be needed when `orElse` PR will be merged.
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Term as CBOR
import           Control.Monad.Class.MonadSTM.Strict
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import qualified Control.Monad.STM as STM
import qualified Data.ByteString.Lazy as BL
import           Data.Hashable
import           Data.Proxy (Proxy (..))
import           Data.Typeable (Typeable)
import           Data.Void
import           Data.Word (Word16)
import           GHC.IO.Exception
#if !defined(mingw32_HOST_OS)
import           Foreign.C.Error
#endif

import qualified Network.Socket as Socket

import           Control.Tracer

import qualified Network.Mux.Compat as Mx
import           Network.Mux.DeltaQ.TraceTransformer
import           Network.TypedProtocol.Codec hiding (decode, encode)

import           Ouroboros.Network.ConnectionId
import           Ouroboros.Network.Driver.Limits
import           Ouroboros.Network.ErrorPolicy
import           Ouroboros.Network.IOManager (IOManager)
import           Ouroboros.Network.Mux
import           Ouroboros.Network.Protocol.Handshake
import           Ouroboros.Network.Protocol.Handshake.Codec
import           Ouroboros.Network.Protocol.Handshake.Type
import           Ouroboros.Network.Server.ConnectionTable
import           Ouroboros.Network.Server.Socket
                     (AcceptConnectionsPolicyTrace (..),
                     AcceptedConnectionsLimit (..))
import qualified Ouroboros.Network.Server.Socket as Server
import           Ouroboros.Network.Snocket (Snocket)
import qualified Ouroboros.Network.Snocket as Snocket
import           Ouroboros.Network.Subscription.PeerState


-- | Tracer used by 'connectToNode' (and derivatives, like
-- 'Ouroboros.Network.NodeToNode.connectTo' or
-- 'Ouroboros.Network.NodeToClient.connectTo).
--
data NetworkConnectTracers addr vNumber = NetworkConnectTracers {
      NetworkConnectTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer         :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)  Mx.MuxTrace),
      -- ^ low level mux-network tracer, which logs mux sdu (send and received)
      -- and other low level multiplexing events.
      NetworkConnectTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer   :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)
                                          (TraceSendRecv (Handshake vNumber CBOR.Term)))
      -- ^ handshake protocol tracer; it is important for analysing version
      -- negotation mismatches.
    }

nullNetworkConnectTracers :: NetworkConnectTracers addr vNumber
nullNetworkConnectTracers :: NetworkConnectTracers addr vNumber
nullNetworkConnectTracers = NetworkConnectTracers :: forall addr vNumber.
Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> NetworkConnectTracers addr vNumber
NetworkConnectTracers {
      nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer       = Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
    }


debuggingNetworkConnectTracers :: (Show addr, Show vNumber)
                               => NetworkConnectTracers addr vNumber
debuggingNetworkConnectTracers :: NetworkConnectTracers addr vNumber
debuggingNetworkConnectTracers = NetworkConnectTracers :: forall addr vNumber.
Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> NetworkConnectTracers addr vNumber
NetworkConnectTracers {
      nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer       = Tracer IO String
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer = Tracer IO String
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer
    }

sockAddrFamily
    :: Socket.SockAddr
    -> Socket.Family
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (Socket.SockAddrInet  PortNumber
_ HostAddress
_    ) = Family
Socket.AF_INET
sockAddrFamily (Socket.SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
_ HostAddress
_) = Family
Socket.AF_INET6
sockAddrFamily (Socket.SockAddrUnix String
_       ) = Family
Socket.AF_UNIX

instance Hashable Socket.SockAddr where
  hashWithSalt :: Int -> SockAddr -> Int
hashWithSalt Int
s (Socket.SockAddrInet   PortNumber
p   HostAddress
a   ) = Int -> (Word16, HostAddress) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (PortNumber -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p :: Word16, HostAddress
a)
  hashWithSalt Int
s (Socket.SockAddrInet6  PortNumber
p HostAddress
_ HostAddress6
a HostAddress
_ ) = Int -> (Word16, HostAddress6) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (PortNumber -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p :: Word16, HostAddress6
a)
  hashWithSalt Int
s (Socket.SockAddrUnix   String
p       ) = Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s String
p

-- | We place an upper limit of `30s` on the time we wait on receiving an SDU.
-- There is no upper bound on the time we wait when waiting for a new SDU.
-- This makes it possible for miniprotocols to use timeouts that are larger
-- than 30s or wait forever.  `30s` for receiving an SDU corresponds to
-- a minimum speed limit of 17kbps.
--
-- ( 8      -- mux header length
-- + 0xffff -- maximum SDU payload
-- )
-- * 8
-- = 524_344 -- maximum bits in an SDU
--
--  524_344 / 30 / 1024 = 17kbps
--
sduTimeout :: DiffTime
sduTimeout :: DiffTime
sduTimeout = DiffTime
30

-- | For handshake, we put a limit of `10s` for sending or receiving a single
-- `MuxSDU`.
--
sduHandshakeTimeout :: DiffTime
sduHandshakeTimeout :: DiffTime
sduHandshakeTimeout = DiffTime
10


-- |
-- Connect to a remote node.  It is using bracket to enclose the underlying
-- socket acquisition.  This implies that when the continuation exits the
-- underlying bearer will get closed.
--
-- The connection will start with handshake protocol sending @Versions@ to the
-- remote peer.  It must fit into @'maxTransmissionUnit'@ (~5k bytes).
--
-- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@.
connectToNode
  :: forall appType vNumber vData fd addr a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => Snocket IO fd addr
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers addr vNumber
  -> (vData -> vData -> Accept vData)
  -> Versions vNumber vData (OuroborosApplication appType addr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> Maybe addr
  -- ^ local address; the created socket will bind to it
  -> addr
  -- ^ remote address
  -> IO ()
connectToNode :: Snocket IO fd addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> Maybe addr
-> addr
-> IO ()
connectToNode Snocket IO fd addr
sn Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers addr vNumber
tracers vData -> vData -> Accept vData
acceptVersion Versions
  vNumber vData (OuroborosApplication appType addr ByteString IO a b)
versions Maybe addr
localAddr addr
remoteAddr =
    IO fd -> (fd -> IO ()) -> (fd -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
      (Snocket IO fd addr -> addr -> IO fd
forall (m :: * -> *) fd addr. Snocket m fd addr -> addr -> m fd
Snocket.openToConnect Snocket IO fd addr
sn addr
remoteAddr)
      (Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn)
      (\fd
sd -> do
          case Maybe addr
localAddr of
            Just addr
addr -> Snocket IO fd addr -> fd -> addr -> IO ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket IO fd addr
sn fd
sd addr
addr
            Maybe addr
Nothing   -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Snocket IO fd addr -> fd -> addr -> IO ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.connect Snocket IO fd addr
sn fd
sd addr
remoteAddr
          Snocket IO fd addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> fd
-> IO ()
forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode' Snocket IO fd addr
sn Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers addr vNumber
tracers vData -> vData -> Accept vData
acceptVersion Versions
  vNumber vData (OuroborosApplication appType addr ByteString IO a b)
versions fd
sd
      )

-- |
-- Connect to a remote node using an existing socket. It is up to to caller to
-- ensure that the socket is closed in case of an exception.
--
-- The connection will start with handshake protocol sending @Versions@ to the
-- remote peer.  It must fit into @'maxTransmissionUnit'@ (~5k bytes).
--
-- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@.
connectToNode'
  :: forall appType vNumber vData fd addr a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => Snocket IO fd addr
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers addr vNumber
  -> (vData -> vData -> Accept vData)
  -> Versions vNumber vData (OuroborosApplication appType addr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> fd
  -> IO ()
connectToNode' :: Snocket IO fd addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode' Snocket IO fd addr
sn Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers {Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer :: forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer, Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer :: forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer } vData -> vData -> Accept vData
acceptVersion Versions
  vNumber vData (OuroborosApplication appType addr ByteString IO a b)
versions fd
sd = do
    ConnectionId addr
connectionId <- addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId (addr -> addr -> ConnectionId addr)
-> IO addr -> IO (addr -> ConnectionId addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd IO (addr -> ConnectionId addr) -> IO addr -> IO (ConnectionId addr)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getRemoteAddr Snocket IO fd addr
sn fd
sd
    Tracer IO MuxTrace
muxTracer <- Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall (m :: * -> *).
MonadSTM m =>
Tracer m MuxTrace -> m (Tracer m MuxTrace)
initDeltaQTracer' (Tracer IO MuxTrace -> IO (Tracer IO MuxTrace))
-> Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall a b. (a -> b) -> a -> b
$ ConnectionId addr
-> MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace
forall peerid a. peerid -> a -> WithMuxBearer peerid a
Mx.WithMuxBearer ConnectionId addr
connectionId (MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO MuxTrace
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer
    Time
ts_start <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime

    MuxBearer IO
handshakeBearer <- Snocket IO fd addr
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd addr.
Snocket m fd addr
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Snocket.toBearer Snocket IO fd addr
sn DiffTime
sduHandshakeTimeout Tracer IO MuxTrace
muxTracer fd
sd
    Either
  (HandshakeException vNumber)
  (OuroborosApplication appType addr ByteString IO a b, vNumber,
   vData)
app_e <-
      MuxBearer IO
-> ConnectionId addr
-> HandshakeArguments (ConnectionId addr) vNumber vData IO
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> IO
     (Either
        (HandshakeException vNumber)
        (OuroborosApplication appType addr ByteString IO a b, vNumber,
         vData))
forall (m :: * -> *) vNumber connectionId vData application.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
 MonadMask m, MonadThrow (STM m), Ord vNumber) =>
MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
runHandshakeClient
        MuxBearer IO
handshakeBearer
        ConnectionId addr
connectionId
        -- TODO: push 'HandshakeArguments' up the call stack.
        HandshakeArguments :: forall connectionId vNumber vData (m :: * -> *).
Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> ProtocolTimeLimits (Handshake vNumber Term)
-> HandshakeArguments connectionId vNumber vData m
HandshakeArguments {
          haHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer  = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer,
          haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
haHandshakeCodec   = Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec,
          haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec = VersionDataCodec Term vNumber vData
versionDataCodec,
          haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion    = vData -> vData -> Accept vData
acceptVersion,
          haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits       = ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
        }
        Versions
  vNumber vData (OuroborosApplication appType addr ByteString IO a b)
versions
    Time
ts_end <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
    case Either
  (HandshakeException vNumber)
  (OuroborosApplication appType addr ByteString IO a b, vNumber,
   vData)
app_e of
         Left (HandshakeProtocolLimit ProtocolLimitFailure
err) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> DiffTime -> MuxTrace
forall e. Exception e => e -> DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientError ProtocolLimitFailure
err (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             ProtocolLimitFailure -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
err

         Left (HandshakeProtocolError HandshakeProtocolError vNumber
err) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> DiffTime -> MuxTrace
forall e. Exception e => e -> DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientError HandshakeProtocolError vNumber
err (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             HandshakeProtocolError vNumber -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
err

         Right (OuroborosApplication appType addr ByteString IO a b
app, vNumber
_versionNumber, vData
_agreedOptions) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientEnd (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             MuxBearer IO
bearer <- Snocket IO fd addr
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd addr.
Snocket m fd addr
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Snocket.toBearer Snocket IO fd addr
sn DiffTime
sduTimeout Tracer IO MuxTrace
muxTracer fd
sd
             Tracer IO MuxTrace
-> MuxApplication appType IO a b -> MuxBearer IO -> IO ()
forall (m :: * -> *) (mode :: MuxMode) a b.
(MonadAsync m, MonadFork m, MonadLabelledSTM m, MonadThrow (STM m),
 MonadTime m, MonadTimer m, MonadMask m) =>
Tracer m MuxTrace
-> MuxApplication mode m a b -> MuxBearer m -> m ()
Mx.muxStart
               Tracer IO MuxTrace
muxTracer
               (ConnectionId addr
-> ControlMessageSTM IO
-> OuroborosApplication appType addr ByteString IO a b
-> MuxApplication appType IO a b
forall (m :: * -> *) addr (mode :: MuxMode) a b.
(MonadCatch m, MonadAsync m) =>
ConnectionId addr
-> ControlMessageSTM m
-> OuroborosApplication mode addr ByteString m a b
-> MuxApplication mode m a b
toApplication ConnectionId addr
connectionId (Proxy IO -> ControlMessageSTM IO
forall (m :: * -> *) (proxy :: (* -> *) -> *).
Applicative (STM m) =>
proxy m -> ControlMessageSTM m
continueForever (Proxy IO
forall k (t :: k). Proxy t
Proxy :: Proxy IO)) OuroborosApplication appType addr ByteString IO a b
app)
               MuxBearer IO
bearer


-- Wraps a Socket inside a Snocket and calls connectToNode'
connectToNodeSocket
  :: forall appType vNumber vData a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => IOManager
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers Socket.SockAddr vNumber
  -> (vData -> vData -> Accept vData)
  -> Versions vNumber vData (OuroborosApplication appType Socket.SockAddr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> Socket.Socket
  -> IO ()
connectToNodeSocket :: IOManager
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers SockAddr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber
     vData
     (OuroborosApplication appType SockAddr ByteString IO a b)
-> Socket
-> IO ()
connectToNodeSocket IOManager
iocp Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers SockAddr vNumber
tracers vData -> vData -> Accept vData
acceptVersion Versions
  vNumber
  vData
  (OuroborosApplication appType SockAddr ByteString IO a b)
versions Socket
sd =
    Snocket IO Socket SockAddr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers SockAddr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber
     vData
     (OuroborosApplication appType SockAddr ByteString IO a b)
-> Socket
-> IO ()
forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (OuroborosApplication appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode'
      (IOManager -> Snocket IO Socket SockAddr
Snocket.socketSnocket IOManager
iocp)
      Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
      ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
      VersionDataCodec Term vNumber vData
versionDataCodec
      NetworkConnectTracers SockAddr vNumber
tracers
      vData -> vData -> Accept vData
acceptVersion
      Versions
  vNumber
  vData
  (OuroborosApplication appType SockAddr ByteString IO a b)
versions
      Socket
sd

-- |
-- Wrapper for OuroborosResponderApplication and OuroborosInitiatorAndResponderApplication.
--
data SomeResponderApplication addr bytes m b where
     SomeResponderApplication
       :: forall appType addr bytes m a b.
          Mx.HasResponder appType ~ True
       => (OuroborosApplication appType addr bytes m a b)
       -> SomeResponderApplication addr bytes m b

-- |
-- Accept or reject an incoming connection.  Each record contains the new state
-- after accepting / rejecting a connection.  When accepting a connection one
-- has to give a mux application which necessarily has the server side, and
-- optionally has the client side.
--
-- TODO:
-- If the other side will not allow us to run the client side on the incoming
-- connection, the whole connection will terminate.  We might want to be more
-- admissible in this scenario: leave the server thread running and let only
-- the client thread to die.
data AcceptConnection st vNumber vData peerid m bytes where

    AcceptConnection
      :: forall st vNumber vData peerid bytes m b.
         !st
      -> !(ConnectionId peerid)
      -> Versions vNumber vData (SomeResponderApplication peerid bytes m b)
      -> AcceptConnection st vNumber vData peerid m bytes

    RejectConnection
      :: !st
      -> !(ConnectionId peerid)
      -> AcceptConnection st vNumber vData peerid m bytes


-- |
-- Accept or reject incoming connection based on the current state and address
-- of the incoming connection.
--
beginConnection
    :: forall vNumber vData addr st fd.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       )
    => Snocket IO fd addr
    -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) Mx.MuxTrace)
    -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) (TraceSendRecv (Handshake vNumber CBOR.Term)))
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> (vData -> vData -> Accept vData)
    -> (Time -> addr -> st -> STM.STM (AcceptConnection st vNumber vData addr IO BL.ByteString))
    -- ^ either accept or reject a connection.
    -> Server.BeginConnection addr fd st ()
beginConnection :: Snocket IO fd addr
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (Time
    -> addr
    -> st
    -> STM (AcceptConnection st vNumber vData addr IO ByteString))
-> BeginConnection addr fd st ()
beginConnection Snocket IO fd addr
sn Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
muxTracer Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
handshakeTracer Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec vData -> vData -> Accept vData
acceptVersion Time
-> addr
-> st
-> STM (AcceptConnection st vNumber vData addr IO ByteString)
fn Time
t addr
addr st
st = do
    AcceptConnection st vNumber vData addr IO ByteString
accept <- Time
-> addr
-> st
-> STM (AcceptConnection st vNumber vData addr IO ByteString)
fn Time
t addr
addr st
st
    case AcceptConnection st vNumber vData addr IO ByteString
accept of
      AcceptConnection st
st' ConnectionId addr
connectionId Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions -> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HandleConnection fd st () -> STM (HandleConnection fd st ()))
-> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a b. (a -> b) -> a -> b
$ st -> (fd -> IO ()) -> HandleConnection fd st ()
forall st channel r.
st -> (channel -> IO r) -> HandleConnection channel st r
Server.Accept st
st' ((fd -> IO ()) -> HandleConnection fd st ())
-> (fd -> IO ()) -> HandleConnection fd st ()
forall a b. (a -> b) -> a -> b
$ \fd
sd -> do
        Tracer IO MuxTrace
muxTracer' <- Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall (m :: * -> *).
MonadSTM m =>
Tracer m MuxTrace -> m (Tracer m MuxTrace)
initDeltaQTracer' (Tracer IO MuxTrace -> IO (Tracer IO MuxTrace))
-> Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall a b. (a -> b) -> a -> b
$ ConnectionId addr
-> MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace
forall peerid a. peerid -> a -> WithMuxBearer peerid a
Mx.WithMuxBearer ConnectionId addr
connectionId (MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO MuxTrace
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
muxTracer

        Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceHandshakeStart

        MuxBearer IO
handshakeBearer <- Snocket IO fd addr
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd addr.
Snocket m fd addr
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Snocket.toBearer Snocket IO fd addr
sn
                                            DiffTime
sduHandshakeTimeout
                                            Tracer IO MuxTrace
muxTracer' fd
sd
        Either
  (HandshakeException vNumber)
  (SomeResponderApplication addr ByteString IO b, vNumber, vData)
app_e <-
          MuxBearer IO
-> ConnectionId addr
-> HandshakeArguments (ConnectionId addr) vNumber vData IO
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> IO
     (Either
        (HandshakeException vNumber)
        (SomeResponderApplication addr ByteString IO b, vNumber, vData))
forall (m :: * -> *) vNumber connectionId vData application.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
 MonadMask m, MonadThrow (STM m), Ord vNumber) =>
MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
runHandshakeServer
            MuxBearer IO
handshakeBearer
            ConnectionId addr
connectionId
            HandshakeArguments :: forall connectionId vNumber vData (m :: * -> *).
Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> ProtocolTimeLimits (Handshake vNumber Term)
-> HandshakeArguments connectionId vNumber vData m
HandshakeArguments {
              haHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer  = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
handshakeTracer,
              haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
haHandshakeCodec   = Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec,
              haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec = VersionDataCodec Term vNumber vData
versionDataCodec,
              haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion    = vData -> vData -> Accept vData
acceptVersion,
              haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits       = ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
            }
           Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions

        case Either
  (HandshakeException vNumber)
  (SomeResponderApplication addr ByteString IO b, vNumber, vData)
app_e of
             Left (HandshakeProtocolLimit ProtocolLimitFailure
err) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> MuxTrace
forall e. Exception e => e -> MuxTrace
Mx.MuxTraceHandshakeServerError ProtocolLimitFailure
err
                 ProtocolLimitFailure -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
err

             Left (HandshakeProtocolError HandshakeProtocolError vNumber
err) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> MuxTrace
forall e. Exception e => e -> MuxTrace
Mx.MuxTraceHandshakeServerError HandshakeProtocolError vNumber
err
                 HandshakeProtocolError vNumber -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
err

             Right (SomeResponderApplication OuroborosApplication appType addr ByteString IO a b
app, vNumber
_versionNumber, vData
_agreedOptions) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceHandshakeServerEnd
                 MuxBearer IO
bearer <- Snocket IO fd addr
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd addr.
Snocket m fd addr
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Snocket.toBearer Snocket IO fd addr
sn DiffTime
sduTimeout Tracer IO MuxTrace
muxTracer' fd
sd
                 Tracer IO MuxTrace
-> MuxApplication appType IO a b -> MuxBearer IO -> IO ()
forall (m :: * -> *) (mode :: MuxMode) a b.
(MonadAsync m, MonadFork m, MonadLabelledSTM m, MonadThrow (STM m),
 MonadTime m, MonadTimer m, MonadMask m) =>
Tracer m MuxTrace
-> MuxApplication mode m a b -> MuxBearer m -> m ()
Mx.muxStart
                   Tracer IO MuxTrace
muxTracer'
                   (ConnectionId addr
-> ControlMessageSTM IO
-> OuroborosApplication appType addr ByteString IO a b
-> MuxApplication appType IO a b
forall (m :: * -> *) addr (mode :: MuxMode) a b.
(MonadCatch m, MonadAsync m) =>
ConnectionId addr
-> ControlMessageSTM m
-> OuroborosApplication mode addr ByteString m a b
-> MuxApplication mode m a b
toApplication ConnectionId addr
connectionId (Proxy IO -> ControlMessageSTM IO
forall (m :: * -> *) (proxy :: (* -> *) -> *).
Applicative (STM m) =>
proxy m -> ControlMessageSTM m
continueForever (Proxy IO
forall k (t :: k). Proxy t
Proxy :: Proxy IO)) OuroborosApplication appType addr ByteString IO a b
app)
                   MuxBearer IO
bearer

      RejectConnection st
st' ConnectionId addr
_peerid -> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HandleConnection fd st () -> STM (HandleConnection fd st ()))
-> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a b. (a -> b) -> a -> b
$ st -> HandleConnection fd st ()
forall st channel r. st -> HandleConnection channel st r
Server.Reject st
st'

mkListeningSocket
    :: Snocket IO fd addr
    -> Maybe addr
    -> Snocket.AddressFamily addr
    -> IO fd
mkListeningSocket :: Snocket IO fd addr -> Maybe addr -> AddressFamily addr -> IO fd
mkListeningSocket Snocket IO fd addr
sn Maybe addr
addr AddressFamily addr
family_ = do
    fd
sd <- Snocket IO fd addr -> AddressFamily addr -> IO fd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket IO fd addr
sn AddressFamily addr
family_

    case Maybe addr
addr of
      Maybe addr
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Just addr
addr_ -> do
        Snocket IO fd addr -> fd -> addr -> IO ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket IO fd addr
sn fd
sd addr
addr_
        Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket IO fd addr
sn fd
sd
    fd -> IO fd
forall (f :: * -> *) a. Applicative f => a -> f a
pure fd
sd

-- |
-- Make a server-compatible socket from a network socket.
--
fromSnocket
    :: forall fd addr. Ord addr
    => ConnectionTable IO addr
    -> Snocket IO fd addr
    -> fd -- ^ socket or handle
    -> IO (Server.Socket addr fd)
fromSnocket :: ConnectionTable IO addr
-> Snocket IO fd addr -> fd -> IO (Socket addr fd)
fromSnocket ConnectionTable IO addr
tblVar Snocket IO fd addr
sn fd
sd = Accept IO fd addr -> Socket addr fd
go (Accept IO fd addr -> Socket addr fd)
-> IO (Accept IO fd addr) -> IO (Socket addr fd)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Snocket IO fd addr -> fd -> IO (Accept IO fd addr)
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> m (Accept m fd addr)
Snocket.accept Snocket IO fd addr
sn fd
sd
  where
    go :: Snocket.Accept IO fd addr -> Server.Socket addr fd
    go :: Accept IO fd addr -> Socket addr fd
go (Snocket.Accept IO (Accepted fd addr, Accept IO fd addr)
accept) = IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd
forall addr channel.
IO (addr, channel, IO (), Socket addr channel)
-> Socket addr channel
Server.Socket (IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd)
-> IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd
forall a b. (a -> b) -> a -> b
$ do
      (Accepted fd addr
result, Accept IO fd addr
next) <- IO (Accepted fd addr, Accept IO fd addr)
accept
      case Accepted fd addr
result of
        Snocket.Accepted fd
sd' addr
remoteAddr -> do
          -- TOOD: we don't need to that on each accept
          addr
localAddr <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd'
          STM IO () -> IO ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO () -> IO ()) -> STM IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ConnectionTable IO addr
-> addr -> addr -> Maybe (ValencyCounter IO) -> STM IO ()
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr -> addr -> Maybe (ValencyCounter m) -> STM m ()
addConnection ConnectionTable IO addr
tblVar addr
remoteAddr addr
localAddr Maybe (ValencyCounter IO)
forall a. Maybe a
Nothing
          (addr, fd, IO (), Socket addr fd)
-> IO (addr, fd, IO (), Socket addr fd)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (addr
remoteAddr, fd
sd', addr -> addr -> fd -> IO ()
close addr
remoteAddr addr
localAddr fd
sd', Accept IO fd addr -> Socket addr fd
go Accept IO fd addr
next)
        Snocket.AcceptFailure SomeException
err ->
          -- the is no way to construct 'Server.Socket'; This will be removed in a later commit!
          SomeException -> IO (addr, fd, IO (), Socket addr fd)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
err

    close :: addr -> addr -> fd -> IO ()
close addr
remoteAddr addr
localAddr fd
sd' = do
        ConnectionTable IO addr -> addr -> addr -> IO ()
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr -> addr -> addr -> m ()
removeConnection ConnectionTable IO addr
tblVar addr
remoteAddr addr
localAddr
        Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn fd
sd'


-- | Tracers required by a server which handles inbound connections.
--
data NetworkServerTracers addr vNumber = NetworkServerTracers {
      NetworkServerTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer         :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr) Mx.MuxTrace),
      -- ^ low level mux-network tracer, which logs mux sdu (send and received)
      -- and other low level multiplexing events.
      NetworkServerTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer   :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)
                                          (TraceSendRecv (Handshake vNumber CBOR.Term))),
      -- ^ handshake protocol tracer; it is important for analysing version
      -- negotation mismatches.
      NetworkServerTracers addr vNumber
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace),
      -- ^ error policy tracer; must not be 'nullTracer', otherwise all the
      -- exceptions which are not matched by any error policy will be caught
      -- and not logged or rethrown.
      NetworkServerTracers addr vNumber
-> Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
      -- ^ tracing rate limiting of accepting connections.
    }

nullNetworkServerTracers :: NetworkServerTracers addr vNumber
nullNetworkServerTracers :: NetworkServerTracers addr vNumber
nullNetworkServerTracers = NetworkServerTracers :: forall addr vNumber.
Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> NetworkServerTracers addr vNumber
NetworkServerTracers {
      nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer          = Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer    = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer  = Tracer IO (WithAddr addr ErrorPolicyTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer = Tracer IO AcceptConnectionsPolicyTrace
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
    }

debuggingNetworkServerTracers :: (Show addr, Show vNumber)
                              =>  NetworkServerTracers addr vNumber
debuggingNetworkServerTracers :: NetworkServerTracers addr vNumber
debuggingNetworkServerTracers = NetworkServerTracers :: forall addr vNumber.
Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> NetworkServerTracers addr vNumber
NetworkServerTracers {
      nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer          = Tracer IO String
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer    = Tracer IO String
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer  = Tracer IO String -> Tracer IO (WithAddr addr ErrorPolicyTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer = Tracer IO String -> Tracer IO AcceptConnectionsPolicyTrace
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer
    }


-- | Mutable state maintained by the network component.
--
data NetworkMutableState addr = NetworkMutableState {
    NetworkMutableState addr -> ConnectionTable IO addr
nmsConnectionTable :: ConnectionTable IO addr,
    -- ^ 'ConnectionTable' which maintains information about current upstream and
    -- downstream connections.
    NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates      :: StrictTVar IO (PeerStates IO addr)
    -- ^ 'PeerStates' which maintains state of each downstream / upstream peer
    -- that errored, misbehaved or was not interesting to us.
  }

newNetworkMutableStateSTM :: STM.STM (NetworkMutableState addr)
newNetworkMutableStateSTM :: STM (NetworkMutableState addr)
newNetworkMutableStateSTM =
    ConnectionTable IO addr
-> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr
forall addr.
ConnectionTable IO addr
-> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr
NetworkMutableState (ConnectionTable IO addr
 -> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
-> STM (ConnectionTable IO addr)
-> STM
     (StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (ConnectionTable IO addr)
forall (m :: * -> *) addr.
MonadSTM m =>
STM m (ConnectionTable m addr)
newConnectionTableSTM
                        STM
  (StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
-> STM (StrictTVar IO (PeerStates IO addr))
-> STM (NetworkMutableState addr)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM (StrictTVar IO (PeerStates IO addr))
forall (m :: * -> *) addr.
MonadSTM m =>
STM m (StrictTVar m (PeerStates m addr))
newPeerStatesVarSTM

newNetworkMutableState :: IO (NetworkMutableState addr)
newNetworkMutableState :: IO (NetworkMutableState addr)
newNetworkMutableState = STM IO (NetworkMutableState addr) -> IO (NetworkMutableState addr)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM IO (NetworkMutableState addr)
forall addr. STM (NetworkMutableState addr)
newNetworkMutableStateSTM

-- | Clean 'PeerStates' within 'NetworkMutableState' every 200s
--
cleanNetworkMutableState :: NetworkMutableState addr
                         -> IO ()
cleanNetworkMutableState :: NetworkMutableState addr -> IO ()
cleanNetworkMutableState NetworkMutableState {StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: forall addr.
NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates} =
    DiffTime -> StrictTVar IO (PeerStates IO addr) -> IO ()
forall (m :: * -> *) addr.
(MonadSTM m, MonadAsync m, MonadTime m, MonadTimer m) =>
DiffTime -> StrictTVar m (PeerStates m addr) -> m ()
cleanPeerStates DiffTime
200 StrictTVar IO (PeerStates IO addr)
nmsPeerStates

-- |
-- Thin wrapper around @'Server.run'@.
--
runServerThread
    :: forall vNumber vData fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> Snocket IO fd addr
    -> fd
    -> AcceptedConnectionsLimit
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> (vData -> vData -> Accept vData)
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -> ErrorPolicies
    -> IO Void
runServerThread :: NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> Snocket IO fd addr
-> fd
-> AcceptedConnectionsLimit
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> IO Void
runServerThread NetworkServerTracers { Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer
                                     , Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer
                                     , Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer
                                     , Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer
                                     }
                NetworkMutableState { ConnectionTable IO addr
nmsConnectionTable :: ConnectionTable IO addr
nmsConnectionTable :: forall addr. NetworkMutableState addr -> ConnectionTable IO addr
nmsConnectionTable
                                    , StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: forall addr.
NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates }
                Snocket IO fd addr
sn
                fd
sd
                AcceptedConnectionsLimit
acceptedConnectionsLimit
                Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
                ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
                VersionDataCodec Term vNumber vData
versionDataCodec
                vData -> vData -> Accept vData
acceptVersion
                Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
                ErrorPolicies
errorPolicies = do
    addr
sockAddr <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd
    Socket addr fd
serverSocket <- ConnectionTable IO addr
-> Snocket IO fd addr -> fd -> IO (Socket addr fd)
forall fd addr.
Ord addr =>
ConnectionTable IO addr
-> Snocket IO fd addr -> fd -> IO (Socket addr fd)
fromSnocket ConnectionTable IO addr
nmsConnectionTable Snocket IO fd addr
sn fd
sd
    Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> Socket addr fd
-> AcceptedConnectionsLimit
-> (IOException -> IO ())
-> BeginConnection addr fd (PeerStates IO addr) ()
-> ApplicationStart addr (PeerStates IO addr)
-> CompleteConnection addr (PeerStates IO addr) Any ()
-> Main (PeerStates IO addr) Void
-> TVar (PeerStates IO addr)
-> IO Void
forall addr channel st r tr t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> Socket addr channel
-> AcceptedConnectionsLimit
-> (IOException -> IO ())
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> CompleteConnection addr st tr r
-> Main st t
-> TVar st
-> IO t
Server.run
        Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer
        Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer
        Socket addr fd
serverSocket
        AcceptedConnectionsLimit
acceptedConnectionsLimit
        (addr -> IOException -> IO ()
acceptException addr
sockAddr)
        (Snocket IO fd addr
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (Time
    -> addr
    -> PeerStates IO addr
    -> STM
         (AcceptConnection
            (PeerStates IO addr) vNumber vData addr IO ByteString))
-> BeginConnection addr fd (PeerStates IO addr) ()
forall vNumber vData addr st fd.
(Ord vNumber, Typeable vNumber, Show vNumber) =>
Snocket IO fd addr
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (Time
    -> addr
    -> st
    -> STM (AcceptConnection st vNumber vData addr IO ByteString))
-> BeginConnection addr fd st ()
beginConnection Snocket IO fd addr
sn Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec vData -> vData -> Accept vData
acceptVersion (addr
-> Time
-> addr
-> PeerStates IO addr
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
acceptConnectionTx addr
sockAddr))
        -- register producer when application starts, it will be unregistered
        -- using 'CompleteConnection'
        (\addr
remoteAddr Async ()
thread PeerStates IO addr
st -> PeerStates IO addr -> STM (PeerStates IO addr)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PeerStates IO addr -> STM (PeerStates IO addr))
-> PeerStates IO addr -> STM (PeerStates IO addr)
forall a b. (a -> b) -> a -> b
$ addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
registerProducer addr
remoteAddr Async ()
Async IO ()
thread
        PeerStates IO addr
st)
        CompleteConnection addr (PeerStates IO addr) Any ()
completeTx Main (PeerStates IO addr) Void
mainTx (StrictTVar IO (PeerStates IO addr)
-> LazyTVar IO (PeerStates IO addr)
forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
toLazyTVar StrictTVar IO (PeerStates IO addr)
nmsPeerStates)
  where
    mainTx :: Server.Main (PeerStates IO addr) Void
    mainTx :: Main (PeerStates IO addr) Void
mainTx (ThrowException e
e) = e -> STM Void
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO e
e
    mainTx PeerStates{}       = STM Void
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

    -- When a connection completes, we do nothing. State is ().
    -- Crucially: we don't re-throw exceptions, because doing so would
    -- bring down the server.
    completeTx :: Server.CompleteConnection
                    addr
                    (PeerStates IO addr)
                    (WithAddr addr ErrorPolicyTrace)
                    ()
    completeTx :: CompleteConnection addr (PeerStates IO addr) Any ()
completeTx Result addr ()
result PeerStates IO addr
st = case Result addr ()
result of

      Server.Result Async ()
thread addr
remoteAddr Time
t (Left (SomeException e
e)) ->
        (PeerStates IO addr -> PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
unregisterProducer addr
remoteAddr Async ()
Async IO ()
thread)
          (CompleteApplicationResult IO addr (PeerStates IO addr)
 -> CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorPolicies
-> CompleteApplication IO (PeerStates IO addr) addr Any
forall (m :: * -> *) addr a.
(MonadAsync m, Ord addr, Ord (Async m ())) =>
ErrorPolicies -> CompleteApplication m (PeerStates m addr) addr a
completeApplicationTx ErrorPolicies
errorPolicies (Time -> addr -> e -> Result addr Any
forall e addr r. Exception e => Time -> addr -> e -> Result addr r
ApplicationError Time
t addr
remoteAddr e
e) PeerStates IO addr
st

      Server.Result Async ()
thread addr
remoteAddr Time
t (Right ()
r) ->
        (PeerStates IO addr -> PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
unregisterProducer addr
remoteAddr Async ()
Async IO ()
thread)
          (CompleteApplicationResult IO addr (PeerStates IO addr)
 -> CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorPolicies
-> CompleteApplication IO (PeerStates IO addr) addr ()
forall (m :: * -> *) addr a.
(MonadAsync m, Ord addr, Ord (Async m ())) =>
ErrorPolicies -> CompleteApplication m (PeerStates m addr) addr a
completeApplicationTx ErrorPolicies
errorPolicies (Time -> addr -> () -> Result addr ()
forall addr r. Time -> addr -> r -> Result addr r
ApplicationResult Time
t addr
remoteAddr ()
r) PeerStates IO addr
st

    iseCONNABORTED :: IOError -> Bool
#if defined(mingw32_HOST_OS)
    -- On Windows the network packet classifies all errors
    -- as OtherError. This means that we're forced to match
    -- on the error string. The text string comes from
    -- the network package's winSockErr.c, and if it ever
    -- changes we must update our text string too.
    iseCONNABORTED (IOError _ _ _ "Software caused connection abort (WSAECONNABORTED)" _ _) = True
    iseCONNABORTED _ = False
#else
    iseCONNABORTED :: IOException -> Bool
iseCONNABORTED (IOError Maybe Handle
_ IOErrorType
_ String
_ String
_ (Just CInt
cerrno) Maybe String
_) = Errno
eCONNABORTED Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== CInt -> Errno
Errno CInt
cerrno
#if defined(darwin_HOST_OS)
    -- There is a bug in accept for IPv6 sockets. Instead of returning -1
    -- and setting errno to ECONNABORTED an invalid (>= 0) file descriptor
    -- is returned, with the client address left unchanged. The uninitialized
    -- client address causes the network package to throw the user error below.
    iseCONNABORTED (IOError Maybe Handle
_ IOErrorType
UserError String
_ String
"Network.Socket.Types.peekSockAddr: address family '0' not supported." Maybe CInt
_ Maybe String
_) = Bool
True
#endif
    iseCONNABORTED IOException
_ = Bool
False
#endif


    acceptException :: addr -> IOException -> IO ()
    acceptException :: addr -> IOException -> IO ()
acceptException addr
a IOException
e = do
      Tracer IO ErrorPolicyTrace -> ErrorPolicyTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith (addr -> ErrorPolicyTrace -> WithAddr addr ErrorPolicyTrace
forall addr a. addr -> a -> WithAddr addr a
WithAddr addr
a (ErrorPolicyTrace -> WithAddr addr ErrorPolicyTrace)
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO ErrorPolicyTrace
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer) (ErrorPolicyTrace -> IO ()) -> ErrorPolicyTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ IOException -> ErrorPolicyTrace
ErrorPolicyAcceptException IOException
e

      -- Try the determine if the connection was aborted by the remote end
      -- before we could process the accept, or if it was a resource exaustion
      -- problem.
      -- NB. This piece of code is fragile and depends on specific
      -- strings/mappings in the network and base libraries.
      if IOException -> Bool
iseCONNABORTED IOException
e then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                          else IOException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO IOException
e

    acceptConnectionTx :: addr
-> Time
-> addr
-> PeerStates IO addr
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
acceptConnectionTx addr
sockAddr Time
t addr
connAddr PeerStates IO addr
st = do
      ConnectDecision (PeerStates IO addr)
d <- BeforeConnect IO (PeerStates IO addr) addr
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
BeforeConnect m (PeerStates m addr) addr
beforeConnectTx Time
t addr
connAddr PeerStates IO addr
st
      case ConnectDecision (PeerStates IO addr)
d of
        AllowConnection PeerStates IO addr
st'    -> AcceptConnection
  (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AcceptConnection
   (PeerStates IO addr) vNumber vData addr IO ByteString
 -> STM
      (AcceptConnection
         (PeerStates IO addr) vNumber vData addr IO ByteString))
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a b. (a -> b) -> a -> b
$ PeerStates IO addr
-> ConnectionId addr
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
forall st vNumber vData peerid bytes (m :: * -> *) b.
st
-> ConnectionId peerid
-> Versions
     vNumber vData (SomeResponderApplication peerid bytes m b)
-> AcceptConnection st vNumber vData peerid m bytes
AcceptConnection PeerStates IO addr
st' (addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
sockAddr addr
connAddr) Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
        DisallowConnection PeerStates IO addr
st' -> AcceptConnection
  (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AcceptConnection
   (PeerStates IO addr) vNumber vData addr IO ByteString
 -> STM
      (AcceptConnection
         (PeerStates IO addr) vNumber vData addr IO ByteString))
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a b. (a -> b) -> a -> b
$ PeerStates IO addr
-> ConnectionId addr
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
forall st peerid vNumber vData (m :: * -> *) bytes.
st
-> ConnectionId peerid
-> AcceptConnection st vNumber vData peerid m bytes
RejectConnection PeerStates IO addr
st' (addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
sockAddr addr
connAddr)

-- | Run a server application. It will listen on the given address for incoming
-- connection, otherwise like withServerNode'.
withServerNode
    :: forall vNumber vData t fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => Snocket IO fd addr
    -> NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> AcceptedConnectionsLimit
    -> addr
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> (vData -> vData -> Accept vData)
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -- ^ The mux application that will be run on each incoming connection from
    -- a given address.  Note that if @'MuxClientAndServerApplication'@ is
    -- returned, the connection will run a full duplex set of mini-protocols.
    -> ErrorPolicies
    -> (addr -> Async Void -> IO t)
    -- ^ callback which takes the @Async@ of the thread that is running the server.
    -- Note: the server thread will terminate when the callback returns or
    -- throws an exception.
    -> IO t
withServerNode :: Snocket IO fd addr
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode Snocket IO fd addr
sn
               NetworkServerTracers addr vNumber
tracers
               NetworkMutableState addr
networkState
               AcceptedConnectionsLimit
acceptedConnectionsLimit
               addr
addr
               Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
               ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
               VersionDataCodec Term vNumber vData
versionDataCodec
               vData -> vData -> Accept vData
acceptVersion
               Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
               ErrorPolicies
errorPolicies
               addr -> Async Void -> IO t
k =
    IO fd -> (fd -> IO ()) -> (fd -> IO t) -> IO t
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Snocket IO fd addr -> Maybe addr -> AddressFamily addr -> IO fd
forall fd addr.
Snocket IO fd addr -> Maybe addr -> AddressFamily addr -> IO fd
mkListeningSocket Snocket IO fd addr
sn (addr -> Maybe addr
forall a. a -> Maybe a
Just addr
addr) (Snocket IO fd addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket IO fd addr
sn addr
addr)) (Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn) ((fd -> IO t) -> IO t) -> (fd -> IO t) -> IO t
forall a b. (a -> b) -> a -> b
$ \fd
sd ->
      Snocket IO fd addr
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
forall vNumber vData t fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
Snocket IO fd addr
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode'
        Snocket IO fd addr
sn
        NetworkServerTracers addr vNumber
tracers
        NetworkMutableState addr
networkState
        AcceptedConnectionsLimit
acceptedConnectionsLimit
        fd
sd
        Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
        ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
        VersionDataCodec Term vNumber vData
versionDataCodec
        vData -> vData -> Accept vData
acceptVersion
        Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
        ErrorPolicies
errorPolicies
        addr -> Async Void -> IO t
k

-- |
-- Run a server application on the provided socket. The socket must be ready to accept connections.
-- The server thread runs using @withAsync@ function, which means
-- that it will terminate when the callback terminates or throws an exception.
--
-- TODO: we should track connections in the state and refuse connections from
-- peers we are already connected to.  This is also the right place to ban
-- connection from peers which missbehaved.
--
-- The server will run handshake protocol on each incoming connection.  We
-- assume that each versin negotiation message should fit into
-- @'maxTransmissionUnit'@ (~5k bytes).
--
-- Note: it will open a socket in the current thread and pass it to the spawned
-- thread which runs the server.  This makes it useful for testing, where we
-- need to guarantee that a socket is open before we try to connect to it.
withServerNode'
    :: forall vNumber vData t fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => Snocket IO fd addr
    -> NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> AcceptedConnectionsLimit
    -> fd
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> (vData -> vData -> Accept vData)
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -- ^ The mux application that will be run on each incoming connection from
    -- a given address.  Note that if @'MuxClientAndServerApplication'@ is
    -- returned, the connection will run a full duplex set of mini-protocols.
    -> ErrorPolicies
    -> (addr -> Async Void -> IO t)
    -- ^ callback which takes the @Async@ of the thread that is running the server.
    -- Note: the server thread will terminate when the callback returns or
    -- throws an exception.
    -> IO t
withServerNode' :: Snocket IO fd addr
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode' Snocket IO fd addr
sn
                NetworkServerTracers addr vNumber
tracers
                NetworkMutableState addr
networkState
                AcceptedConnectionsLimit
acceptedConnectionsLimit
                fd
sd
                Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
                ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
                VersionDataCodec Term vNumber vData
versionDataCodec
                vData -> vData -> Accept vData
acceptVersion
                Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
                ErrorPolicies
errorPolicies
                addr -> Async Void -> IO t
k = do
      addr
addr' <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd
      IO Void -> (Async Void -> IO t) -> IO t
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync
        (NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> Snocket IO fd addr
-> fd
-> AcceptedConnectionsLimit
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> IO Void
forall vNumber vData fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> Snocket IO fd addr
-> fd
-> AcceptedConnectionsLimit
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> IO Void
runServerThread
          NetworkServerTracers addr vNumber
tracers
          NetworkMutableState addr
networkState
          Snocket IO fd addr
sn
          fd
sd
          AcceptedConnectionsLimit
acceptedConnectionsLimit
          Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
          ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
          VersionDataCodec Term vNumber vData
versionDataCodec
          vData -> vData -> Accept vData
acceptVersion
          Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
          ErrorPolicies
errorPolicies)
        (addr -> Async Void -> IO t
k addr
addr')