{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DerivingVia         #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | How to punish the sender of a invalid block
module Ouroboros.Consensus.Storage.ChainDB.API.Types.InvalidBlockPunishment (
    -- * opaque
    InvalidBlockPunishment
  , enact
    -- * combinators
  , Invalidity (..)
  , branch
  , mkPunishThisThread
  , mkUnlessImproved
  , noPunishment
  ) where

import qualified Control.Exception as Exn
import           Control.Monad (join, unless)
import           Data.Functor ((<&>))
import           NoThunks.Class

import           Ouroboros.Consensus.Block.Abstract (BlockProtocol)
import           Ouroboros.Consensus.Protocol.Abstract (SelectView)

import           Ouroboros.Consensus.Util.IOLike
import           Ouroboros.Consensus.Util.TentativeState

-- | Is the added block itself invalid, or is its prefix invalid?
data Invalidity =
    BlockItself
  | BlockPrefix

-- | How to handle a discovered 'Invalidity'
--
-- This type is opaque because the soundness of the punishment is subtle because
-- of where it is invoked during the chain selection. As a result, arbitrary
-- monadic actions would be foot guns. Instead, this module defines a small DSL
-- for punishment that we judge to be sound.
newtype InvalidBlockPunishment m = InvalidBlockPunishment {
    InvalidBlockPunishment m -> Invalidity -> m ()
enact :: Invalidity -> m ()
  }
  deriving Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
Proxy (InvalidBlockPunishment m) -> String
(Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo))
-> (Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo))
-> (Proxy (InvalidBlockPunishment m) -> String)
-> NoThunks (InvalidBlockPunishment m)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (m :: * -> *).
Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
forall (m :: * -> *). Proxy (InvalidBlockPunishment m) -> String
showTypeOf :: Proxy (InvalidBlockPunishment m) -> String
$cshowTypeOf :: forall (m :: * -> *). Proxy (InvalidBlockPunishment m) -> String
wNoThunks :: Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (m :: * -> *).
Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
noThunks :: Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
$cnoThunks :: forall (m :: * -> *).
Context -> InvalidBlockPunishment m -> IO (Maybe ThunkInfo)
NoThunks via
    OnlyCheckWhnfNamed "InvalidBlockPunishment" (InvalidBlockPunishment m)

-- | A noop punishment
noPunishment :: Applicative m => InvalidBlockPunishment m
noPunishment :: InvalidBlockPunishment m
noPunishment = (Invalidity -> m ()) -> InvalidBlockPunishment m
forall (m :: * -> *).
(Invalidity -> m ()) -> InvalidBlockPunishment m
InvalidBlockPunishment ((Invalidity -> m ()) -> InvalidBlockPunishment m)
-> (Invalidity -> m ()) -> InvalidBlockPunishment m
forall a b. (a -> b) -> a -> b
$ \Invalidity
_invalidity -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Create a punishment that kills this thread
mkPunishThisThread :: IOLike m => m (InvalidBlockPunishment m)
mkPunishThisThread :: m (InvalidBlockPunishment m)
mkPunishThisThread = do
    ThreadId m
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
    InvalidBlockPunishment m -> m (InvalidBlockPunishment m)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InvalidBlockPunishment m -> m (InvalidBlockPunishment m))
-> InvalidBlockPunishment m -> m (InvalidBlockPunishment m)
forall a b. (a -> b) -> a -> b
$ (Invalidity -> m ()) -> InvalidBlockPunishment m
forall (m :: * -> *).
(Invalidity -> m ()) -> InvalidBlockPunishment m
InvalidBlockPunishment ((Invalidity -> m ()) -> InvalidBlockPunishment m)
-> (Invalidity -> m ()) -> InvalidBlockPunishment m
forall a b. (a -> b) -> a -> b
$ \Invalidity
_invalidity ->
      ThreadId m -> PeerSentAnInvalidBlockException -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid PeerSentAnInvalidBlockException
PeerSentAnInvalidBlockException

-- | Thrown asynchronously to the client thread that added the block whose
-- processing involved an invalid block.
--
-- See 'punishThisThread'.
data PeerSentAnInvalidBlockException = PeerSentAnInvalidBlockException
  deriving (Int -> PeerSentAnInvalidBlockException -> ShowS
[PeerSentAnInvalidBlockException] -> ShowS
PeerSentAnInvalidBlockException -> String
(Int -> PeerSentAnInvalidBlockException -> ShowS)
-> (PeerSentAnInvalidBlockException -> String)
-> ([PeerSentAnInvalidBlockException] -> ShowS)
-> Show PeerSentAnInvalidBlockException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PeerSentAnInvalidBlockException] -> ShowS
$cshowList :: [PeerSentAnInvalidBlockException] -> ShowS
show :: PeerSentAnInvalidBlockException -> String
$cshow :: PeerSentAnInvalidBlockException -> String
showsPrec :: Int -> PeerSentAnInvalidBlockException -> ShowS
$cshowsPrec :: Int -> PeerSentAnInvalidBlockException -> ShowS
Show)

instance Exn.Exception PeerSentAnInvalidBlockException

-- | Allocate a stateful punishment that performs the given punishment unless
-- the given header is better than the previous invocation
mkUnlessImproved :: forall proxy m blk.
     ( IOLike m
     , NoThunks (SelectView (BlockProtocol blk))
     , Ord      (SelectView (BlockProtocol blk))
     )
  => proxy blk
  -> STM m (   SelectView (BlockProtocol blk)
            -> InvalidBlockPunishment m
            -> InvalidBlockPunishment m
           )
mkUnlessImproved :: proxy blk
-> STM
     m
     (SelectView (BlockProtocol blk)
      -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
mkUnlessImproved proxy blk
_prx = do
    StrictTVar m (TentativeState blk)
var <- TentativeState blk -> STM m (StrictTVar m (TentativeState blk))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack, NoThunks a) =>
a -> STM m (StrictTVar m a)
newTVar (TentativeState blk
forall blk. TentativeState blk
NoLastInvalidTentative :: TentativeState blk)
    (SelectView (BlockProtocol blk)
 -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
-> STM
     m
     (SelectView (BlockProtocol blk)
      -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((SelectView (BlockProtocol blk)
  -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
 -> STM
      m
      (SelectView (BlockProtocol blk)
       -> InvalidBlockPunishment m -> InvalidBlockPunishment m))
-> (SelectView (BlockProtocol blk)
    -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
-> STM
     m
     (SelectView (BlockProtocol blk)
      -> InvalidBlockPunishment m -> InvalidBlockPunishment m)
forall a b. (a -> b) -> a -> b
$ \SelectView (BlockProtocol blk)
new InvalidBlockPunishment m
punish -> (Invalidity -> m ()) -> InvalidBlockPunishment m
forall (m :: * -> *).
(Invalidity -> m ()) -> InvalidBlockPunishment m
InvalidBlockPunishment ((Invalidity -> m ()) -> InvalidBlockPunishment m)
-> (Invalidity -> m ()) -> InvalidBlockPunishment m
forall a b. (a -> b) -> a -> b
$ \Invalidity
invalidity -> m (m ()) -> m ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (m (m ()) -> m ()) -> m (m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ STM m (m ()) -> m (m ())
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (m ()) -> m (m ())) -> STM m (m ()) -> m (m ())
forall a b. (a -> b) -> a -> b
$ do
      m ()
io <- StrictTVar m (TentativeState blk) -> STM m (TentativeState blk)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (TentativeState blk)
var STM m (TentativeState blk)
-> (TentativeState blk -> m ()) -> STM m (m ())
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
        TentativeState blk
NoLastInvalidTentative   -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        LastInvalidTentative SelectView (BlockProtocol blk)
old -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SelectView (BlockProtocol blk)
new SelectView (BlockProtocol blk)
-> SelectView (BlockProtocol blk) -> Bool
forall a. Ord a => a -> a -> Bool
> SelectView (BlockProtocol blk)
old) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          InvalidBlockPunishment m -> Invalidity -> m ()
forall (m :: * -> *).
InvalidBlockPunishment m -> Invalidity -> m ()
enact InvalidBlockPunishment m
punish Invalidity
invalidity
      StrictTVar m (TentativeState blk) -> TentativeState blk -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (TentativeState blk)
var (TentativeState blk -> STM m ()) -> TentativeState blk -> STM m ()
forall a b. (a -> b) -> a -> b
$ SelectView (BlockProtocol blk) -> TentativeState blk
forall blk. SelectView (BlockProtocol blk) -> TentativeState blk
LastInvalidTentative SelectView (BlockProtocol blk)
new
      m () -> STM m (m ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure m ()
io

-- | Punish according to the 'Invalidity'
branch :: (Invalidity -> InvalidBlockPunishment m) -> InvalidBlockPunishment m
branch :: (Invalidity -> InvalidBlockPunishment m)
-> InvalidBlockPunishment m
branch Invalidity -> InvalidBlockPunishment m
f = (Invalidity -> m ()) -> InvalidBlockPunishment m
forall (m :: * -> *).
(Invalidity -> m ()) -> InvalidBlockPunishment m
InvalidBlockPunishment ((Invalidity -> m ()) -> InvalidBlockPunishment m)
-> (Invalidity -> m ()) -> InvalidBlockPunishment m
forall a b. (a -> b) -> a -> b
$ \Invalidity
invalidity ->
    InvalidBlockPunishment m -> Invalidity -> m ()
forall (m :: * -> *).
InvalidBlockPunishment m -> Invalidity -> m ()
enact (Invalidity -> InvalidBlockPunishment m
f Invalidity
invalidity) Invalidity
invalidity