{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeOperators         #-}
module PlutusIR.Compiler.Types where

import PlutusIR qualified as PIR
import PlutusIR.Compiler.Provenance
import PlutusIR.Error

import Control.Monad.Except
import Control.Monad.Reader

import Control.Lens

import PlutusCore qualified as PLC
import PlutusCore.InlineUtils
import PlutusCore.MkPlc qualified as PLC
import PlutusCore.Pretty qualified as PLC
import PlutusCore.Quote
import PlutusCore.StdLib.Type qualified as Types
import PlutusCore.TypeCheck.Internal qualified as PLC

import Data.Text qualified as T

-- | Extra flag to be passed in the TypeCheckM Reader context,
-- to signal if the PIR expression currently being typechecked is at the top-level
-- and thus its type can escape, or nested and thus not allowed to escape.
data AllowEscape = YesEscape | NoEscape

-- | extending the plc typecheck config with AllowEscape
data PirTCConfig uni fun = PirTCConfig {
      PirTCConfig uni fun -> TypeCheckConfig uni fun
_pirConfigTCConfig      :: PLC.TypeCheckConfig uni fun
      , PirTCConfig uni fun -> AllowEscape
_pirConfigAllowEscape :: AllowEscape
     }
makeLenses ''PirTCConfig

-- pir config has inside a plc config so it can act like it
instance PLC.HasTypeCheckConfig (PirTCConfig uni fun) uni fun where
    typeCheckConfig :: (TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
typeCheckConfig = (TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
forall (uni :: * -> *) fun (uni :: * -> *) fun.
Lens
  (PirTCConfig uni fun)
  (PirTCConfig uni fun)
  (TypeCheckConfig uni fun)
  (TypeCheckConfig uni fun)
pirConfigTCConfig

data CompilationOpts a = CompilationOpts {
    CompilationOpts a -> Bool
_coOptimize                   :: Bool
    , CompilationOpts a -> Bool
_coPedantic                 :: Bool
    , CompilationOpts a -> Bool
_coVerbose                  :: Bool
    , CompilationOpts a -> Bool
_coDebug                    :: Bool
    , CompilationOpts a -> Int
_coMaxSimplifierIterations  :: Int
    -- Simplifier passes
    , CompilationOpts a -> Bool
_coDoSimplifierUnwrapCancel :: Bool
    , CompilationOpts a -> Bool
_coDoSimplifierBeta         :: Bool
    , CompilationOpts a -> Bool
_coDoSimplifierInline       :: Bool
    , CompilationOpts a -> InlineHints Name (Provenance a)
_coInlineHints              :: InlineHints PLC.Name (Provenance a)
    , CompilationOpts a -> Bool
_coProfile                  :: Bool
    } deriving stock (Int -> CompilationOpts a -> ShowS
[CompilationOpts a] -> ShowS
CompilationOpts a -> String
(Int -> CompilationOpts a -> ShowS)
-> (CompilationOpts a -> String)
-> ([CompilationOpts a] -> ShowS)
-> Show (CompilationOpts a)
forall a. Int -> CompilationOpts a -> ShowS
forall a. [CompilationOpts a] -> ShowS
forall a. CompilationOpts a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CompilationOpts a] -> ShowS
$cshowList :: forall a. [CompilationOpts a] -> ShowS
show :: CompilationOpts a -> String
$cshow :: forall a. CompilationOpts a -> String
showsPrec :: Int -> CompilationOpts a -> ShowS
$cshowsPrec :: forall a. Int -> CompilationOpts a -> ShowS
Show)

makeLenses ''CompilationOpts

defaultCompilationOpts :: CompilationOpts a
defaultCompilationOpts :: CompilationOpts a
defaultCompilationOpts = CompilationOpts :: forall a.
Bool
-> Bool
-> Bool
-> Bool
-> Int
-> Bool
-> Bool
-> Bool
-> InlineHints Name (Provenance a)
-> Bool
-> CompilationOpts a
CompilationOpts
  { _coOptimize :: Bool
_coOptimize = Bool
True
  , _coPedantic :: Bool
_coPedantic = Bool
False
  , _coVerbose :: Bool
_coVerbose = Bool
False
  , _coDebug :: Bool
_coDebug = Bool
False
  , _coMaxSimplifierIterations :: Int
_coMaxSimplifierIterations = Int
12
  , _coDoSimplifierUnwrapCancel :: Bool
_coDoSimplifierUnwrapCancel = Bool
True
  , _coDoSimplifierBeta :: Bool
_coDoSimplifierBeta = Bool
True
  , _coDoSimplifierInline :: Bool
_coDoSimplifierInline = Bool
True
  , _coInlineHints :: InlineHints Name (Provenance a)
_coInlineHints = InlineHints Name (Provenance a)
forall a. Monoid a => a
mempty
  , _coProfile :: Bool
_coProfile = Bool
False
  }

data CompilationCtx uni fun a = CompilationCtx {
    CompilationCtx uni fun a -> CompilationOpts a
_ccOpts              :: CompilationOpts a
    , CompilationCtx uni fun a -> Provenance a
_ccEnclosing       :: Provenance a
    -- | Decide to either typecheck (passing a specific tcconfig) or not by passing 'Nothing'.
    , CompilationCtx uni fun a -> Maybe (PirTCConfig uni fun)
_ccTypeCheckConfig :: Maybe (PirTCConfig uni fun)
    }

makeLenses ''CompilationCtx

toDefaultCompilationCtx :: PLC.TypeCheckConfig uni fun -> CompilationCtx uni fun a
toDefaultCompilationCtx :: TypeCheckConfig uni fun -> CompilationCtx uni fun a
toDefaultCompilationCtx TypeCheckConfig uni fun
configPlc = CompilationOpts a
-> Provenance a
-> Maybe (PirTCConfig uni fun)
-> CompilationCtx uni fun a
forall (uni :: * -> *) fun a.
CompilationOpts a
-> Provenance a
-> Maybe (PirTCConfig uni fun)
-> CompilationCtx uni fun a
CompilationCtx CompilationOpts a
forall a. CompilationOpts a
defaultCompilationOpts Provenance a
forall a. Provenance a
noProvenance (Maybe (PirTCConfig uni fun) -> CompilationCtx uni fun a)
-> Maybe (PirTCConfig uni fun) -> CompilationCtx uni fun a
forall a b. (a -> b) -> a -> b
$ PirTCConfig uni fun -> Maybe (PirTCConfig uni fun)
forall a. a -> Maybe a
Just (TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
forall (uni :: * -> *) fun.
TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
PirTCConfig TypeCheckConfig uni fun
configPlc AllowEscape
YesEscape)

getEnclosing :: MonadReader (CompilationCtx uni fun a) m => m (Provenance a)
getEnclosing :: m (Provenance a)
getEnclosing = Getting (Provenance a) (CompilationCtx uni fun a) (Provenance a)
-> m (Provenance a)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Provenance a) (CompilationCtx uni fun a) (Provenance a)
forall (uni :: * -> *) fun a.
Lens' (CompilationCtx uni fun a) (Provenance a)
ccEnclosing

withEnclosing :: MonadReader (CompilationCtx uni fun a) m => (Provenance a -> Provenance a) -> m b -> m b
withEnclosing :: (Provenance a -> Provenance a) -> m b -> m b
withEnclosing Provenance a -> Provenance a
f = (CompilationCtx uni fun a -> CompilationCtx uni fun a)
-> m b -> m b
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ASetter
  (CompilationCtx uni fun a)
  (CompilationCtx uni fun a)
  (Provenance a)
  (Provenance a)
-> (Provenance a -> Provenance a)
-> CompilationCtx uni fun a
-> CompilationCtx uni fun a
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter
  (CompilationCtx uni fun a)
  (CompilationCtx uni fun a)
  (Provenance a)
  (Provenance a)
forall (uni :: * -> *) fun a.
Lens' (CompilationCtx uni fun a) (Provenance a)
ccEnclosing Provenance a -> Provenance a
f)

runIf
  :: MonadReader (CompilationCtx uni fun a) m
  => m Bool
  -> (b -> m b)
  -> (b -> m b)
runIf :: m Bool -> (b -> m b) -> b -> m b
runIf m Bool
condition b -> m b
pass b
arg = do
  Bool
doPass <- m Bool
condition
  if Bool
doPass then b -> m b
pass b
arg else b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
arg

runIfOpts :: MonadReader (CompilationCtx uni fun a) m => (b -> m b) -> (b -> m b)
runIfOpts :: (b -> m b) -> b -> m b
runIfOpts = m Bool -> (b -> m b) -> b -> m b
forall (uni :: * -> *) fun a (m :: * -> *) b.
MonadReader (CompilationCtx uni fun a) m =>
m Bool -> (b -> m b) -> b -> m b
runIf (m Bool -> (b -> m b) -> b -> m b)
-> m Bool -> (b -> m b) -> b -> m b
forall a b. (a -> b) -> a -> b
$ Getting Bool (CompilationCtx uni fun a) Bool -> m Bool
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view ((CompilationOpts a -> Const Bool (CompilationOpts a))
-> CompilationCtx uni fun a
-> Const Bool (CompilationCtx uni fun a)
forall (uni :: * -> *) fun a.
Lens' (CompilationCtx uni fun a) (CompilationOpts a)
ccOpts ((CompilationOpts a -> Const Bool (CompilationOpts a))
 -> CompilationCtx uni fun a
 -> Const Bool (CompilationCtx uni fun a))
-> ((Bool -> Const Bool Bool)
    -> CompilationOpts a -> Const Bool (CompilationOpts a))
-> Getting Bool (CompilationCtx uni fun a) Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Const Bool Bool)
-> CompilationOpts a -> Const Bool (CompilationOpts a)
forall a. Lens' (CompilationOpts a) Bool
coOptimize)

type PLCTerm uni fun a = PLC.Term PLC.TyName PLC.Name uni fun (Provenance a)
type PLCType uni a = PLC.Type PLC.TyName uni (Provenance a)

-- | A possibly recursive type.
data PLCRecType uni fun a
    = PlainType (PLCType uni a)
    | RecursiveType (Types.RecursiveType uni fun (Provenance a))

-- | Get the actual type inside a 'PLCRecType'.
getType :: PLCRecType uni fun a -> PLCType uni a
getType :: PLCRecType uni fun a -> PLCType uni a
getType PLCRecType uni fun a
r = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
t                                                -> PLCType uni a
t
    RecursiveType Types.RecursiveType {_recursiveType :: forall (uni :: * -> *) fun ann.
RecursiveType uni fun ann -> Type TyName uni ann
Types._recursiveType=PLCType uni a
t} -> PLCType uni a
t

-- | Wrap a term appropriately for a possibly recursive type.
wrap :: Provenance a -> PLCRecType uni fun a -> [PLCType uni a] -> PIRTerm uni fun a -> PIRTerm uni fun a
wrap :: Provenance a
-> PLCRecType uni fun a
-> [PLCType uni a]
-> PIRTerm uni fun a
-> PIRTerm uni fun a
wrap Provenance a
p PLCRecType uni fun a
r [PLCType uni a]
tvs PIRTerm uni fun a
t = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
_                                                      -> PIRTerm uni fun a
t
    RecursiveType Types.RecursiveType {_recursiveWrap :: forall (uni :: * -> *) fun ann.
RecursiveType uni fun ann
-> forall (term :: * -> *).
   TermLike term TyName Name uni fun =>
   [Type TyName uni ann] -> term ann -> term ann
Types._recursiveWrap=forall (term :: * -> *).
TermLike term TyName Name uni fun =>
[PLCType uni a] -> term (Provenance a) -> term (Provenance a)
wrapper} -> Provenance a -> PIRTerm uni fun a -> PIRTerm uni fun a
forall (f :: * -> *) b a.
Functor f =>
Provenance b -> f a -> f (Provenance b)
setProvenance Provenance a
p (PIRTerm uni fun a -> PIRTerm uni fun a)
-> PIRTerm uni fun a -> PIRTerm uni fun a
forall a b. (a -> b) -> a -> b
$ [PLCType uni a] -> PIRTerm uni fun a -> PIRTerm uni fun a
forall (term :: * -> *).
TermLike term TyName Name uni fun =>
[PLCType uni a] -> term (Provenance a) -> term (Provenance a)
wrapper [PLCType uni a]
tvs PIRTerm uni fun a
t

-- | Unwrap a term appropriately for a possibly recursive type.
unwrap :: Provenance a -> PLCRecType uni fun a -> PIRTerm uni fun a -> PIRTerm uni fun a
unwrap :: Provenance a
-> PLCRecType uni fun a -> PIRTerm uni fun a -> PIRTerm uni fun a
unwrap Provenance a
p PLCRecType uni fun a
r PIRTerm uni fun a
t = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
_                          -> PIRTerm uni fun a
t
    RecursiveType Types.RecursiveType {} -> Provenance a -> PIRTerm uni fun a -> PIRTerm uni fun a
forall tyname name (uni :: * -> *) fun a.
a -> Term tyname name uni fun a -> Term tyname name uni fun a
PIR.Unwrap Provenance a
p PIRTerm uni fun a
t

type PIRTerm uni fun a = PIR.Term PIR.TyName PIR.Name uni fun (Provenance a)
type PIRType uni a = PIR.Type PIR.TyName uni (Provenance a)

type Compiling m e uni fun a =
    ( Monad m
    , MonadReader (CompilationCtx uni fun a) m
    , AsTypeError e (PIR.Term PIR.TyName PIR.Name uni fun ()) uni fun (Provenance a)
    , AsTypeErrorExt e uni (Provenance a)
    , AsError e uni fun (Provenance a)
    , MonadError e m
    , MonadQuote m
    , Ord a
    , PLC.Typecheckable uni fun
    , PLC.GEq uni
    -- Pretty printing instances
    , PLC.Pretty fun
    , PLC.Closed uni
    , PLC.GShow uni
    , uni `PLC.Everywhere` PLC.PrettyConst
    , PLC.Pretty a
    )

type TermDef tyname name uni fun a = PLC.Def (PLC.VarDecl tyname name uni fun a) (PIR.Term tyname name uni fun a)

-- | We generate some shared definitions compilation, this datatype
-- defines the "keys" for those definitions.
data SharedName =
    FixpointCombinator Integer
    | FixBy
    deriving stock (Int -> SharedName -> ShowS
[SharedName] -> ShowS
SharedName -> String
(Int -> SharedName -> ShowS)
-> (SharedName -> String)
-> ([SharedName] -> ShowS)
-> Show SharedName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SharedName] -> ShowS
$cshowList :: [SharedName] -> ShowS
show :: SharedName -> String
$cshow :: SharedName -> String
showsPrec :: Int -> SharedName -> ShowS
$cshowsPrec :: Int -> SharedName -> ShowS
Show, SharedName -> SharedName -> Bool
(SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool) -> Eq SharedName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SharedName -> SharedName -> Bool
$c/= :: SharedName -> SharedName -> Bool
== :: SharedName -> SharedName -> Bool
$c== :: SharedName -> SharedName -> Bool
Eq, Eq SharedName
Eq SharedName
-> (SharedName -> SharedName -> Ordering)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> SharedName)
-> (SharedName -> SharedName -> SharedName)
-> Ord SharedName
SharedName -> SharedName -> Bool
SharedName -> SharedName -> Ordering
SharedName -> SharedName -> SharedName
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SharedName -> SharedName -> SharedName
$cmin :: SharedName -> SharedName -> SharedName
max :: SharedName -> SharedName -> SharedName
$cmax :: SharedName -> SharedName -> SharedName
>= :: SharedName -> SharedName -> Bool
$c>= :: SharedName -> SharedName -> Bool
> :: SharedName -> SharedName -> Bool
$c> :: SharedName -> SharedName -> Bool
<= :: SharedName -> SharedName -> Bool
$c<= :: SharedName -> SharedName -> Bool
< :: SharedName -> SharedName -> Bool
$c< :: SharedName -> SharedName -> Bool
compare :: SharedName -> SharedName -> Ordering
$ccompare :: SharedName -> SharedName -> Ordering
$cp1Ord :: Eq SharedName
Ord)

toProgramName :: SharedName -> Quote PLC.Name
toProgramName :: SharedName -> Quote Name
toProgramName (FixpointCombinator Integer
n) = Text -> Quote Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName (Text
"fix" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Integer -> String
forall a. Show a => a -> String
show Integer
n))
toProgramName SharedName
FixBy                  = Text -> Quote Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName Text
"fixBy"