{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | API for running 'Handshake' protocol.
--
module Ouroboros.Network.Protocol.Handshake
  ( runHandshakeClient
  , runHandshakeServer
  , HandshakeArguments (..)
  , Versions (..)
  , HandshakeException (..)
  , HandshakeProtocolError (..)
  , RefuseReason (..)
  , Accept (..)
  ) where

import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer

import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Term as CBOR
import           Control.Tracer (Tracer, contramap)
import qualified Data.ByteString.Lazy as BL

import           Network.Mux.Trace
import           Network.Mux.Types
import           Network.TypedProtocol.Codec

import           Ouroboros.Network.Channel
import           Ouroboros.Network.Driver.Limits

import           Ouroboros.Network.Protocol.Handshake.Client
import           Ouroboros.Network.Protocol.Handshake.Codec
import           Ouroboros.Network.Protocol.Handshake.Server
import           Ouroboros.Network.Protocol.Handshake.Type
import           Ouroboros.Network.Protocol.Handshake.Version


-- | The handshake protocol number.
--
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum = Word16 -> MiniProtocolNum
MiniProtocolNum Word16
0

-- | Wrapper around initiator and responder errors experienced by tryHandshake.
--
data HandshakeException vNumber =
    HandshakeProtocolLimit ProtocolLimitFailure
  | HandshakeProtocolError (HandshakeProtocolError vNumber)
  deriving Int -> HandshakeException vNumber -> ShowS
[HandshakeException vNumber] -> ShowS
HandshakeException vNumber -> String
(Int -> HandshakeException vNumber -> ShowS)
-> (HandshakeException vNumber -> String)
-> ([HandshakeException vNumber] -> ShowS)
-> Show (HandshakeException vNumber)
forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HandshakeException vNumber] -> ShowS
$cshowList :: forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
show :: HandshakeException vNumber -> String
$cshow :: forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
showsPrec :: Int -> HandshakeException vNumber -> ShowS
$cshowsPrec :: forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
Show


-- | Try to complete either initiator or responder side of the Handshake protocol
-- within `handshakeTimeout` seconds.
--
tryHandshake :: forall m vNumber r.
                ( MonadAsync m
                , MonadMask m
                )
             => m (Either (HandshakeProtocolError vNumber) r)
             -> m (Either (HandshakeException vNumber)     r)
tryHandshake :: m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake m (Either (HandshakeProtocolError vNumber) r)
doHandshake = do
    Either
  ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r)
mapp <- m (Either (HandshakeProtocolError vNumber) r)
-> m (Either
        ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r))
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try m (Either (HandshakeProtocolError vNumber) r)
doHandshake
    case Either
  ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r)
mapp of
      Left ProtocolLimitFailure
err ->
          Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
 -> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
 -> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> HandshakeException vNumber
forall vNumber. ProtocolLimitFailure -> HandshakeException vNumber
HandshakeProtocolLimit ProtocolLimitFailure
err
      Right (Left HandshakeProtocolError vNumber
err) ->
          Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
 -> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
 -> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> HandshakeException vNumber
forall vNumber.
HandshakeProtocolError vNumber -> HandshakeException vNumber
HandshakeProtocolError HandshakeProtocolError vNumber
err
      Right (Right r
r) -> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
 -> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ r -> Either (HandshakeException vNumber) r
forall a b. b -> Either a b
Right r
r


--
-- Record arguemnts
--

-- | Common arguments for both 'Handshake' client & server.
--
data HandshakeArguments connectionId vNumber vData m = HandshakeArguments {
      -- | 'Handshake' tracer
      --
      HandshakeArguments connectionId vNumber vData m
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer m (WithMuxBearer connectionId
                                     (TraceSendRecv (Handshake vNumber CBOR.Term))),
      -- | Codec for protocol messages.
      --
      HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
        ::  Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString,

      -- | A codec for protocol parameters.
      --
      HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec
        ::  VersionDataCodec CBOR.Term vNumber vData,

      -- | accept version, first argument is our version data the second
      -- argument is the remote version data.
      HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData,

      -- | 'Driver' timeouts for 'Handshake' protocol.
      --
      HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
        :: ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    }


-- | Run client side of the 'Handshake' protocol
--
runHandshakeClient
    :: ( MonadAsync m
       , MonadFork m
       , MonadMonotonicTime m
       , MonadTimer m
       , MonadMask m
       , MonadThrow (STM m)
       , Ord vNumber
       )
    => MuxBearer m
    -> connectionId
    -> HandshakeArguments connectionId vNumber vData m
    -> Versions vNumber vData application
    -> m (Either (HandshakeException vNumber)
                 (application, vNumber, vData))
runHandshakeClient :: MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
runHandshakeClient MuxBearer m
bearer
                   connectionId
connectionId
                   HandshakeArguments {
                     Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
                     Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
                     VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
                     vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion,
                     ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
                   }
                   Versions vNumber vData application
versions  =
    m (Either
     (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
      ((Either
   (HandshakeProtocolError vNumber) (application, vNumber, vData),
 Maybe ByteString)
-> Either
     (HandshakeProtocolError vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
    (HandshakeProtocolError vNumber) (application, vNumber, vData),
  Maybe ByteString)
 -> Either
      (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData),
      Maybe ByteString)
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData),
      Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadMonotonicTime m, MonadTimer m,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
          (connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
 -> WithMuxBearer
      connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
          Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
          ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
          ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
          (Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
InitiatorDir))
          (VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either (HandshakeProtocolError vNumber) (r, vNumber, vData))
handshakeClientPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion Versions vNumber vData application
versions))


-- | Run server side of the 'Handshake' protocol.
--
runHandshakeServer
    :: ( MonadAsync m
       , MonadFork m
       , MonadMonotonicTime m
       , MonadTimer m
       , MonadMask m
       , MonadThrow (STM m)
       , Ord vNumber
       )
    => MuxBearer m
    -> connectionId
    -> HandshakeArguments connectionId vNumber vData m
    -> Versions vNumber vData application
    -> m (Either
           (HandshakeException vNumber)
           (application, vNumber, vData))
runHandshakeServer :: MuxBearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
runHandshakeServer MuxBearer m
bearer
                   connectionId
connectionId
                   HandshakeArguments {
                     Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
                     Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
                     VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
                     vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion,
                     ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
                   }
                   Versions vNumber vData application
versions  =
    m (Either
     (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeException vNumber) (application, vNumber, vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
      ((Either
   (HandshakeProtocolError vNumber) (application, vNumber, vData),
 Maybe ByteString)
-> Either
     (HandshakeProtocolError vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
    (HandshakeProtocolError vNumber) (application, vNumber, vData),
  Maybe ByteString)
 -> Either
      (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData),
      Maybe ByteString)
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
     (Handshake vNumber Term)
     'AsServer
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData),
      Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadMonotonicTime m, MonadTimer m,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
          (connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
 -> WithMuxBearer
      connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
          Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
          ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
          ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
          (Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
ResponderDir))
          (VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
     (Handshake vNumber Term)
     'AsServer
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
     (Handshake vNumber Term)
     'AsServer
     'StPropose
     m
     (Either (HandshakeProtocolError vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion Versions vNumber vData application
versions))