{-# LANGUAGE DeriveAnyClass             #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}

module Ouroboros.Consensus.Util.STM (
    -- * 'Watcher'
    Watcher (..)
  , forkLinkedWatcher
  , withWatcher
    -- * Misc
  , Fingerprint (..)
  , WithFingerprint (..)
  , blockUntilAllJust
  , blockUntilChanged
  , blockUntilJust
  , runWhenJust
    -- * Simulate various monad stacks in STM
  , Sim (..)
  , simId
  , simStateT
  ) where

import           Control.Monad.State
import           Data.Void
import           Data.Word (Word64)
import           GHC.Generics (Generic)
import           GHC.Stack

import           Ouroboros.Consensus.Util.IOLike
import           Ouroboros.Consensus.Util.ResourceRegistry

{-------------------------------------------------------------------------------
  Misc
-------------------------------------------------------------------------------}

-- | Wait until the TVar changed
blockUntilChanged :: forall m a b. (MonadSTM m, Eq b)
                  => (a -> b) -> b -> STM m a -> STM m (a, b)
blockUntilChanged :: (a -> b) -> b -> STM m a -> STM m (a, b)
blockUntilChanged a -> b
f b
b STM m a
getA = do
    a
a <- STM m a
getA
    let b' :: b
b' = a -> b
f a
a
    if b
b' b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== b
b
      then STM m (a, b)
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
      else (a, b) -> STM m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, b
b')

-- | Spawn a new thread that waits for an STM value to become 'Just'
--
-- The thread will be linked to the registry.
runWhenJust :: IOLike m
            => ResourceRegistry m
            -> String  -- ^ Label for the thread
            -> STM m (Maybe a)
            -> (a -> m ())
            -> m ()
runWhenJust :: ResourceRegistry m
-> String -> STM m (Maybe a) -> (a -> m ()) -> m ()
runWhenJust ResourceRegistry m
registry String
label STM m (Maybe a)
getMaybeA a -> m ()
action =
    m (Thread m ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Thread m ()) -> m ()) -> m (Thread m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ ResourceRegistry m -> String -> m () -> m (Thread m ())
forall (m :: * -> *) a.
(IOLike m, HasCallStack) =>
ResourceRegistry m -> String -> m a -> m (Thread m a)
forkLinkedThread ResourceRegistry m
registry String
label (m () -> m (Thread m ())) -> m () -> m (Thread m ())
forall a b. (a -> b) -> a -> b
$
      a -> m ()
action (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe a) -> STM m a
forall (m :: * -> *) a. MonadSTM m => STM m (Maybe a) -> STM m a
blockUntilJust STM m (Maybe a)
getMaybeA)

blockUntilJust :: MonadSTM m => STM m (Maybe a) -> STM m a
blockUntilJust :: STM m (Maybe a) -> STM m a
blockUntilJust STM m (Maybe a)
getMaybeA = do
    Maybe a
ma <- STM m (Maybe a)
getMaybeA
    case Maybe a
ma of
      Maybe a
Nothing -> STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
      Just a
a  -> a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

blockUntilAllJust :: MonadSTM m => [STM m (Maybe a)] -> STM m [a]
blockUntilAllJust :: [STM m (Maybe a)] -> STM m [a]
blockUntilAllJust = (STM m (Maybe a) -> STM m a) -> [STM m (Maybe a)] -> STM m [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM STM m (Maybe a) -> STM m a
forall (m :: * -> *) a. MonadSTM m => STM m (Maybe a) -> STM m a
blockUntilJust

-- | Simple type that can be used to indicate something in a @TVar@ is
-- changed.
newtype Fingerprint = Fingerprint Word64
  deriving stock    (Int -> Fingerprint -> ShowS
[Fingerprint] -> ShowS
Fingerprint -> String
(Int -> Fingerprint -> ShowS)
-> (Fingerprint -> String)
-> ([Fingerprint] -> ShowS)
-> Show Fingerprint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Fingerprint] -> ShowS
$cshowList :: [Fingerprint] -> ShowS
show :: Fingerprint -> String
$cshow :: Fingerprint -> String
showsPrec :: Int -> Fingerprint -> ShowS
$cshowsPrec :: Int -> Fingerprint -> ShowS
Show, Fingerprint -> Fingerprint -> Bool
(Fingerprint -> Fingerprint -> Bool)
-> (Fingerprint -> Fingerprint -> Bool) -> Eq Fingerprint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Fingerprint -> Fingerprint -> Bool
$c/= :: Fingerprint -> Fingerprint -> Bool
== :: Fingerprint -> Fingerprint -> Bool
$c== :: Fingerprint -> Fingerprint -> Bool
Eq, (forall x. Fingerprint -> Rep Fingerprint x)
-> (forall x. Rep Fingerprint x -> Fingerprint)
-> Generic Fingerprint
forall x. Rep Fingerprint x -> Fingerprint
forall x. Fingerprint -> Rep Fingerprint x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Fingerprint x -> Fingerprint
$cfrom :: forall x. Fingerprint -> Rep Fingerprint x
Generic)
  deriving newtype  (Int -> Fingerprint
Fingerprint -> Int
Fingerprint -> [Fingerprint]
Fingerprint -> Fingerprint
Fingerprint -> Fingerprint -> [Fingerprint]
Fingerprint -> Fingerprint -> Fingerprint -> [Fingerprint]
(Fingerprint -> Fingerprint)
-> (Fingerprint -> Fingerprint)
-> (Int -> Fingerprint)
-> (Fingerprint -> Int)
-> (Fingerprint -> [Fingerprint])
-> (Fingerprint -> Fingerprint -> [Fingerprint])
-> (Fingerprint -> Fingerprint -> [Fingerprint])
-> (Fingerprint -> Fingerprint -> Fingerprint -> [Fingerprint])
-> Enum Fingerprint
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Fingerprint -> Fingerprint -> Fingerprint -> [Fingerprint]
$cenumFromThenTo :: Fingerprint -> Fingerprint -> Fingerprint -> [Fingerprint]
enumFromTo :: Fingerprint -> Fingerprint -> [Fingerprint]
$cenumFromTo :: Fingerprint -> Fingerprint -> [Fingerprint]
enumFromThen :: Fingerprint -> Fingerprint -> [Fingerprint]
$cenumFromThen :: Fingerprint -> Fingerprint -> [Fingerprint]
enumFrom :: Fingerprint -> [Fingerprint]
$cenumFrom :: Fingerprint -> [Fingerprint]
fromEnum :: Fingerprint -> Int
$cfromEnum :: Fingerprint -> Int
toEnum :: Int -> Fingerprint
$ctoEnum :: Int -> Fingerprint
pred :: Fingerprint -> Fingerprint
$cpred :: Fingerprint -> Fingerprint
succ :: Fingerprint -> Fingerprint
$csucc :: Fingerprint -> Fingerprint
Enum)
  deriving anyclass (Context -> Fingerprint -> IO (Maybe ThunkInfo)
Proxy Fingerprint -> String
(Context -> Fingerprint -> IO (Maybe ThunkInfo))
-> (Context -> Fingerprint -> IO (Maybe ThunkInfo))
-> (Proxy Fingerprint -> String)
-> NoThunks Fingerprint
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy Fingerprint -> String
$cshowTypeOf :: Proxy Fingerprint -> String
wNoThunks :: Context -> Fingerprint -> IO (Maybe ThunkInfo)
$cwNoThunks :: Context -> Fingerprint -> IO (Maybe ThunkInfo)
noThunks :: Context -> Fingerprint -> IO (Maybe ThunkInfo)
$cnoThunks :: Context -> Fingerprint -> IO (Maybe ThunkInfo)
NoThunks)

-- | Store a value together with its fingerprint.
data WithFingerprint a = WithFingerprint
  { WithFingerprint a -> a
forgetFingerprint :: !a
  , WithFingerprint a -> Fingerprint
getFingerprint    :: !Fingerprint
  } deriving (Int -> WithFingerprint a -> ShowS
[WithFingerprint a] -> ShowS
WithFingerprint a -> String
(Int -> WithFingerprint a -> ShowS)
-> (WithFingerprint a -> String)
-> ([WithFingerprint a] -> ShowS)
-> Show (WithFingerprint a)
forall a. Show a => Int -> WithFingerprint a -> ShowS
forall a. Show a => [WithFingerprint a] -> ShowS
forall a. Show a => WithFingerprint a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WithFingerprint a] -> ShowS
$cshowList :: forall a. Show a => [WithFingerprint a] -> ShowS
show :: WithFingerprint a -> String
$cshow :: forall a. Show a => WithFingerprint a -> String
showsPrec :: Int -> WithFingerprint a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> WithFingerprint a -> ShowS
Show, WithFingerprint a -> WithFingerprint a -> Bool
(WithFingerprint a -> WithFingerprint a -> Bool)
-> (WithFingerprint a -> WithFingerprint a -> Bool)
-> Eq (WithFingerprint a)
forall a. Eq a => WithFingerprint a -> WithFingerprint a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WithFingerprint a -> WithFingerprint a -> Bool
$c/= :: forall a. Eq a => WithFingerprint a -> WithFingerprint a -> Bool
== :: WithFingerprint a -> WithFingerprint a -> Bool
$c== :: forall a. Eq a => WithFingerprint a -> WithFingerprint a -> Bool
Eq, a -> WithFingerprint b -> WithFingerprint a
(a -> b) -> WithFingerprint a -> WithFingerprint b
(forall a b. (a -> b) -> WithFingerprint a -> WithFingerprint b)
-> (forall a b. a -> WithFingerprint b -> WithFingerprint a)
-> Functor WithFingerprint
forall a b. a -> WithFingerprint b -> WithFingerprint a
forall a b. (a -> b) -> WithFingerprint a -> WithFingerprint b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> WithFingerprint b -> WithFingerprint a
$c<$ :: forall a b. a -> WithFingerprint b -> WithFingerprint a
fmap :: (a -> b) -> WithFingerprint a -> WithFingerprint b
$cfmap :: forall a b. (a -> b) -> WithFingerprint a -> WithFingerprint b
Functor, (forall x. WithFingerprint a -> Rep (WithFingerprint a) x)
-> (forall x. Rep (WithFingerprint a) x -> WithFingerprint a)
-> Generic (WithFingerprint a)
forall x. Rep (WithFingerprint a) x -> WithFingerprint a
forall x. WithFingerprint a -> Rep (WithFingerprint a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (WithFingerprint a) x -> WithFingerprint a
forall a x. WithFingerprint a -> Rep (WithFingerprint a) x
$cto :: forall a x. Rep (WithFingerprint a) x -> WithFingerprint a
$cfrom :: forall a x. WithFingerprint a -> Rep (WithFingerprint a) x
Generic, Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
Proxy (WithFingerprint a) -> String
(Context -> WithFingerprint a -> IO (Maybe ThunkInfo))
-> (Context -> WithFingerprint a -> IO (Maybe ThunkInfo))
-> (Proxy (WithFingerprint a) -> String)
-> NoThunks (WithFingerprint a)
forall a.
NoThunks a =>
Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Proxy (WithFingerprint a) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy (WithFingerprint a) -> String
$cshowTypeOf :: forall a. NoThunks a => Proxy (WithFingerprint a) -> String
wNoThunks :: Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall a.
NoThunks a =>
Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
noThunks :: Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
$cnoThunks :: forall a.
NoThunks a =>
Context -> WithFingerprint a -> IO (Maybe ThunkInfo)
NoThunks)

{-------------------------------------------------------------------------------
  Simulate monad stacks
-------------------------------------------------------------------------------}

newtype Sim n m = Sim { Sim n m -> forall a. n a -> STM m a
runSim :: forall a. n a -> STM m a }

simId :: Sim (STM m) m
simId :: Sim (STM m) m
simId = (forall a. STM m a -> STM m a) -> Sim (STM m) m
forall (n :: * -> *) (m :: * -> *).
(forall a. n a -> STM m a) -> Sim n m
Sim forall a. a -> a
forall a. STM m a -> STM m a
id

simStateT :: IOLike m => StrictTVar m st -> Sim n m -> Sim (StateT st n) m
simStateT :: StrictTVar m st -> Sim n m -> Sim (StateT st n) m
simStateT StrictTVar m st
stVar (Sim forall a. n a -> STM m a
k) = (forall a. StateT st n a -> STM m a) -> Sim (StateT st n) m
forall (n :: * -> *) (m :: * -> *).
(forall a. n a -> STM m a) -> Sim n m
Sim ((forall a. StateT st n a -> STM m a) -> Sim (StateT st n) m)
-> (forall a. StateT st n a -> STM m a) -> Sim (StateT st n) m
forall a b. (a -> b) -> a -> b
$ \(StateT f) -> do
    st
st       <- StrictTVar m st -> STM m st
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m st
stVar
    (a
a, st
st') <- n (a, st) -> STM m (a, st)
forall a. n a -> STM m a
k (st -> n (a, st)
f st
st)
    StrictTVar m st -> st -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m st
stVar st
st'
    a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

{-------------------------------------------------------------------------------
  Watchers
-------------------------------------------------------------------------------}

-- | Specification for a thread that watches a variable, and reports interesting
-- changes.
--
-- NOTE: STM does not guarantee that 'wNotify' will /literally/ be called on
-- /every/ change: when the system is under heavy load, some updates may be
-- missed.
data Watcher m a fp = Watcher {
    -- | Obtain a fingerprint from a value of the monitored variable.
    Watcher m a fp -> a -> fp
wFingerprint :: a -> fp
    -- | The initial fingerprint
    --
    -- If 'Nothing', the action is executed once immediately to obtain the
    -- initial fingerprint.
  , Watcher m a fp -> Maybe fp
wInitial     :: Maybe fp
    -- | An action executed each time the fingerprint changes.
  , Watcher m a fp -> a -> m ()
wNotify      :: a -> m ()
    -- | The variable to monitor.
  , Watcher m a fp -> STM m a
wReader      :: STM m a
  }

-- | Execute a 'Watcher'
--
-- NOT EXPORTED
runWatcher :: forall m a fp. (IOLike m, Eq fp, HasCallStack)
           => Watcher m a fp
           -> m Void
runWatcher :: Watcher m a fp -> m Void
runWatcher Watcher m a fp
watcher = do
    fp
initB <- case Maybe fp
mbInitFP of
      Just fp
fp -> fp -> m fp
forall (m :: * -> *) a. Monad m => a -> m a
return fp
fp
      Maybe fp
Nothing -> do
        a
a <- STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m a
getA
        a -> m ()
notify a
a
        fp -> m fp
forall (m :: * -> *) a. Monad m => a -> m a
return (fp -> m fp) -> fp -> m fp
forall a b. (a -> b) -> a -> b
$ a -> fp
f a
a
    fp -> m Void
loop fp
initB
  where
    Watcher {
        wFingerprint :: forall (m :: * -> *) a fp. Watcher m a fp -> a -> fp
wFingerprint = a -> fp
f
      , wInitial :: forall (m :: * -> *) a fp. Watcher m a fp -> Maybe fp
wInitial     = Maybe fp
mbInitFP
      , wNotify :: forall (m :: * -> *) a fp. Watcher m a fp -> a -> m ()
wNotify      = a -> m ()
notify
      , wReader :: forall (m :: * -> *) a fp. Watcher m a fp -> STM m a
wReader      = STM m a
getA
      } = Watcher m a fp
watcher

    loop :: fp -> m Void
    loop :: fp -> m Void
loop fp
fp = do
      (a
a, fp
fp') <- STM m (a, fp) -> m (a, fp)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (a, fp) -> m (a, fp)) -> STM m (a, fp) -> m (a, fp)
forall a b. (a -> b) -> a -> b
$ (a -> fp) -> fp -> STM m a -> STM m (a, fp)
forall (m :: * -> *) a b.
(MonadSTM m, Eq b) =>
(a -> b) -> b -> STM m a -> STM m (a, b)
blockUntilChanged a -> fp
f fp
fp STM m a
getA
      a -> m ()
notify a
a
      fp -> m Void
loop fp
fp'

-- | Spawn a new thread that runs a 'Watcher'
--
-- The thread will be linked to the registry.
forkLinkedWatcher :: forall m a fp. (IOLike m, Eq fp, HasCallStack)
                  => ResourceRegistry m
                  -> String    -- ^ Label for the thread
                  -> Watcher m a fp
                  -> m (Thread m Void)
forkLinkedWatcher :: ResourceRegistry m -> String -> Watcher m a fp -> m (Thread m Void)
forkLinkedWatcher ResourceRegistry m
registry String
label Watcher m a fp
watcher =
    ResourceRegistry m -> String -> m Void -> m (Thread m Void)
forall (m :: * -> *) a.
(IOLike m, HasCallStack) =>
ResourceRegistry m -> String -> m a -> m (Thread m a)
forkLinkedThread ResourceRegistry m
registry String
label (m Void -> m (Thread m Void)) -> m Void -> m (Thread m Void)
forall a b. (a -> b) -> a -> b
$ Watcher m a fp -> m Void
forall (m :: * -> *) a fp.
(IOLike m, Eq fp, HasCallStack) =>
Watcher m a fp -> m Void
runWatcher Watcher m a fp
watcher

-- | Spawn a new thread that runs a 'Watcher'
--
-- The thread is bracketed via 'withAsync' and 'link'ed.
--
-- We do not provide the 'Async' handle only because our anticipated use cases
-- don't need it.
withWatcher :: forall m a fp r. (IOLike m, Eq fp, HasCallStack)
            => String    -- ^ Label for the thread
            -> Watcher m a fp
            -> m r
            -> m r
withWatcher :: String -> Watcher m a fp -> m r -> m r
withWatcher String
label Watcher m a fp
watcher m r
k =
    m Void -> (Async m Void -> m r) -> m r
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync
      (do String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label; Watcher m a fp -> m Void
forall (m :: * -> *) a fp.
(IOLike m, Eq fp, HasCallStack) =>
Watcher m a fp -> m Void
runWatcher Watcher m a fp
watcher)
      (\Async m Void
h -> do Async m Void -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
Async m a -> m ()
link Async m Void
h; m r
k)