{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.KeepAlive
  ( KeepAliveInterval (..)
  , keepAliveClient
  , keepAliveServer
  , TraceKeepAliveClient (..)
  ) where

import           Control.Exception (assert)
import qualified Control.Monad.Class.MonadSTM as Lazy
import           Control.Monad.Class.MonadSTM.Strict
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer
import           Control.Tracer (Tracer, traceWith)
import qualified Data.Map.Strict as M
import           Data.Maybe (fromJust)
import           System.Random (StdGen, random)

import           Ouroboros.Network.DeltaQ
import           Ouroboros.Network.Mux (ControlMessage (..), ControlMessageSTM)
import           Ouroboros.Network.Protocol.KeepAlive.Client
import           Ouroboros.Network.Protocol.KeepAlive.Server
import           Ouroboros.Network.Protocol.KeepAlive.Type


newtype KeepAliveInterval = KeepAliveInterval { KeepAliveInterval -> DiffTime
keepAliveInterval :: DiffTime }

data TraceKeepAliveClient peer =
    AddSample peer DiffTime PeerGSV

instance Show peer => Show (TraceKeepAliveClient peer) where
    show :: TraceKeepAliveClient peer -> String
show (AddSample peer
peer DiffTime
rtt PeerGSV
gsv) = String
"AddSample " String -> ShowS
forall a. [a] -> [a] -> [a]
++ peer -> String
forall a. Show a => a -> String
show peer
peer String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" sample: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ DiffTime -> String
forall a. Show a => a -> String
show DiffTime
rtt
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" gsv: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PeerGSV -> String
forall a. Show a => a -> String
show PeerGSV
gsv

keepAliveClient
    :: forall m peer.
       ( MonadSTM   m
       , MonadMonotonicTime m
       , MonadTimer m
       , Ord peer
       )
    => Tracer m (TraceKeepAliveClient peer)
    -> StdGen
    -> ControlMessageSTM m
    -> peer
    -> (StrictTVar m (M.Map peer PeerGSV))
    -> KeepAliveInterval
    -> KeepAliveClient m ()
keepAliveClient :: Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ControlMessageSTM m
-> peer
-> StrictTVar m (Map peer PeerGSV)
-> KeepAliveInterval
-> KeepAliveClient m ()
keepAliveClient Tracer m (TraceKeepAliveClient peer)
tracer StdGen
inRng ControlMessageSTM m
controlMessageSTM peer
peer StrictTVar m (Map peer PeerGSV)
dqCtx KeepAliveInterval { DiffTime
keepAliveInterval :: DiffTime
keepAliveInterval :: KeepAliveInterval -> DiffTime
keepAliveInterval } =
    let (Word16
cookie, StdGen
rng) = StdGen -> (Word16, StdGen)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random StdGen
inRng in
    Cookie -> m (KeepAliveClient m ()) -> KeepAliveClient m ()
forall (m :: * -> *) a.
Cookie -> m (KeepAliveClient m a) -> KeepAliveClient m a
SendMsgKeepAlive (Word16 -> Cookie
Cookie Word16
cookie) (StdGen -> Maybe Time -> m (KeepAliveClient m ())
go StdGen
rng Maybe Time
forall a. Maybe a
Nothing)
  where
    payloadSize :: SizeInBytes
payloadSize = SizeInBytes
2

    decisionSTM :: Lazy.TVar m Bool
                -> STM  m ControlMessage
    decisionSTM :: TVar m Bool -> ControlMessageSTM m
decisionSTM TVar m Bool
delayVar = do
       ControlMessage
controlMessage <- ControlMessageSTM m
controlMessageSTM
       case ControlMessage
controlMessage of
            ControlMessage
Terminate -> ControlMessage -> ControlMessageSTM m
forall (m :: * -> *) a. Monad m => a -> m a
return ControlMessage
Terminate

            -- Continue
            ControlMessage
_  -> do
              Bool
done <- TVar m Bool -> STM m Bool
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
Lazy.readTVar TVar m Bool
delayVar
              if Bool
done
                 then ControlMessage -> ControlMessageSTM m
forall (m :: * -> *) a. Monad m => a -> m a
return ControlMessage
Continue
                 else ControlMessageSTM m
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

    go :: StdGen -> Maybe Time -> m (KeepAliveClient m ())
    go :: StdGen -> Maybe Time -> m (KeepAliveClient m ())
go StdGen
rng Maybe Time
startTime_m = do
      Time
endTime <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      case Maybe Time
startTime_m of
           Just Time
startTime -> do
               let rtt :: DiffTime
rtt = Time -> Time -> DiffTime
diffTime Time
endTime Time
startTime
                   sample :: PeerGSV
sample = Time -> Time -> SizeInBytes -> PeerGSV
fromSample Time
startTime Time
endTime SizeInBytes
payloadSize
               PeerGSV
gsv' <- STM m PeerGSV -> m PeerGSV
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m PeerGSV -> m PeerGSV) -> STM m PeerGSV -> m PeerGSV
forall a b. (a -> b) -> a -> b
$ do
                   Map peer PeerGSV
m <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqCtx
                   Bool -> STM m PeerGSV -> STM m PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map peer PeerGSV
m) (STM m PeerGSV -> STM m PeerGSV) -> STM m PeerGSV -> STM m PeerGSV
forall a b. (a -> b) -> a -> b
$ do
                     let (Maybe PeerGSV
gsv', Map peer PeerGSV
m') = (peer -> PeerGSV -> Maybe PeerGSV)
-> peer -> Map peer PeerGSV -> (Maybe PeerGSV, Map peer PeerGSV)
forall k a.
Ord k =>
(k -> a -> Maybe a) -> k -> Map k a -> (Maybe a, Map k a)
M.updateLookupWithKey
                             (\peer
_ PeerGSV
a -> if PeerGSV -> Time
sampleTime PeerGSV
a Time -> Time -> Bool
forall a. Eq a => a -> a -> Bool
== DiffTime -> Time
Time DiffTime
0 -- Ignore the initial dummy value
                                         then PeerGSV -> Maybe PeerGSV
forall a. a -> Maybe a
Just PeerGSV
sample
                                         else PeerGSV -> Maybe PeerGSV
forall a. a -> Maybe a
Just (PeerGSV -> Maybe PeerGSV) -> PeerGSV -> Maybe PeerGSV
forall a b. (a -> b) -> a -> b
$ PeerGSV
sample PeerGSV -> PeerGSV -> PeerGSV
forall a. Semigroup a => a -> a -> a
<> PeerGSV
a
                             ) peer
peer Map peer PeerGSV
m
                     StrictTVar m (Map peer PeerGSV) -> Map peer PeerGSV -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (Map peer PeerGSV)
dqCtx Map peer PeerGSV
m'
                     PeerGSV -> STM m PeerGSV
forall (m :: * -> *) a. Monad m => a -> m a
return (PeerGSV -> STM m PeerGSV) -> PeerGSV -> STM m PeerGSV
forall a b. (a -> b) -> a -> b
$ Maybe PeerGSV -> PeerGSV
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PeerGSV
gsv'
               Tracer m (TraceKeepAliveClient peer)
-> TraceKeepAliveClient peer -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (TraceKeepAliveClient peer)
tracer (TraceKeepAliveClient peer -> m ())
-> TraceKeepAliveClient peer -> m ()
forall a b. (a -> b) -> a -> b
$ peer -> DiffTime -> PeerGSV -> TraceKeepAliveClient peer
forall peer.
peer -> DiffTime -> PeerGSV -> TraceKeepAliveClient peer
AddSample peer
peer DiffTime
rtt PeerGSV
gsv'

           Maybe Time
Nothing        -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      let keepAliveInterval' :: DiffTime
keepAliveInterval' = case Maybe Time
startTime_m of
                                    Just Time
_  -> DiffTime
keepAliveInterval
                                    Maybe Time
Nothing -> DiffTime
0 -- The first time we send a packet directly.

      TVar m Bool
delayVar <- DiffTime -> m (TVar m Bool)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (TVar m Bool)
registerDelay DiffTime
keepAliveInterval'
      ControlMessage
decision <- ControlMessageSTM m -> m ControlMessage
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TVar m Bool -> ControlMessageSTM m
decisionSTM TVar m Bool
delayVar)
      Time
now <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      case ControlMessage
decision of
        -- 'decisionSTM' above cannot return 'Quiesce'
        ControlMessage
Quiesce   -> String -> m (KeepAliveClient m ())
forall a. HasCallStack => String -> a
error String
"keepAlive: impossible happened"
        ControlMessage
Continue  ->
            let (Word16
cookie, StdGen
rng') = StdGen -> (Word16, StdGen)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random StdGen
rng in
            KeepAliveClient m () -> m (KeepAliveClient m ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cookie -> m (KeepAliveClient m ()) -> KeepAliveClient m ()
forall (m :: * -> *) a.
Cookie -> m (KeepAliveClient m a) -> KeepAliveClient m a
SendMsgKeepAlive (Word16 -> Cookie
Cookie Word16
cookie) (m (KeepAliveClient m ()) -> KeepAliveClient m ())
-> m (KeepAliveClient m ()) -> KeepAliveClient m ()
forall a b. (a -> b) -> a -> b
$ StdGen -> Maybe Time -> m (KeepAliveClient m ())
go StdGen
rng' (Maybe Time -> m (KeepAliveClient m ()))
-> Maybe Time -> m (KeepAliveClient m ())
forall a b. (a -> b) -> a -> b
$ Time -> Maybe Time
forall a. a -> Maybe a
Just Time
now)
        ControlMessage
Terminate -> KeepAliveClient m () -> m (KeepAliveClient m ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m () -> KeepAliveClient m ()
forall (m :: * -> *) a. m a -> KeepAliveClient m a
SendMsgDone (() -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))


keepAliveServer
  :: forall m.  Applicative m
  => KeepAliveServer m ()
keepAliveServer :: KeepAliveServer m ()
keepAliveServer = KeepAliveServer :: forall (m :: * -> *) a.
m (KeepAliveServer m a) -> m a -> KeepAliveServer m a
KeepAliveServer {
    recvMsgKeepAlive :: m (KeepAliveServer m ())
recvMsgKeepAlive = KeepAliveServer m () -> m (KeepAliveServer m ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure KeepAliveServer m ()
forall (m :: * -> *). Applicative m => KeepAliveServer m ()
keepAliveServer,
    recvMsgDone :: m ()
recvMsgDone      = () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  }