{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Control.Monad.Class.MonadFork
  ( MonadThread (..)
  , MonadFork (..)
  , labelThisThread
    -- * Deprecated API
  , fork
  , forkWithUnmask
  ) where

import qualified Control.Concurrent as IO
import           Control.Exception (AsyncException (ThreadKilled), Exception)
import           Control.Monad.Reader (ReaderT (..), lift)
import           Data.Kind (Type)
import qualified GHC.Conc.Sync as IO (labelThread)


class (Monad m, Eq   (ThreadId m),
                Ord  (ThreadId m),
                Show (ThreadId m)) => MonadThread m where

  type ThreadId m :: Type

  myThreadId     :: m (ThreadId m)
  labelThread    :: ThreadId m -> String -> m ()


class MonadThread m => MonadFork m where

  forkIO           :: m () -> m (ThreadId m)
  forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
  throwTo          :: Exception e => ThreadId m -> e -> m ()

  killThread       :: ThreadId m -> m ()
  killThread ThreadId m
tid = ThreadId m -> AsyncException -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncException
ThreadKilled

  yield            :: m ()

fork :: MonadFork m => m () -> m (ThreadId m)
fork :: m () -> m (ThreadId m)
fork = m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO
{-# DEPRECATED fork "use forkIO" #-}

forkWithUnmask :: MonadFork m => ((forall a. m a -> m a) ->  m ()) -> m (ThreadId m)
forkWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkWithUnmask = ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask
{-# DEPRECATED forkWithUnmask "use forkIO" #-}


instance MonadThread IO where
  type ThreadId IO = IO.ThreadId
  myThreadId :: IO (ThreadId IO)
myThreadId = IO ThreadId
IO (ThreadId IO)
IO.myThreadId
  labelThread :: ThreadId IO -> String -> IO ()
labelThread = ThreadId -> String -> IO ()
ThreadId IO -> String -> IO ()
IO.labelThread

instance MonadFork IO where
  forkIO :: IO () -> IO (ThreadId IO)
forkIO           = IO () -> IO ThreadId
IO () -> IO (ThreadId IO)
IO.forkIO
  forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
forkIOWithUnmask = ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
IO.forkIOWithUnmask
  throwTo :: ThreadId IO -> e -> IO ()
throwTo          = ThreadId IO -> e -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
IO.throwTo
  killThread :: ThreadId IO -> IO ()
killThread       = ThreadId -> IO ()
ThreadId IO -> IO ()
IO.killThread
  yield :: IO ()
yield            = IO ()
IO.yield

instance MonadThread m => MonadThread (ReaderT r m) where
  type ThreadId (ReaderT r m) = ThreadId m
  myThreadId :: ReaderT r m (ThreadId (ReaderT r m))
myThreadId  = m (ThreadId m) -> ReaderT r m (ThreadId m)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  labelThread :: ThreadId (ReaderT r m) -> String -> ReaderT r m ()
labelThread ThreadId (ReaderT r m)
t String
l = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
ThreadId (ReaderT r m)
t String
l)

instance MonadFork m => MonadFork (ReaderT e m) where
  forkIO :: ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkIO (ReaderT e -> m ()
f)   = (e -> m (ThreadId m)) -> ReaderT e m (ThreadId m)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId m)) -> ReaderT e m (ThreadId m))
-> (e -> m (ThreadId m)) -> ReaderT e m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \e
e -> m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (e -> m ()
f e
e)
  forkIOWithUnmask :: ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkIOWithUnmask (forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k   = (e -> m (ThreadId m)) -> ReaderT e m (ThreadId m)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId m)) -> ReaderT e m (ThreadId m))
-> (e -> m (ThreadId m)) -> ReaderT e m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \e
e -> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask (((forall a. m a -> m a) -> m ()) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
                       let restore' :: ReaderT e m a -> ReaderT e m a
                           restore' :: ReaderT e m a -> ReaderT e m a
restore' (ReaderT e -> m a
f) = (e -> m a) -> ReaderT e m a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m a) -> ReaderT e m a) -> (e -> m a) -> ReaderT e m a
forall a b. (a -> b) -> a -> b
$ m a -> m a
forall a. m a -> m a
restore (m a -> m a) -> (e -> m a) -> e -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
f
                       in ReaderT e m () -> e -> m ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k forall a. ReaderT e m a -> ReaderT e m a
restore') e
e
  throwTo :: ThreadId (ReaderT e m) -> e -> ReaderT e m ()
throwTo ThreadId (ReaderT e m)
e e
t = m () -> ReaderT e m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
ThreadId (ReaderT e m)
e e
t)
  yield :: ReaderT e m ()
yield       = m () -> ReaderT e m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m ()
forall (m :: * -> *). MonadFork m => m ()
yield

-- | Apply the label to the current thread
labelThisThread :: MonadThread m => String -> m ()
labelThisThread :: String -> m ()
labelThisThread String
label = m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId m (ThreadId m) -> (ThreadId m -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ThreadId m
tid -> ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label