{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}

module Control.Monad.Freer.Extras.Modify (
    -- * change the list of effects
    mapEffs

    -- * under functions
    , UnderN(..)
    , under

    -- * weaken functions
    , CanWeakenEnd(..)
    , weakenUnder
    , weakenNUnder
    , weakenMUnderN

    -- * raise functions
    , raiseEnd
    , raiseUnder
    , raiseUnder2
    , raise2Under
    , raiseNUnder
    , raiseMUnderN

    -- * zoom functions
    , handleZoomedState
    , handleZoomedError
    , handleZoomedWriter
    , handleZoomedReader

    -- * manipulation
    , writeIntoState
    , stateToMonadState
    , monadStateToState
    , errorToMonadError
    , wrapError
    ) where

import Control.Lens hiding (under)
import Control.Monad.Except qualified as MTL
import Control.Monad.Freer
import Control.Monad.Freer.Error
import Control.Monad.Freer.Internal
import Control.Monad.Freer.Reader
import Control.Monad.Freer.State
import Control.Monad.Freer.Writer
import Control.Monad.State qualified as MTL

mapEffs :: (Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs :: (Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs Union effs ~> Union effs'
f = Eff effs x -> Eff effs' x
loop where
    loop :: Eff effs x -> Eff effs' x
loop = \case
        Val x
a -> x -> Eff effs' x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
a
        E Union effs b
u Arrs effs b x
q -> Union effs' b -> Arrs effs' b x -> Eff effs' x
forall (effs :: [* -> *]) a b.
Union effs b -> Arrs effs b a -> Eff effs a
E (Union effs b -> Union effs' b
Union effs ~> Union effs'
f Union effs b
u) ((b -> Eff effs' x) -> Arrs effs' b x
forall a (m :: * -> *) b. (a -> m b) -> FTCQueue m a b
tsingleton ((b -> Eff effs' x) -> Arrs effs' b x)
-> (b -> Eff effs' x) -> Arrs effs' b x
forall a b. (a -> b) -> a -> b
$ Arrs effs b x -> (Eff effs x -> Eff effs' x) -> b -> Eff effs' x
forall (effs :: [* -> *]) a b (effs' :: [* -> *]) c.
Arrs effs a b -> (Eff effs b -> Eff effs' c) -> Arr effs' a c
qComp Arrs effs b x
q Eff effs x -> Eff effs' x
loop)


under :: (Union effs ~> Union effs') -> Union (a ': effs) ~> Union (a ': effs')
under :: (Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under Union effs ~> Union effs'
f Union (a : effs) x
u = case Union (a : effs) x -> Either (Union effs x) (a x)
forall (t :: * -> *) (r :: [* -> *]) a.
Union (t : r) a -> Either (Union r a) (t a)
decomp Union (a : effs) x
u of
    Left Union effs x
u' -> Union effs' x -> Union (a : effs') x
forall k (r :: [* -> *]) (a :: k) (any :: * -> *).
Union r a -> Union (any : r) a
weaken (Union effs x -> Union effs' x
Union effs ~> Union effs'
f Union effs x
u')
    Right a x
t -> a x -> Union (a : effs') x
forall (eff :: * -> *) (effs :: [* -> *]) a.
Member eff effs =>
eff a -> Union effs a
inj a x
t

class UnderN as where
    underN :: (Union effs ~> Union effs') -> Union (as :++: effs) ~> Union (as :++: effs')
instance UnderN '[] where
    underN :: (Union effs ~> Union effs')
-> Union ('[] :++: effs) ~> Union ('[] :++: effs')
underN Union effs ~> Union effs'
f = Union ('[] :++: effs) x -> Union ('[] :++: effs') x
Union effs ~> Union effs'
f
instance UnderN as => UnderN (a ': as) where
    underN :: (Union effs ~> Union effs')
-> Union ((a : as) :++: effs) ~> Union ((a : as) :++: effs')
underN Union effs ~> Union effs'
f = (Union (as :++: effs) ~> Union (as :++: effs'))
-> Union (a : (as :++: effs)) ~> Union (a : (as :++: effs'))
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under ((Union effs ~> Union effs')
-> Union (as :++: effs) ~> Union (as :++: effs')
forall (as :: [* -> *]) (effs :: [* -> *]) (effs' :: [* -> *]).
UnderN as =>
(Union effs ~> Union effs')
-> Union (as :++: effs) ~> Union (as :++: effs')
underN @as Union effs ~> Union effs'
f)


{- Note [Various raising helpers]
These are all to help with the issue where you have something of type

Eff effs a

where effs is some *fixed* list of effects. You may then need to insert
more effects *under* effs to interpret them in terms of. It turns out that
inserting effects at the *end* of the list is tricky.

I have no idea what I'm doing, these are partially stolen from freer-simple/polysemy
with a lot of hacking around.

The first instance of CanWeakenEnd is for the case where the fixed list has length 1.
The second instance is for cases where the fixed list has a length of 2 or more,
hence the double cons in the types to prevent overlap with the first instance.
-}
class CanWeakenEnd as effs where
    weakenEnd :: Union as ~> Union effs
instance effs ~ (a ': effs') => CanWeakenEnd '[a] effs where
    weakenEnd :: Union '[a] x -> Union effs x
weakenEnd Union '[a] x
u = a x -> Union effs x
forall (eff :: * -> *) (effs :: [* -> *]) a.
Member eff effs =>
eff a -> Union effs a
inj (Union '[a] x -> a x
forall (t :: * -> *) a. Union '[t] a -> t a
extract Union '[a] x
u)
instance (effs ~ (a ': effs'), CanWeakenEnd (b ': as) effs') => CanWeakenEnd (a ': b ': as) effs where
    weakenEnd :: Union (a : b : as) x -> Union effs x
weakenEnd = (Union (b : as) ~> Union effs')
-> Union (a : b : as) ~> Union (a : effs')
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under forall (as :: [* -> *]) (effs :: [* -> *]).
CanWeakenEnd as effs =>
Union as ~> Union effs
Union (b : as) ~> Union effs'
weakenEnd

weakenUnder :: forall effs a b . Union (a ': effs) ~> Union (a ': b ': effs)
weakenUnder :: Union (a : effs) x -> Union (a : b : effs) x
weakenUnder = (Union effs ~> Union (b : effs))
-> Union (a : effs) ~> Union (a : b : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under Union effs ~> Union (b : effs)
forall k (r :: [* -> *]) (a :: k) (any :: * -> *).
Union r a -> Union (any : r) a
weaken

weakenNUnder :: forall effs' effs a . Weakens effs' => Union (a ': effs) ~> Union (a ': (effs' :++: effs))
weakenNUnder :: Union (a : effs) ~> Union (a : (effs' :++: effs))
weakenNUnder = (Union effs ~> Union (effs' :++: effs))
-> Union (a : effs) ~> Union (a : (effs' :++: effs))
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under (forall (q :: [* -> *]) k (r :: [* -> *]) (a :: k).
Weakens q =>
Union r a -> Union (q :++: r) a
Union effs ~> Union (effs' :++: effs)
weakens @effs' @effs)

-- basically applies `under` n times to `weaken` composed m times, n = length as, m = length effs'
weakenMUnderN :: forall effs' as effs . (UnderN as, Weakens effs') => Union (as :++: effs) ~> Union (as :++: (effs' :++: effs))
weakenMUnderN :: Union (as :++: effs) ~> Union (as :++: (effs' :++: effs))
weakenMUnderN = (Union effs ~> Union (effs' :++: effs))
-> Union (as :++: effs) ~> Union (as :++: (effs' :++: effs))
forall (as :: [* -> *]) (effs :: [* -> *]) (effs' :: [* -> *]).
UnderN as =>
(Union effs ~> Union effs')
-> Union (as :++: effs) ~> Union (as :++: effs')
underN @as (forall (q :: [* -> *]) k (r :: [* -> *]) (a :: k).
Weakens q =>
Union r a -> Union (q :++: r) a
Union effs ~> Union (effs' :++: effs)
weakens @effs' @effs)


raiseEnd :: forall effs as. CanWeakenEnd as effs => Eff as ~> Eff effs
raiseEnd :: Eff as ~> Eff effs
raiseEnd = (Union as ~> Union effs) -> Eff as ~> Eff effs
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs forall (as :: [* -> *]) (effs :: [* -> *]).
CanWeakenEnd as effs =>
Union as ~> Union effs
Union as ~> Union effs
weakenEnd

raiseUnder :: forall effs a b . Eff (a ': effs) ~> Eff (a ': b ': effs)
raiseUnder :: Eff (a : effs) x -> Eff (a : b : effs) x
raiseUnder = (Union (a : effs) ~> Union (a : b : effs))
-> Eff (a : effs) ~> Eff (a : b : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs forall (effs :: [* -> *]) (a :: * -> *) (b :: * -> *).
Union (a : effs) ~> Union (a : b : effs)
Union (a : effs) ~> Union (a : b : effs)
weakenUnder

raiseUnder2 :: forall effs a b c . Eff (a ': b ': effs) ~> Eff (a ': b ': c ': effs)
raiseUnder2 :: Eff (a : b : effs) x -> Eff (a : b : c : effs) x
raiseUnder2 = (Union (a : b : effs) ~> Union (a : b : c : effs))
-> Eff (a : b : effs) ~> Eff (a : b : c : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs ((Union (b : effs) ~> Union (b : c : effs))
-> Union (a : b : effs) ~> Union (a : b : c : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under ((Union (b : effs) ~> Union (b : c : effs))
 -> Union (a : b : effs) ~> Union (a : b : c : effs))
-> (Union (b : effs) ~> Union (b : c : effs))
-> Union (a : b : effs) ~> Union (a : b : c : effs)
forall a b. (a -> b) -> a -> b
$ (Union effs ~> Union (c : effs))
-> Union (b : effs) ~> Union (b : c : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under Union effs ~> Union (c : effs)
forall k (r :: [* -> *]) (a :: k) (any :: * -> *).
Union r a -> Union (any : r) a
weaken)

raise2Under :: forall effs a b c . Eff (a ': effs) ~> Eff (a ': b ': c ': effs)
raise2Under :: Eff (a : effs) x -> Eff (a : b : c : effs) x
raise2Under = (Union (a : effs) ~> Union (a : b : c : effs))
-> Eff (a : effs) ~> Eff (a : b : c : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs ((Union effs ~> Union (b : c : effs))
-> Union (a : effs) ~> Union (a : b : c : effs)
forall (effs :: [* -> *]) (effs' :: [* -> *]) (a :: * -> *).
(Union effs ~> Union effs')
-> Union (a : effs) ~> Union (a : effs')
under ((Union effs ~> Union (b : c : effs))
 -> Union (a : effs) ~> Union (a : b : c : effs))
-> (Union effs ~> Union (b : c : effs))
-> Union (a : effs) ~> Union (a : b : c : effs)
forall a b. (a -> b) -> a -> b
$ Union (c : effs) x -> Union (b : c : effs) x
forall k (r :: [* -> *]) (a :: k) (any :: * -> *).
Union r a -> Union (any : r) a
weaken (Union (c : effs) x -> Union (b : c : effs) x)
-> (Union effs x -> Union (c : effs) x)
-> Union effs x
-> Union (b : c : effs) x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Union effs x -> Union (c : effs) x
forall k (r :: [* -> *]) (a :: k) (any :: * -> *).
Union r a -> Union (any : r) a
weaken)

raiseNUnder :: forall effs' effs a . Weakens effs' => Eff (a ': effs) ~> Eff (a ': (effs' :++: effs))
raiseNUnder :: Eff (a : effs) ~> Eff (a : (effs' :++: effs))
raiseNUnder = (Union (a : effs) ~> Union (a : (effs' :++: effs)))
-> Eff (a : effs) ~> Eff (a : (effs' :++: effs))
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs (Weakens effs' => Union (a : effs) ~> Union (a : (effs' :++: effs))
forall (effs' :: [* -> *]) (effs :: [* -> *]) (a :: * -> *).
Weakens effs' =>
Union (a : effs) ~> Union (a : (effs' :++: effs))
weakenNUnder @effs' @effs @a)

-- | Raise m effects under the top n effects
raiseMUnderN :: forall effs' as effs . (UnderN as, Weakens effs') => Eff (as :++: effs) ~> Eff (as :++: (effs' :++: effs))
raiseMUnderN :: Eff (as :++: effs) ~> Eff (as :++: (effs' :++: effs))
raiseMUnderN = (Union (as :++: effs) ~> Union (as :++: (effs' :++: effs)))
-> Eff (as :++: effs) ~> Eff (as :++: (effs' :++: effs))
forall (effs :: [* -> *]) (effs' :: [* -> *]).
(Union effs ~> Union effs') -> Eff effs ~> Eff effs'
mapEffs ((UnderN as, Weakens effs') =>
Union (as :++: effs) ~> Union (as :++: (effs' :++: effs))
forall (effs' :: [* -> *]) (as :: [* -> *]) (effs :: [* -> *]).
(UnderN as, Weakens effs') =>
Union (as :++: effs) ~> Union (as :++: (effs' :++: effs))
weakenMUnderN @effs' @as @effs)


-- | Handle a 'State' effect in terms of a "larger" 'State' effect from which we have a lens.
handleZoomedState :: Member (State s2) effs => Lens' s2 s1 -> (State s1 ~> Eff effs)
handleZoomedState :: Lens' s2 s1 -> State s1 ~> Eff effs
handleZoomedState Lens' s2 s1
l = \case
    State s1 x
Get   -> Getting s1 s2 s1 -> s2 -> s1
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting s1 s2 s1
Lens' s2 s1
l (s2 -> x) -> Eff effs s2 -> Eff effs x
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Eff effs s2
forall s (effs :: [* -> *]). Member (State s) effs => Eff effs s
get
    Put s1
v -> (s2 -> s2) -> Eff effs ()
forall s (effs :: [* -> *]).
Member (State s) effs =>
(s -> s) -> Eff effs ()
modify (ASetter s2 s2 s1 s1 -> s1 -> s2 -> s2
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter s2 s2 s1 s1
Lens' s2 s1
l s1
v)

-- | Handle a 'Writer' effect in terms of a "larger" 'Writer' effect from which we have a review.
handleZoomedWriter :: Member (Writer s2) effs => AReview s2 s1 -> (Writer s1 ~> Eff effs)
handleZoomedWriter :: AReview s2 s1 -> Writer s1 ~> Eff effs
handleZoomedWriter AReview s2 s1
p = \case
    Tell s1
w -> s2 -> Eff effs ()
forall w (effs :: [* -> *]).
Member (Writer w) effs =>
w -> Eff effs ()
tell (AReview s2 s1 -> s1 -> s2
forall b (m :: * -> *) t. MonadReader b m => AReview t b -> m t
review AReview s2 s1
p s1
w)

-- | Handle an 'Error' effect in terms of a "larger" 'Error' effect from which we have a review.
handleZoomedError :: Member (Error s2) effs => AReview s2 s1 -> (Error s1 ~> Eff effs)
handleZoomedError :: AReview s2 s1 -> Error s1 ~> Eff effs
handleZoomedError AReview s2 s1
p = \case
    Error s1
e -> s2 -> Eff effs x
forall e (effs :: [* -> *]) a.
Member (Error e) effs =>
e -> Eff effs a
throwError (AReview s2 s1 -> s1 -> s2
forall b (m :: * -> *) t. MonadReader b m => AReview t b -> m t
review AReview s2 s1
p s1
e)

-- | Handle a 'Reader' effect in terms of a "larger" 'Reader' effect from which we have a getter.
handleZoomedReader :: Member (Reader s2) effs => Getter s2 s1 -> (Reader s1 ~> Eff effs)
handleZoomedReader :: Getter s2 s1 -> Reader s1 ~> Eff effs
handleZoomedReader Getter s2 s1
g = \case
    Reader s1 x
Ask -> Getting s1 s2 s1 -> s2 -> s1
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting s1 s2 s1
Getter s2 s1
g (s2 -> x) -> Eff effs s2 -> Eff effs x
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Eff effs s2
forall r (effs :: [* -> *]). Member (Reader r) effs => Eff effs r
ask

-- | Handle a 'Writer' effect in terms of a "larger" 'State' effect from which we have a setter.
writeIntoState
    :: (Monoid s1, Member (State s2) effs)
    => Setter' s2 s1
    -> (Writer s1 ~> Eff effs)
writeIntoState :: Setter' s2 s1 -> Writer s1 ~> Eff effs
writeIntoState Setter' s2 s1
s = \case
    Tell s1
w -> (s2 -> s2) -> Eff effs ()
forall s (effs :: [* -> *]).
Member (State s) effs =>
(s -> s) -> Eff effs ()
modify (\s2
st -> s2
st s2 -> (s2 -> s2) -> s2
forall a b. a -> (a -> b) -> b
& (s1 -> Identity s1) -> s2 -> Identity s2
Setter' s2 s1
s ((s1 -> Identity s1) -> s2 -> Identity s2) -> s1 -> s2 -> s2
forall a s t. Semigroup a => ASetter s t a a -> a -> s -> t
<>~ s1
w)

-- | Handle a 'State' effect in terms of a monadic effect which has a 'MTL.MonadState' instance.
stateToMonadState
    :: (MTL.MonadState s m)
    => (State s ~> m)
stateToMonadState :: State s ~> m
stateToMonadState = \case
    State s x
Get   -> m x
forall s (m :: * -> *). MonadState s m => m s
MTL.get
    Put s
v -> s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
MTL.put s
v

monadStateToState
    :: (Member (State s) effs)
    => MTL.State s a
    -> Eff effs a
monadStateToState :: State s a -> Eff effs a
monadStateToState State s a
a = do
    s
s1 <- Eff effs s
forall s (effs :: [* -> *]). Member (State s) effs => Eff effs s
get
    let (a
r, s
s2) = State s a -> s -> (a, s)
forall s a. State s a -> s -> (a, s)
MTL.runState State s a
a s
s1
    s -> Eff effs ()
forall s (effs :: [* -> *]).
Member (State s) effs =>
s -> Eff effs ()
put s
s2
    a -> Eff effs a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Handle an 'Error' effect in terms of a monadic effect which has a 'MTL.MonadError' instance.
errorToMonadError
    :: (MTL.MonadError e m)
    => (Error e ~> m)
errorToMonadError :: Error e ~> m
errorToMonadError = \case
    Error e
e -> e -> m x
forall e (m :: * -> *) a. MonadError e m => e -> m a
MTL.throwError e
e

-- | Transform an error type
wrapError
    :: forall e f effs. Member (Error f) effs
    => (e -> f)
    -> Eff (Error e ': effs)
    ~> Eff effs
wrapError :: (e -> f) -> Eff (Error e : effs) ~> Eff effs
wrapError e -> f
f = (Eff (Error e : effs) x -> (e -> Eff effs x) -> Eff effs x)
-> (e -> Eff effs x) -> Eff (Error e : effs) x -> Eff effs x
forall a b c. (a -> b -> c) -> b -> a -> c
flip Eff (Error e : effs) x -> (e -> Eff effs x) -> Eff effs x
forall e (effs :: [* -> *]) a.
Eff (Error e : effs) a -> (e -> Eff effs a) -> Eff effs a
handleError (forall (effs :: [* -> *]) a.
Member (Error f) effs =>
f -> Eff effs a
forall e (effs :: [* -> *]) a.
Member (Error e) effs =>
e -> Eff effs a
throwError @f (f -> Eff effs x) -> (e -> f) -> e -> Eff effs x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> f
f)