-- | An internal module that defines functions for deciding equality of values of data types
-- that encode things with binders.

{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE UndecidableInstances  #-}

module PlutusCore.Eq
    ( LR (..)
    , RL (..)
    , EqRename
    , ScopedEqRename
    , runEqRename
    , withTwinBindings
    , eqNameM
    , eqM
    ) where

import PlutusPrelude

import PlutusCore.Name
import PlutusCore.Rename.Monad

import Control.Lens

{- Note [Modulo alpha]
The implemented algorithm of checking equality modulo alpha works as follows
(taking types as an example):

- traverse both the types simultaneously
- if the outermost constructors differ, return 'False'
- otherwise if the constructors are binders, then record that the two possibly distinct bound
    variables map to each other
- otherwise if the constructors are variables, look them up in the current scope
    * if both are in scope, then those are bound variables, so check that they map to each other
        (i.e. are introduced by twin binders)
    * if both are not in scope, then those are free variables, so check if their uniques are equal
    * if one is in scope and the other one is not, then return 'False'
- otherwise check equality of non-recursive constituents and recurse for recursive occurrences
-}

{- Note [Scope tracking]
Scopes (term level vs type level) are resolved automatically in a type-driven way much like we do
in the renamer. This allows not to think about scopes when implementing the actual equality checks
and makes it impossible to confuse the scopes and e.g. insert a type-level name into a term-level
scope. Also allows to define a function that records bindings and a function that checks equality
of two names in a general manner for both the scopes.
-}

{- Note [Side tracking]
The simplest way to track that twin variables map to each other is to have two contexts:

- left-hand side variables and what they map to
- right-hand side variables and what they map to

(we refer to a first argument as being on the left-hand side and to a second argument as being on the right-hand side)

E.g. when checking equality of these two types:

    all (x_3 :: *) (x_3 :: *). x_3 -> x_3
    all (y_4 :: *) (z_5 :: *). y_4 -> z_5

we first record that @x_3@ maps to @y_4@ and vice versa, then record that @x_3@ maps to @z_5@ and
vice versa. This way when we later check equality of @x_3@ and @y_4@ we know that it doesn't hold,
because even though @y_4@ maps to @x_3@, @x_3@ doesn't map to @y_4@, because that mapping was
overwritten by the @x_3@-to-@z_5@ one.

For storing the left-to-right and right-to-left mappings separately we use the 'Bilateral' data
type. Given that we track not only sides, but also scopes, we instantiate 'Bilateral' with either
type-level-only renamings (for checking equality of types) or scoped ones (for checking equality
of terms and programs). This amounts to the following generic monad:

    RenameT (Bilateral ren) m

i.e. regardless of what the underlying renaming is, it has to be bilateral.

We zoom into the sides of a bilateral renaming using the 'LR' and 'RL' newtype wrappers using the
same 'HasRenaming' machinery that we use for zooming into the scopes of a scoped renaming:

- the 'LR' wrapper allows to retrieve a lens focusing on the left  renaming
- the 'RL' wrapper allows to retrieve a lens focusing on the right renaming

I.e. you wrap a name into either 'LR' and 'RL' and depending on that you get focused on either
left-to-right or right-to-left part of a 'Bilateral' renaming.

So e.g.

    withRenamedName (LR name1) (LR name2)

reads as "record that @name1@ maps to @name2@ from left to right"

and

    lookupNameM $ RL name2

reads as "look the right-to-left mapping of @name2@ up".

I.e. we first resolve sides using explicit wrappers and then scope resolution happens automatically
on the basis of existing type information (e.g. 'TyName' is a type-level name, hence we need the
type-level renaming).
-}

-- See Note [Side tracking].
-- | From left to right.
newtype LR a = LR
    { LR a -> a
unLR :: a
    } deriving stock ((forall x. LR a -> Rep (LR a) x)
-> (forall x. Rep (LR a) x -> LR a) -> Generic (LR a)
forall x. Rep (LR a) x -> LR a
forall x. LR a -> Rep (LR a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (LR a) x -> LR a
forall a x. LR a -> Rep (LR a) x
$cto :: forall a x. Rep (LR a) x -> LR a
$cfrom :: forall a x. LR a -> Rep (LR a) x
Generic)

-- See Note [Side tracking].
-- | From right to left.
newtype RL a = RL
    { RL a -> a
unRL :: a
    } deriving stock ((forall x. RL a -> Rep (RL a) x)
-> (forall x. Rep (RL a) x -> RL a) -> Generic (RL a)
forall x. Rep (RL a) x -> RL a
forall x. RL a -> Rep (RL a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (RL a) x -> RL a
forall a x. RL a -> Rep (RL a) x
$cto :: forall a x. Rep (RL a) x -> RL a
$cfrom :: forall a x. RL a -> Rep (RL a) x
Generic)

-- See Note [Side tracking].
-- | A left @a@ and a right @a@.
data Bilateral a = Bilateral
    { Bilateral a -> a
_bilateralL :: a
    , Bilateral a -> a
_bilateralR :: a
    }

makeLenses ''Bilateral

instance Wrapped (LR a)
instance Wrapped (RL a)
instance HasUnique name unique => HasUnique (LR name) (LR unique)
instance HasUnique name unique => HasUnique (RL name) (RL unique)

instance Semigroup a => Semigroup (Bilateral a) where
    Bilateral a
l1 a
r1 <> :: Bilateral a -> Bilateral a -> Bilateral a
<> Bilateral a
l2 a
r2 = a -> a -> Bilateral a
forall a. a -> a -> Bilateral a
Bilateral (a
l1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
l2) (a
r1 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
r2)

instance Monoid a => Monoid (Bilateral a) where
    mempty :: Bilateral a
mempty = a -> a -> Bilateral a
forall a. a -> a -> Bilateral a
Bilateral a
forall a. Monoid a => a
mempty a
forall a. Monoid a => a
mempty

-- To rename from left to right is to update the left renaming.
instance HasRenaming ren unique => HasRenaming (Bilateral ren) (LR unique) where
    renaming :: (Renaming (LR unique) -> f (Renaming (LR unique)))
-> Bilateral ren -> f (Bilateral ren)
renaming = (ren -> f ren) -> Bilateral ren -> f (Bilateral ren)
forall a. Lens' (Bilateral a) a
bilateralL ((ren -> f ren) -> Bilateral ren -> f (Bilateral ren))
-> ((Renaming (LR unique) -> f (Renaming (LR unique)))
    -> ren -> f ren)
-> (Renaming (LR unique) -> f (Renaming (LR unique)))
-> Bilateral ren
-> f (Bilateral ren)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Renaming unique -> f (Renaming unique)) -> ren -> f ren
forall ren unique.
HasRenaming ren unique =>
Lens' ren (Renaming unique)
renaming ((Renaming unique -> f (Renaming unique)) -> ren -> f ren)
-> ((Renaming (LR unique) -> f (Renaming (LR unique)))
    -> Renaming unique -> f (Renaming unique))
-> (Renaming (LR unique) -> f (Renaming (LR unique)))
-> ren
-> f ren
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t a b.
(Coercible (Renaming unique) a, Coercible t b) =>
Iso (Renaming unique) t a b
forall s t a b. (Coercible s a, Coercible t b) => Iso s t a b
coerced @(Renaming unique)

-- To rename from right to left is to update the right renaming.
instance HasRenaming ren unique => HasRenaming (Bilateral ren) (RL unique) where
    renaming :: (Renaming (RL unique) -> f (Renaming (RL unique)))
-> Bilateral ren -> f (Bilateral ren)
renaming = (ren -> f ren) -> Bilateral ren -> f (Bilateral ren)
forall a. Lens' (Bilateral a) a
bilateralR ((ren -> f ren) -> Bilateral ren -> f (Bilateral ren))
-> ((Renaming (RL unique) -> f (Renaming (RL unique)))
    -> ren -> f ren)
-> (Renaming (RL unique) -> f (Renaming (RL unique)))
-> Bilateral ren
-> f (Bilateral ren)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Renaming unique -> f (Renaming unique)) -> ren -> f ren
forall ren unique.
HasRenaming ren unique =>
Lens' ren (Renaming unique)
renaming ((Renaming unique -> f (Renaming unique)) -> ren -> f ren)
-> ((Renaming (RL unique) -> f (Renaming (RL unique)))
    -> Renaming unique -> f (Renaming unique))
-> (Renaming (RL unique) -> f (Renaming (RL unique)))
-> ren
-> f ren
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t a b.
(Coercible (Renaming unique) a, Coercible t b) =>
Iso (Renaming unique) t a b
forall s t a b. (Coercible s a, Coercible t b) => Iso s t a b
coerced @(Renaming unique)

-- | The type of a runnable equality check. @Maybe ()@ is isomorphic to 'Bool' and we use it
-- instead of 'Bool', because this unlocks the convenient and readable do-notation and allows for
-- automatic short-circuiting, which would be tedious with @Rename (Bilateral ren) Bool@.
type EqRename ren = RenameT (Bilateral ren) Maybe ()
type ScopedEqRename = EqRename ScopedRenaming

-- | Run an 'EqRename' computation.
runEqRename :: Monoid ren => EqRename ren -> Bool
runEqRename :: EqRename ren -> Bool
runEqRename = Maybe () -> Bool
forall a. Maybe a -> Bool
isJust (Maybe () -> Bool)
-> (EqRename ren -> Maybe ()) -> EqRename ren -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EqRename ren -> Maybe ()
forall ren (m :: * -> *) a. Monoid ren => RenameT ren m a -> m a
runRenameT

-- See Note [Modulo alpha].
-- | Record that two names map to each other.
withTwinBindings
    :: (HasRenaming ren unique, HasUnique name unique, Monad m)
    => name -> name -> RenameT (Bilateral ren) m c -> RenameT (Bilateral ren) m c
withTwinBindings :: name
-> name
-> RenameT (Bilateral ren) m c
-> RenameT (Bilateral ren) m c
withTwinBindings name
name1 name
name2 RenameT (Bilateral ren) m c
k =
    LR name
-> LR name
-> RenameT (Bilateral ren) m c
-> RenameT (Bilateral ren) m c
forall ren unique name (m :: * -> *) c.
(HasRenaming ren unique, HasUnique name unique,
 MonadReader ren m) =>
name -> name -> m c -> m c
withRenamedName (name -> LR name
forall a. a -> LR a
LR name
name1) (name -> LR name
forall a. a -> LR a
LR name
name2) (RenameT (Bilateral ren) m c -> RenameT (Bilateral ren) m c)
-> RenameT (Bilateral ren) m c -> RenameT (Bilateral ren) m c
forall a b. (a -> b) -> a -> b
$
    RL name
-> RL name
-> RenameT (Bilateral ren) m c
-> RenameT (Bilateral ren) m c
forall ren unique name (m :: * -> *) c.
(HasRenaming ren unique, HasUnique name unique,
 MonadReader ren m) =>
name -> name -> m c -> m c
withRenamedName (name -> RL name
forall a. a -> RL a
RL name
name2) (name -> RL name
forall a. a -> RL a
RL name
name1) RenameT (Bilateral ren) m c
k

-- See Note [Modulo alpha].
-- | Check equality of two names.
eqNameM
    :: (HasRenaming ren unique, HasUnique name unique, Eq unique)
    => name -> name -> EqRename ren
eqNameM :: name -> name -> EqRename ren
eqNameM name
name1 name
name2 = do
    Maybe (LR unique)
mayUniq2' <- LR name -> RenameT (Bilateral ren) Maybe (Maybe (LR unique))
forall name unique ren (m :: * -> *).
(HasUnique name unique, HasRenaming ren unique,
 MonadReader ren m) =>
name -> m (Maybe unique)
lookupNameM (LR name -> RenameT (Bilateral ren) Maybe (Maybe (LR unique)))
-> LR name -> RenameT (Bilateral ren) Maybe (Maybe (LR unique))
forall a b. (a -> b) -> a -> b
$ name -> LR name
forall a. a -> LR a
LR name
name1
    Maybe (RL unique)
mayUniq1' <- RL name -> RenameT (Bilateral ren) Maybe (Maybe (RL unique))
forall name unique ren (m :: * -> *).
(HasUnique name unique, HasRenaming ren unique,
 MonadReader ren m) =>
name -> m (Maybe unique)
lookupNameM (RL name -> RenameT (Bilateral ren) Maybe (Maybe (RL unique)))
-> RL name -> RenameT (Bilateral ren) Maybe (Maybe (RL unique))
forall a b. (a -> b) -> a -> b
$ name -> RL name
forall a. a -> RL a
RL name
name2
    let uniq1 :: unique
uniq1 = name
name1 name -> Getting unique name unique -> unique
forall s a. s -> Getting a s a -> a
^. Getting unique name unique
forall a unique. HasUnique a unique => Lens' a unique
unique
        uniq2 :: unique
uniq2 = name
name2 name -> Getting unique name unique -> unique
forall s a. s -> Getting a s a -> a
^. Getting unique name unique
forall a unique. HasUnique a unique => Lens' a unique
unique
    Bool -> EqRename ren
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> EqRename ren) -> Bool -> EqRename ren
forall a b. (a -> b) -> a -> b
$ case (Maybe (RL unique)
mayUniq1', Maybe (LR unique)
mayUniq2') of
        (Maybe (RL unique)
Nothing         , Maybe (LR unique)
Nothing         ) -> unique
uniq1 unique -> unique -> Bool
forall a. Eq a => a -> a -> Bool
== unique
uniq2
        (Just (RL unique
uniq1'), Just (LR unique
uniq2')) -> unique
uniq1 unique -> unique -> Bool
forall a. Eq a => a -> a -> Bool
== unique
uniq1' Bool -> Bool -> Bool
&& unique
uniq2 unique -> unique -> Bool
forall a. Eq a => a -> a -> Bool
== unique
uniq2'
        (Maybe (RL unique)
_               , Maybe (LR unique)
_               ) -> Bool
False

-- | Check equality of things having an 'Eq' instance.
eqM :: Eq a => a -> a -> EqRename ren
eqM :: a -> a -> EqRename ren
eqM a
x1 a
x2 = Bool -> EqRename ren
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> EqRename ren) -> Bool -> EqRename ren
forall a b. (a -> b) -> a -> b
$ a
x1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x2