{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}


module Control.Concurrent.JobPool
  ( JobPool
  , Job (..)
  , withJobPool
  , forkJob
  , readSize
  , readGroupSize
  , collect
  , cancelGroup
  ) where

import           Data.Functor (($>))
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import           Control.Exception (SomeAsyncException (..))
import           Control.Monad (when)
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork (MonadThread (..))
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadThrow


data JobPool group m a = JobPool {
       JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar         :: !(TVar m (Map (group, ThreadId m) (Async m ()))),
       JobPool group m a -> TQueue m a
completionQueue :: !(TQueue m a)
     }

data Job group m a = Job (m a) (SomeException -> m a) group String

withJobPool :: forall group m a b.
               (MonadAsync m, MonadThrow m, MonadLabelledSTM m)
            => (JobPool group m a -> m b) -> m b
withJobPool :: (JobPool group m a -> m b) -> m b
withJobPool =
    m (JobPool group m a)
-> (JobPool group m a -> m ()) -> (JobPool group m a -> m b) -> m b
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket m (JobPool group m a)
create JobPool group m a -> m ()
close
  where
    create :: m (JobPool group m a)
    create :: m (JobPool group m a)
create =
      STM m (JobPool group m a) -> m (JobPool group m a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (JobPool group m a) -> m (JobPool group m a))
-> STM m (JobPool group m a) -> m (JobPool group m a)
forall a b. (a -> b) -> a -> b
$
        TVar m (Map (group, ThreadId m) (Async m ()))
-> TQueue m a -> JobPool group m a
forall group (m :: * -> *) a.
TVar m (Map (group, ThreadId m) (Async m ()))
-> TQueue m a -> JobPool group m a
JobPool (TVar m (Map (group, ThreadId m) (Async m ()))
 -> TQueue m a -> JobPool group m a)
-> STM m (TVar m (Map (group, ThreadId m) (Async m ())))
-> STM m (TQueue m a -> JobPool group m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Map (group, ThreadId m) (Async m ())
-> STM m (TVar m (Map (group, ThreadId m) (Async m ())))
forall (m :: * -> *) a. MonadSTM m => a -> STM m (TVar m a)
newTVar Map (group, ThreadId m) (Async m ())
forall k a. Map k a
Map.empty STM m (TVar m (Map (group, ThreadId m) (Async m ())))
-> (TVar m (Map (group, ThreadId m) (Async m ()))
    -> STM m (TVar m (Map (group, ThreadId m) (Async m ()))))
-> STM m (TVar m (Map (group, ThreadId m) (Async m ())))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TVar m (Map (group, ThreadId m) (Async m ()))
v -> TVar m (Map (group, ThreadId m) (Async m ())) -> String -> STM m ()
forall (m :: * -> *) a.
MonadLabelledSTM m =>
TVar m a -> String -> STM m ()
labelTVar TVar m (Map (group, ThreadId m) (Async m ()))
v String
"job-pool" STM m ()
-> TVar m (Map (group, ThreadId m) (Async m ()))
-> STM m (TVar m (Map (group, ThreadId m) (Async m ())))
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TVar m (Map (group, ThreadId m) (Async m ()))
v)
                STM m (TQueue m a -> JobPool group m a)
-> STM m (TQueue m a) -> STM m (JobPool group m a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM m (TQueue m a)
forall (m :: * -> *) a. MonadSTM m => STM m (TQueue m a)
newTQueue

    -- 'bracket' requires that the 'close' callback is uninterruptible.  Note
    -- also that 'async' library is using 'uninterruptibleClose' in
    -- 'withAsync' combinator.  This can only deadlock if the threads in
    -- 'JobPool' got deadlocked so that the asynchronous exception cannot be
    -- delivered, e.g. deadlock in an ffi call or a tight loop which does not
    -- allocate (which is not a deadlock per se, but rather a rare unfortunate
    -- condition).
    close :: JobPool group m a -> m ()
    close :: JobPool group m a -> m ()
close JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} = do
      Map (group, ThreadId m) (Async m ())
jobs <- STM m (Map (group, ThreadId m) (Async m ()))
-> m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TVar m (Map (group, ThreadId m) (Async m ()))
-> STM m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar)
      (Async m () -> m ())
-> Map (group, ThreadId m) (Async m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async m () -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
uninterruptibleCancel Map (group, ThreadId m) (Async m ())
jobs

forkJob :: forall group m a.
           ( MonadAsync m, MonadMask m
           , Ord group
           )
        => JobPool group m a
        -> Job     group m a
        -> m ()
forkJob :: JobPool group m a -> Job group m a -> m ()
forkJob JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar, TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall group (m :: * -> *) a. JobPool group m a -> TQueue m a
completionQueue} (Job m a
action SomeException -> m a
handler group
group String
label) =
    ((forall a. m a -> m a) -> m ()) -> m ()
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m ()) -> m ())
-> ((forall a. m a -> m a) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
      Async m ()
jobAsync <- m () -> m (Async m ())
forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a)
async (m () -> m (Async m ())) -> m () -> m (Async m ())
forall a b. (a -> b) -> a -> b
$ do
        ThreadId m
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
        ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label
        !a
res <- (SomeException -> Maybe SomeException)
-> (SomeException -> m a) -> m a -> m a
forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust SomeException -> Maybe SomeException
notAsyncExceptions SomeException -> m a
handler (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
                 m a -> m a
forall a. m a -> m a
restore m a
action
        STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          TQueue m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TQueue m a -> a -> STM m ()
writeTQueue TQueue m a
completionQueue a
res
          TVar m (Map (group, ThreadId m) (Async m ()))
-> (Map (group, ThreadId m) (Async m ())
    -> Map (group, ThreadId m) (Async m ()))
-> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
TVar m a -> (a -> a) -> STM m ()
modifyTVar' TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar ((group, ThreadId m)
-> Map (group, ThreadId m) (Async m ())
-> Map (group, ThreadId m) (Async m ())
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete (group
group, ThreadId m
tid))

      let !tid :: ThreadId m
tid = Async m () -> ThreadId m
forall (m :: * -> *) a. MonadAsync m => Async m a -> ThreadId m
asyncThreadId Async m ()
jobAsync
      STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar m (Map (group, ThreadId m) (Async m ()))
-> (Map (group, ThreadId m) (Async m ())
    -> Map (group, ThreadId m) (Async m ()))
-> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
TVar m a -> (a -> a) -> STM m ()
modifyTVar' TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar ((group, ThreadId m)
-> Async m ()
-> Map (group, ThreadId m) (Async m ())
-> Map (group, ThreadId m) (Async m ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (group
group, ThreadId m
tid) Async m ()
jobAsync)
      () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    notAsyncExceptions :: SomeException -> Maybe SomeException
    notAsyncExceptions :: SomeException -> Maybe SomeException
notAsyncExceptions SomeException
e
      | Just (SomeAsyncException e
_) <- SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
                  = Maybe SomeException
forall a. Maybe a
Nothing
      | Bool
otherwise = SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e

readSize :: MonadSTM m => JobPool group m a -> STM m Int
readSize :: JobPool group m a -> STM m Int
readSize JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} = Map (group, ThreadId m) (Async m ()) -> Int
forall k a. Map k a -> Int
Map.size (Map (group, ThreadId m) (Async m ()) -> Int)
-> STM m (Map (group, ThreadId m) (Async m ())) -> STM m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar m (Map (group, ThreadId m) (Async m ()))
-> STM m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar

readGroupSize :: ( MonadSTM m
                 , Eq group
                 )
              => JobPool group m a -> group -> STM m Int
readGroupSize :: JobPool group m a -> group -> STM m Int
readGroupSize JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} group
group =
      Map (group, ThreadId m) (Async m ()) -> Int
forall k a. Map k a -> Int
Map.size
    (Map (group, ThreadId m) (Async m ()) -> Int)
-> (Map (group, ThreadId m) (Async m ())
    -> Map (group, ThreadId m) (Async m ()))
-> Map (group, ThreadId m) (Async m ())
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((group, ThreadId m) -> Async m () -> Bool)
-> Map (group, ThreadId m) (Async m ())
-> Map (group, ThreadId m) (Async m ())
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\(group
group', ThreadId m
_) Async m ()
_ -> group
group' group -> group -> Bool
forall a. Eq a => a -> a -> Bool
== group
group)
  (Map (group, ThreadId m) (Async m ()) -> Int)
-> STM m (Map (group, ThreadId m) (Async m ())) -> STM m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar m (Map (group, ThreadId m) (Async m ()))
-> STM m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar

collect :: MonadSTM m => JobPool group m a -> STM m a
collect :: JobPool group m a -> STM m a
collect JobPool{TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall group (m :: * -> *) a. JobPool group m a -> TQueue m a
completionQueue} = TQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TQueue m a -> STM m a
readTQueue TQueue m a
completionQueue

cancelGroup :: ( MonadAsync m
               , Eq group
               )
            => JobPool group m a -> group -> m ()
cancelGroup :: JobPool group m a -> group -> m ()
cancelGroup JobPool { TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar } group
group = do
    Map (group, ThreadId m) (Async m ())
jobs <- STM m (Map (group, ThreadId m) (Async m ()))
-> m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TVar m (Map (group, ThreadId m) (Async m ()))
-> STM m (Map (group, ThreadId m) (Async m ()))
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar)
    Map (group, ThreadId m) ()
_ <- ((group, ThreadId m) -> Async m () -> m ())
-> Map (group, ThreadId m) (Async m ())
-> m (Map (group, ThreadId m) ())
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey (\(group
group', ThreadId m
_) Async m ()
thread ->
                               Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (group
group' group -> group -> Bool
forall a. Eq a => a -> a -> Bool
== group
group) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                                 Async m () -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m ()
thread
                             )
                             Map (group, ThreadId m) (Async m ())
jobs
    () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()