{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ouroboros.Network.Protocol.Handshake
( runHandshakeClient
, runHandshakeServer
, HandshakeArguments (..)
, Versions (..)
, HandshakeException (..)
, HandshakeProtocolError (..)
, RefuseReason (..)
, Accept (..)
) where
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Term as CBOR
import Control.Tracer (Tracer, contramap)
import qualified Data.ByteString.Lazy as BL
import Network.Mux.Trace
import Network.Mux.Types
import Network.TypedProtocol.Codec
import Ouroboros.Network.Channel
import Ouroboros.Network.Driver.Limits
import Ouroboros.Network.Protocol.Handshake.Client
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Server
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum = Word16 -> MiniProtocolNum
MiniProtocolNum Word16
0
data HandshakeException vNumber =
HandshakeProtocolLimit ProtocolLimitFailure
| HandshakeProtocolError (HandshakeProtocolError vNumber)
deriving Int -> HandshakeException vNumber -> ShowS
[HandshakeException vNumber] -> ShowS
HandshakeException vNumber -> String
(Int -> HandshakeException vNumber -> ShowS)
-> (HandshakeException vNumber -> String)
-> ([HandshakeException vNumber] -> ShowS)
-> Show (HandshakeException vNumber)
forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HandshakeException vNumber] -> ShowS
$cshowList :: forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
show :: HandshakeException vNumber -> String
$cshow :: forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
showsPrec :: Int -> HandshakeException vNumber -> ShowS
$cshowsPrec :: forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
Show
tryHandshake :: forall m vNumber r.
( MonadAsync m
, MonadMask m
)
=> m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake :: m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake m (Either (HandshakeProtocolError vNumber) r)
doHandshake = do
Either
ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r)
mapp <- m (Either (HandshakeProtocolError vNumber) r)
-> m (Either
ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r))
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try m (Either (HandshakeProtocolError vNumber) r)
doHandshake
case Either
ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r)
mapp of
Left ProtocolLimitFailure
err ->
Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
-> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> HandshakeException vNumber
forall vNumber. ProtocolLimitFailure -> HandshakeException vNumber
HandshakeProtocolLimit ProtocolLimitFailure
err
Right (Left HandshakeProtocolError vNumber
err) ->
Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
-> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> HandshakeException vNumber
forall vNumber.
HandshakeProtocolError vNumber -> HandshakeException vNumber
HandshakeProtocolError HandshakeProtocolError vNumber
err
Right (Right r
r) -> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ r -> Either (HandshakeException vNumber) r
forall a b. b -> Either a b
Right r
r
data HandshakeArguments connectionId vNumber vData m = HandshakeArguments {
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer m (WithMuxBearer connectionId
(TraceSendRecv (Handshake vNumber CBOR.Term))),
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
:: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString,
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec
:: VersionDataCodec CBOR.Term vNumber vData,
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData,
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
:: ProtocolTimeLimits (Handshake vNumber CBOR.Term)
}
runHandshakeClient
:: ( MonadAsync m
, MonadFork m
, MonadMonotonicTime m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either (HandshakeException vNumber)
(application, vNumber, vData))
runHandshakeClient :: MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
(HandshakeException vNumber) (application, vNumber, vData))
runHandshakeClient MuxBearer m
bearer
connectionId
connectionId
HandshakeArguments {
Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion,
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
}
Versions vNumber vData application
versions =
m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeException vNumber) (application, vNumber, vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
((Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadMonotonicTime m, MonadTimer m,
forall (st' :: ps). Show (ClientHasAgency st'),
forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
(Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
InitiatorDir))
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either (HandshakeProtocolError vNumber) (r, vNumber, vData))
handshakeClientPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion Versions vNumber vData application
versions))
runHandshakeServer
:: ( MonadAsync m
, MonadFork m
, MonadMonotonicTime m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
(HandshakeException vNumber)
(application, vNumber, vData))
runHandshakeServer :: MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
(HandshakeException vNumber) (application, vNumber, vData))
runHandshakeServer MuxBearer m
bearer
connectionId
connectionId
HandshakeArguments {
Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion,
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
}
Versions vNumber vData application
versions =
m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeException vNumber) (application, vNumber, vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
((Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsServer
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeProtocolError vNumber) (application, vNumber, vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadMonotonicTime m, MonadTimer m,
forall (st' :: ps). Show (ClientHasAgency st'),
forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
(Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
ResponderDir))
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsServer
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
(Handshake vNumber Term)
'AsServer
'StPropose
m
(Either (HandshakeProtocolError vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion Versions vNumber vData application
versions))