{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}

{- HLINT ignore "Use newtype instead of data" -}

-- |
-- Copyright: © 2021 IOHK
-- License: Apache-2.0
--
-- This module provides a utility for ordering concurrent actions
-- via locks.
module Control.Concurrent.Concierge
    ( Concierge
    , newConcierge
    , atomicallyWith
    , atomicallyWithLifted
    )
    where

import Prelude

import Control.Monad.Class.MonadFork
    ( MonadThread, ThreadId, myThreadId )
import Control.Monad.Class.MonadSTM
    ( MonadSTM
    , TVar
    , atomically
    , modifyTVar
    , newTVarIO
    , readTVar
    , retry
    , writeTVar
    )
import Control.Monad.Class.MonadThrow
    ( MonadThrow, bracket )
import Control.Monad.IO.Class
    ( MonadIO, liftIO )
import Data.Map.Strict
    ( Map )

import qualified Data.Map.Strict as Map

{-------------------------------------------------------------------------------
    Concierge
-------------------------------------------------------------------------------}
-- | At a 'Concierge', you can obtain a lock and
-- enforce sequential execution of concurrent 'IO' actions.
--
-- Back in the old days, hotel concierges used to give out keys.
-- But after the cryptocurrency revolution, they give out locks. :)
-- (The term /lock/ is standard terminology in concurrent programming.)
data Concierge m lock = Concierge
    { Concierge m lock -> TVar m (Map lock (ThreadId m))
locks :: TVar m (Map lock (ThreadId m))
    }

-- | Create a new 'Concierge' that keeps track of locks.
newConcierge :: MonadSTM m => m (Concierge m lock)
newConcierge :: m (Concierge m lock)
newConcierge = TVar m (Map lock (ThreadId m)) -> Concierge m lock
forall (m :: * -> *) lock.
TVar m (Map lock (ThreadId m)) -> Concierge m lock
Concierge (TVar m (Map lock (ThreadId m)) -> Concierge m lock)
-> m (TVar m (Map lock (ThreadId m))) -> m (Concierge m lock)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map lock (ThreadId m) -> m (TVar m (Map lock (ThreadId m)))
forall (m :: * -> *) a. MonadSTM m => a -> m (TVar m a)
newTVarIO Map lock (ThreadId m)
forall k a. Map k a
Map.empty

-- | Obtain a lock from a 'Concierge' and run an 'IO' action.
--
-- If the same (equal) lock is already taken at this 'Concierge',
-- the thread will be blocked until the lock becomes available.
--
-- The action may throw a synchronous or asynchronous exception.
-- In both cases, the lock is returned to the concierge.
atomicallyWith
    :: (Ord lock, MonadIO m, MonadThrow m)
    => Concierge IO lock -> lock -> m a -> m a
atomicallyWith :: Concierge IO lock -> lock -> m a -> m a
atomicallyWith = (forall b. IO b -> m b) -> Concierge IO lock -> lock -> m a -> m a
forall lock (m :: * -> *) (n :: * -> *) a.
(Ord lock, MonadSTM m, MonadThread m, MonadThrow n) =>
(forall b. m b -> n b) -> Concierge m lock -> lock -> n a -> n a
atomicallyWithLifted forall b. IO b -> m b
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

-- | More polymorphic version of 'atomicallyWith'.
atomicallyWithLifted
    :: (Ord lock, MonadSTM m, MonadThread m, MonadThrow n)
    => (forall b. m b -> n b)
    -> Concierge m lock -> lock -> n a -> n a
atomicallyWithLifted :: (forall b. m b -> n b) -> Concierge m lock -> lock -> n a -> n a
atomicallyWithLifted forall b. m b -> n b
lift Concierge{TVar m (Map lock (ThreadId m))
locks :: TVar m (Map lock (ThreadId m))
locks :: forall (m :: * -> *) lock.
Concierge m lock -> TVar m (Map lock (ThreadId m))
locks} lock
lock n a
action =
    n () -> (() -> n ()) -> (() -> n a) -> n a
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket n ()
acquire (n () -> () -> n ()
forall a b. a -> b -> a
const n ()
release) (n a -> () -> n a
forall a b. a -> b -> a
const n a
action)
  where
    acquire :: n ()
acquire = m () -> n ()
forall b. m b -> n b
lift (m () -> n ()) -> m () -> n ()
forall a b. (a -> b) -> a -> b
$ do
        ThreadId m
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
        STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            Map lock (ThreadId m)
ls <- TVar m (Map lock (ThreadId m)) -> STM m (Map lock (ThreadId m))
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map lock (ThreadId m))
locks
            case lock -> Map lock (ThreadId m) -> Maybe (ThreadId m)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup lock
lock Map lock (ThreadId m)
ls of
                Just ThreadId m
_  -> STM m ()
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
                Maybe (ThreadId m)
Nothing -> TVar m (Map lock (ThreadId m)) -> Map lock (ThreadId m) -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m (Map lock (ThreadId m))
locks (Map lock (ThreadId m) -> STM m ())
-> Map lock (ThreadId m) -> STM m ()
forall a b. (a -> b) -> a -> b
$ lock
-> ThreadId m -> Map lock (ThreadId m) -> Map lock (ThreadId m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert lock
lock ThreadId m
tid Map lock (ThreadId m)
ls
    release :: n ()
release = m () -> n ()
forall b. m b -> n b
lift (m () -> n ()) -> m () -> n ()
forall a b. (a -> b) -> a -> b
$
        STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar m (Map lock (ThreadId m))
-> (Map lock (ThreadId m) -> Map lock (ThreadId m)) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
TVar m a -> (a -> a) -> STM m ()
modifyTVar TVar m (Map lock (ThreadId m))
locks ((Map lock (ThreadId m) -> Map lock (ThreadId m)) -> STM m ())
-> (Map lock (ThreadId m) -> Map lock (ThreadId m)) -> STM m ()
forall a b. (a -> b) -> a -> b
$ lock -> Map lock (ThreadId m) -> Map lock (ThreadId m)
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete lock
lock