-- | Kind/type inference/checking.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies    #-}

module PlutusCore.TypeCheck
    ( ToKind
    , Typecheckable
    -- * Configuration.
    , BuiltinTypes (..)
    , TypeCheckConfig (..)
    , tccBuiltinTypes
    , builtinMeaningsToTypes
    , getDefTypeCheckConfig
    -- * Kind/type inference/checking.
    , inferKind
    , checkKind
    , inferType
    , checkType
    , inferTypeOfProgram
    , checkTypeOfProgram
    ) where

import PlutusPrelude

import PlutusCore.Builtin
import PlutusCore.Core
import PlutusCore.Error
import PlutusCore.Name
import PlutusCore.Normalize
import PlutusCore.Quote
import PlutusCore.Rename
import PlutusCore.TypeCheck.Internal

import Control.Monad.Except
import Data.Array
import Universe

type Typecheckable uni fun = (ToKind uni, HasUniApply uni, ToBuiltinMeaning uni fun)

-- | Extract the 'TypeScheme' from a 'BuiltinMeaning' and convert it to the
-- corresponding 'Type' for each built-in function.
builtinMeaningsToTypes
    :: (MonadError err m, AsTypeError err term uni fun ann, Typecheckable uni fun)
    => ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes :: ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes ann
ann =
    QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun)
forall (m :: * -> *) a. Monad m => QuoteT m a -> m a
runQuoteT (QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array fun (Dupable (Normalized (Type TyName uni ())))
 -> BuiltinTypes uni fun)
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> BuiltinTypes uni fun
forall (uni :: * -> *) fun.
Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> BuiltinTypes uni fun
BuiltinTypes (Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
 -> BuiltinTypes uni fun)
-> (Array fun (Dupable (Normalized (Type TyName uni ())))
    -> Maybe (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> Array fun (Dupable (Normalized (Type TyName uni ())))
-> BuiltinTypes uni fun
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array fun (Dupable (Normalized (Type TyName uni ())))
-> Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall a. a -> Maybe a
Just) (QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT
         m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT
      m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> Array
         fun (QuoteT m (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
forall i a. (Bounded i, Enum i, Ix i) => (i -> a) -> Array i a
tabulateArray ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall a b. (a -> b) -> a -> b
$ \fun
fun -> do
        let ty :: Type TyName uni ()
ty = fun -> Type TyName uni ()
forall (uni :: * -> *) fun.
ToBuiltinMeaning uni fun =>
fun -> Type TyName uni ()
typeOfBuiltinFunction fun
fun
        Kind ()
_ <- TypeCheckConfig uni fun
-> Type TyName uni ann -> QuoteT m (Kind ())
forall (m :: * -> *) err term (uni :: * -> *) fun ann.
(MonadQuote m, MonadError err m, AsTypeError err term uni fun ann,
 ToKind uni) =>
TypeCheckConfig uni fun -> Type TyName uni ann -> m (Kind ())
inferKind (BuiltinTypes uni fun -> TypeCheckConfig uni fun
forall (uni :: * -> *) fun.
BuiltinTypes uni fun -> TypeCheckConfig uni fun
TypeCheckConfig (BuiltinTypes uni fun -> TypeCheckConfig uni fun)
-> BuiltinTypes uni fun -> TypeCheckConfig uni fun
forall a b. (a -> b) -> a -> b
$ Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> BuiltinTypes uni fun
forall (uni :: * -> *) fun.
Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> BuiltinTypes uni fun
BuiltinTypes Maybe (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall a. Maybe a
Nothing) (Type TyName uni ann -> QuoteT m (Kind ()))
-> Type TyName uni ann -> QuoteT m (Kind ())
forall a b. (a -> b) -> a -> b
$ ann
ann ann -> Type TyName uni () -> Type TyName uni ann
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Type TyName uni ()
ty
        Normalized (Type TyName uni ())
-> Dupable (Normalized (Type TyName uni ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Normalized (Type TyName uni ())
 -> Dupable (Normalized (Type TyName uni ())))
-> QuoteT m (Normalized (Type TyName uni ()))
-> QuoteT m (Dupable (Normalized (Type TyName uni ())))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type TyName uni () -> QuoteT m (Normalized (Type TyName uni ()))
forall tyname (m :: * -> *) (uni :: * -> *) ann.
(HasUnique tyname TypeUnique, MonadQuote m, HasUniApply uni) =>
Type tyname uni ann -> m (Normalized (Type tyname uni ann))
normalizeType Type TyName uni ()
ty

-- | Get the default type checking config.
getDefTypeCheckConfig
    :: (MonadError err m, AsTypeError err term uni fun ann, Typecheckable uni fun)
    => ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig :: ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig ann
ann = BuiltinTypes uni fun -> TypeCheckConfig uni fun
forall (uni :: * -> *) fun.
BuiltinTypes uni fun -> TypeCheckConfig uni fun
TypeCheckConfig (BuiltinTypes uni fun -> TypeCheckConfig uni fun)
-> m (BuiltinTypes uni fun) -> m (TypeCheckConfig uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ann -> m (BuiltinTypes uni fun)
forall err (m :: * -> *) term (uni :: * -> *) fun ann.
(MonadError err m, AsTypeError err term uni fun ann,
 Typecheckable uni fun) =>
ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes ann
ann

-- | Infer the kind of a type.
inferKind
    :: (MonadQuote m, MonadError err m, AsTypeError err term uni fun ann, ToKind uni)
    => TypeCheckConfig uni fun -> Type TyName uni ann -> m (Kind ())
inferKind :: TypeCheckConfig uni fun -> Type TyName uni ann -> m (Kind ())
inferKind TypeCheckConfig uni fun
config = TypeCheckConfig uni fun
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err (Kind ())
-> m (Kind ())
forall err (m :: * -> *) cfg (uni :: * -> *) fun a.
(MonadError err m, MonadQuote m) =>
cfg -> TypeCheckM uni fun cfg err a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckM uni fun (TypeCheckConfig uni fun) err (Kind ())
 -> m (Kind ()))
-> (Type TyName uni ann
    -> TypeCheckM uni fun (TypeCheckConfig uni fun) err (Kind ()))
-> Type TyName uni ann
-> m (Kind ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type TyName uni ann
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err (Kind ())
forall err term (uni :: * -> *) fun ann cfg.
(AsTypeError err term uni fun ann, ToKind uni) =>
Type TyName uni ann -> TypeCheckM uni fun cfg err (Kind ())
inferKindM

-- | Check a type against a kind.
-- Infers the kind of the type and checks that it's equal to the given kind
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkKind
    :: (MonadQuote m, MonadError err m, AsTypeError err term uni fun ann, ToKind uni)
    => TypeCheckConfig uni fun -> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind :: TypeCheckConfig uni fun
-> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind TypeCheckConfig uni fun
config ann
ann Type TyName uni ann
ty = TypeCheckConfig uni fun
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err () -> m ()
forall err (m :: * -> *) cfg (uni :: * -> *) fun a.
(MonadError err m, MonadQuote m) =>
cfg -> TypeCheckM uni fun cfg err a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckM uni fun (TypeCheckConfig uni fun) err () -> m ())
-> (Kind () -> TypeCheckM uni fun (TypeCheckConfig uni fun) err ())
-> Kind ()
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ann
-> Type TyName uni ann
-> Kind ()
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err ()
forall err term (uni :: * -> *) fun ann cfg.
(AsTypeError err term uni fun ann, ToKind uni) =>
ann
-> Type TyName uni ann -> Kind () -> TypeCheckM uni fun cfg err ()
checkKindM ann
ann Type TyName uni ann
ty

-- | Infer the type of a term.
inferType
    :: ( MonadError err m, MonadQuote m
       , AsTypeError err (Term TyName Name uni fun ()) uni fun ann, ToKind uni, HasUniApply uni
       , GEq uni, Ix fun
       )
    => TypeCheckConfig uni fun -> Term TyName Name uni fun ann -> m (Normalized (Type TyName uni ()))
inferType :: TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config = Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
rename (Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann))
-> (Term TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> TypeCheckConfig uni fun
-> TypeCheckM
     uni
     fun
     (TypeCheckConfig uni fun)
     err
     (Normalized (Type TyName uni ()))
-> m (Normalized (Type TyName uni ()))
forall err (m :: * -> *) cfg (uni :: * -> *) fun a.
(MonadError err m, MonadQuote m) =>
cfg -> TypeCheckM uni fun cfg err a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckM
   uni
   fun
   (TypeCheckConfig uni fun)
   err
   (Normalized (Type TyName uni ()))
 -> m (Normalized (Type TyName uni ())))
-> (Term TyName Name uni fun ann
    -> TypeCheckM
         uni
         fun
         (TypeCheckConfig uni fun)
         err
         (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun ann
-> TypeCheckM
     uni
     fun
     (TypeCheckConfig uni fun)
     err
     (Normalized (Type TyName uni ()))
forall err (uni :: * -> *) fun ann cfg.
(AsTypeError err (Term TyName Name uni fun ()) uni fun ann,
 ToKind uni, HasUniApply uni, HasTypeCheckConfig cfg uni fun,
 GEq uni, Ix fun) =>
Term TyName Name uni fun ann
-> TypeCheckM uni fun cfg err (Normalized (Type TyName uni ()))
inferTypeM

-- | Check a term against a type.
-- Infers the type of the term and checks that it's equal to the given type
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkType
    :: ( MonadError err m, MonadQuote m
       , AsTypeError err (Term TyName Name uni fun ()) uni fun ann, ToKind uni, HasUniApply uni
       , GEq uni, Ix fun
       )
    => TypeCheckConfig uni fun
    -> ann
    -> Term TyName Name uni fun ann
    -> Normalized (Type TyName uni ())
    -> m ()
checkType :: TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term Normalized (Type TyName uni ())
ty = do
    Term TyName Name uni fun ann
termRen <- Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
rename Term TyName Name uni fun ann
term
    TypeCheckConfig uni fun
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err () -> m ()
forall err (m :: * -> *) cfg (uni :: * -> *) fun a.
(MonadError err m, MonadQuote m) =>
cfg -> TypeCheckM uni fun cfg err a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckM uni fun (TypeCheckConfig uni fun) err () -> m ())
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err () -> m ()
forall a b. (a -> b) -> a -> b
$ ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckM uni fun (TypeCheckConfig uni fun) err ()
forall err (uni :: * -> *) fun ann cfg.
(AsTypeError err (Term TyName Name uni fun ()) uni fun ann,
 ToKind uni, HasUniApply uni, HasTypeCheckConfig cfg uni fun,
 GEq uni, Ix fun) =>
ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckM uni fun cfg err ()
checkTypeM ann
ann Term TyName Name uni fun ann
termRen Normalized (Type TyName uni ())
ty

-- | Infer the type of a program.
inferTypeOfProgram
    :: ( MonadError err m, MonadQuote m
       , AsTypeError err (Term TyName Name uni fun ()) uni fun ann, ToKind uni, HasUniApply uni
       , GEq uni, Ix fun
       )
    => TypeCheckConfig uni fun
    -> Program TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ()))
inferTypeOfProgram :: TypeCheckConfig uni fun
-> Program TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferTypeOfProgram TypeCheckConfig uni fun
config (Program ann
_ Version ann
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall err (m :: * -> *) (uni :: * -> *) fun ann.
(MonadError err m, MonadQuote m,
 AsTypeError err (Term TyName Name uni fun ()) uni fun ann,
 ToKind uni, HasUniApply uni, GEq uni, Ix fun) =>
TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config Term TyName Name uni fun ann
term

-- | Check a program against a type.
-- Infers the type of the program and checks that it's equal to the given type
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkTypeOfProgram
    :: ( MonadError err m, MonadQuote m
       , AsTypeError err (Term TyName Name uni fun ()) uni fun ann, ToKind uni, HasUniApply uni
       , GEq uni, Ix fun
       )
    => TypeCheckConfig uni fun
    -> ann
    -> Program TyName Name uni fun ann
    -> Normalized (Type TyName uni ())
    -> m ()
checkTypeOfProgram :: TypeCheckConfig uni fun
-> ann
-> Program TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkTypeOfProgram TypeCheckConfig uni fun
config ann
ann (Program ann
_ Version ann
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
forall err (m :: * -> *) (uni :: * -> *) fun ann.
(MonadError err m, MonadQuote m,
 AsTypeError err (Term TyName Name uni fun ()) uni fun ann,
 ToKind uni, HasUniApply uni, GEq uni, Ix fun) =>
TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term