{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
module Data.GADT.Compare.TH
( DeriveGEQ(..)
, DeriveGCompare(..)
, GComparing, runGComparing, geq', compare'
) where
import Control.Applicative
import Control.Monad
import Data.Dependent.Sum
import Data.Dependent.Sum.TH.Internal
import Data.Functor.Identity
import Data.GADT.Compare
import Data.Traversable (for)
import Data.Type.Equality ((:~:) (..))
import Language.Haskell.TH
import Language.Haskell.TH.Extras
class DeriveGEQ t where
deriveGEq :: t -> Q [Dec]
instance DeriveGEQ Name where
deriveGEq :: Name -> Q [Dec]
deriveGEq Name
typeName = do
Info
typeInfo <- Name -> Q Info
reify Name
typeName
case Info
typeInfo of
TyConI Dec
dec -> Dec -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq Dec
dec
Info
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: the name of a type constructor is required"
instance DeriveGEQ Dec where
deriveGEq :: Dec -> Q [Dec]
deriveGEq = Name
-> (Q Type -> Q Type)
-> ([TyVarBndrSpec] -> [Con] -> Q Dec)
-> Dec
-> Q [Dec]
deriveForDec ''GEq (\Q Type
t -> [t| GEq $t |]) [TyVarBndrSpec] -> [Con] -> Q Dec
geqFunction
instance DeriveGEQ t => DeriveGEQ [t] where
deriveGEq :: [t] -> Q [Dec]
deriveGEq [t
it] = t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq t
it
deriveGEq [t]
_ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: [] instance only applies to single-element lists"
instance DeriveGEQ t => DeriveGEQ (Q t) where
deriveGEq :: Q t -> Q [Dec]
deriveGEq = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq)
geqFunction :: [TyVarBndrSpec] -> [Con] -> Q Dec
geqFunction [TyVarBndrSpec]
bndrs [Con]
cons = Name -> [ClauseQ] -> Q Dec
funD 'geq
( (Con -> ClauseQ) -> [Con] -> [ClauseQ]
forall a b. (a -> b) -> [a] -> [b]
map ([TyVarBndrSpec] -> Con -> ClauseQ
geqClause [TyVarBndrSpec]
bndrs) [Con]
cons
[ClauseQ] -> [ClauseQ] -> [ClauseQ]
forall a. [a] -> [a] -> [a]
++ [ [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [PatQ
wildP, PatQ
wildP] (ExpQ -> BodyQ
normalB [| Nothing |]) []
| [Con] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
cons Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1
]
)
geqClause :: [TyVarBndrSpec] -> Con -> ClauseQ
geqClause [TyVarBndrSpec]
bndrs Con
con = do
let argTypes :: [Type]
argTypes = Con -> [Type]
argTypesOfCon Con
con
needsGEq :: Type -> Bool
needsGEq Type
argType = (TyVarBndrSpec -> Bool) -> [TyVarBndrSpec] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> Type -> Bool
`occursInType` Type
argType) (Name -> Bool) -> (TyVarBndrSpec -> Name) -> TyVarBndrSpec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrSpec -> Name
forall a. TyVarBndrSpec -> Name
nameOfBinder) ([TyVarBndrSpec]
bndrs [TyVarBndrSpec] -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. [a] -> [a] -> [a]
++ Con -> [TyVarBndrSpec]
varsBoundInCon Con
con)
nArgs :: Int
nArgs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes
[Name]
lArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName String
"x")
[Name]
rArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName String
"y")
[PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
, Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
]
( ExpQ -> BodyQ
normalB (ExpQ -> BodyQ) -> ExpQ -> BodyQ
forall a b. (a -> b) -> a -> b
$ [StmtQ] -> ExpQ
doE
( [ if Type -> Bool
needsGEq Type
argType
then PatQ -> ExpQ -> StmtQ
bindS (Name -> [PatQ] -> PatQ
conP 'Refl []) [| geq $(varE lArg) $(varE rArg) |]
else ExpQ -> StmtQ
noBindS [| guard ($(varE lArg) == $(varE rArg)) |]
| (Name
lArg, Name
rArg, Type
argType) <- [Name] -> [Name] -> [Type] -> [(Name, Name, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames [Type]
argTypes
]
[StmtQ] -> [StmtQ] -> [StmtQ]
forall a. [a] -> [a] -> [a]
++ [ ExpQ -> StmtQ
noBindS [| return Refl |] ]
)
) []
where conName :: Name
conName = Con -> Name
nameOfCon Con
con
newtype GComparing a b t = GComparing (Either (GOrdering a b) t)
instance Functor (GComparing a b) where fmap :: (a -> b) -> GComparing a b a -> GComparing a b b
fmap a -> b
f (GComparing Either (GOrdering a b) a
x) = Either (GOrdering a b) b -> GComparing a b b
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing ((GOrdering a b -> Either (GOrdering a b) b)
-> (a -> Either (GOrdering a b) b)
-> Either (GOrdering a b) a
-> Either (GOrdering a b) b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either GOrdering a b -> Either (GOrdering a b) b
forall a b. a -> Either a b
Left (b -> Either (GOrdering a b) b
forall a b. b -> Either a b
Right (b -> Either (GOrdering a b) b)
-> (a -> b) -> a -> Either (GOrdering a b) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f) Either (GOrdering a b) a
x)
instance Monad (GComparing a b) where
return :: a -> GComparing a b a
return = a -> GComparing a b a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
GComparing (Left GOrdering a b
x) >>= :: GComparing a b a -> (a -> GComparing a b b) -> GComparing a b b
>>= a -> GComparing a b b
f = Either (GOrdering a b) b -> GComparing a b b
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (GOrdering a b -> Either (GOrdering a b) b
forall a b. a -> Either a b
Left GOrdering a b
x)
GComparing (Right a
x) >>= a -> GComparing a b b
f = a -> GComparing a b b
f a
x
instance Applicative (GComparing a b) where
pure :: a -> GComparing a b a
pure = Either (GOrdering a b) a -> GComparing a b a
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (Either (GOrdering a b) a -> GComparing a b a)
-> (a -> Either (GOrdering a b) a) -> a -> GComparing a b a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either (GOrdering a b) a
forall a b. b -> Either a b
Right
<*> :: GComparing a b (a -> b) -> GComparing a b a -> GComparing a b b
(<*>) = GComparing a b (a -> b) -> GComparing a b a -> GComparing a b b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
geq' :: GCompare t => t a -> t b -> GComparing x y (a :~: b)
geq' :: t a -> t b -> GComparing x y (a :~: b)
geq' t a
x t b
y = Either (GOrdering x y) (a :~: b) -> GComparing x y (a :~: b)
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (case t a -> t b -> GOrdering a b
forall k (f :: k -> *) (a :: k) (b :: k).
GCompare f =>
f a -> f b -> GOrdering a b
gcompare t a
x t b
y of
GOrdering a b
GLT -> GOrdering x y -> Either (GOrdering x y) (a :~: b)
forall a b. a -> Either a b
Left GOrdering x y
forall k (a :: k) (b :: k). GOrdering a b
GLT
GOrdering a b
GEQ -> (a :~: a) -> Either (GOrdering x y) (a :~: a)
forall a b. b -> Either a b
Right a :~: a
forall k (a :: k). a :~: a
Refl
GOrdering a b
GGT -> GOrdering x y -> Either (GOrdering x y) (a :~: b)
forall a b. a -> Either a b
Left GOrdering x y
forall k (a :: k) (b :: k). GOrdering a b
GGT)
compare' :: a -> a -> GComparing a b ()
compare' a
x a
y = Either (GOrdering a b) () -> GComparing a b ()
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (Either (GOrdering a b) () -> GComparing a b ())
-> Either (GOrdering a b) () -> GComparing a b ()
forall a b. (a -> b) -> a -> b
$ case a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y of
Ordering
LT -> GOrdering a b -> Either (GOrdering a b) ()
forall a b. a -> Either a b
Left GOrdering a b
forall k (a :: k) (b :: k). GOrdering a b
GLT
Ordering
EQ -> () -> Either (GOrdering a b) ()
forall a b. b -> Either a b
Right ()
Ordering
GT -> GOrdering a b -> Either (GOrdering a b) ()
forall a b. a -> Either a b
Left GOrdering a b
forall k (a :: k) (b :: k). GOrdering a b
GGT
runGComparing :: GComparing a b (GOrdering a b) -> GOrdering a b
runGComparing (GComparing Either (GOrdering a b) (GOrdering a b)
x) = (GOrdering a b -> GOrdering a b)
-> (GOrdering a b -> GOrdering a b)
-> Either (GOrdering a b) (GOrdering a b)
-> GOrdering a b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either GOrdering a b -> GOrdering a b
forall a. a -> a
id GOrdering a b -> GOrdering a b
forall a. a -> a
id Either (GOrdering a b) (GOrdering a b)
x
class DeriveGCompare t where
deriveGCompare :: t -> Q [Dec]
instance DeriveGCompare Name where
deriveGCompare :: Name -> Q [Dec]
deriveGCompare Name
typeName = do
Info
typeInfo <- Name -> Q Info
reify Name
typeName
case Info
typeInfo of
TyConI Dec
dec -> Dec -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare Dec
dec
Info
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: the name of a type constructor is required"
instance DeriveGCompare Dec where
deriveGCompare :: Dec -> Q [Dec]
deriveGCompare = Name
-> (Q Type -> Q Type)
-> ([TyVarBndrSpec] -> [Con] -> Q Dec)
-> Dec
-> Q [Dec]
deriveForDec ''GCompare (\Q Type
t -> [t| GCompare $t |]) [TyVarBndrSpec] -> [Con] -> Q Dec
forall (t :: * -> *).
Foldable t =>
[TyVarBndrSpec] -> t Con -> Q Dec
gcompareFunction
instance DeriveGCompare t => DeriveGCompare [t] where
deriveGCompare :: [t] -> Q [Dec]
deriveGCompare [t
it] = t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare t
it
deriveGCompare [t]
_ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: [] instance only applies to single-element lists"
instance DeriveGCompare t => DeriveGCompare (Q t) where
deriveGCompare :: Q t -> Q [Dec]
deriveGCompare = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare)
gcompareFunction :: [TyVarBndrSpec] -> t Con -> Q Dec
gcompareFunction [TyVarBndrSpec]
boundVars t Con
cons
| t Con -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null t Con
cons = Name -> [ClauseQ] -> Q Dec
funD 'gcompare [[PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [] (ExpQ -> BodyQ
normalB [| \x y -> seq x (seq y undefined) |]) []]
| Bool
otherwise = Name -> [ClauseQ] -> Q Dec
funD 'gcompare ((Con -> [ClauseQ]) -> t Con -> [ClauseQ]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Con -> [ClauseQ]
gcompareClauses t Con
cons)
where
gcompareClauses :: Con -> [ClauseQ]
gcompareClauses Con
con =
[ Con -> ClauseQ
mainClause Con
con
, [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [Name -> [FieldPatQ] -> PatQ
recP Name
conName [], PatQ
wildP] (ExpQ -> BodyQ
normalB [| GLT |]) []
, [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [PatQ
wildP, Name -> [FieldPatQ] -> PatQ
recP Name
conName []] (ExpQ -> BodyQ
normalB [| GGT |]) []
] where conName :: Name
conName = Con -> Name
nameOfCon Con
con
needsGCompare :: Type -> Con -> Bool
needsGCompare Type
argType Con
con = (TyVarBndrSpec -> Bool) -> [TyVarBndrSpec] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> Type -> Bool
`occursInType` Type
argType) (Name -> Bool) -> (TyVarBndrSpec -> Name) -> TyVarBndrSpec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrSpec -> Name
forall a. TyVarBndrSpec -> Name
nameOfBinder) ([TyVarBndrSpec]
boundVars [TyVarBndrSpec] -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. [a] -> [a] -> [a]
++ Con -> [TyVarBndrSpec]
varsBoundInCon Con
con)
mainClause :: Con -> ClauseQ
mainClause Con
con = do
let conName :: Name
conName = Con -> Name
nameOfCon Con
con
argTypes :: [Type]
argTypes = Con -> [Type]
argTypesOfCon Con
con
nArgs :: Int
nArgs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes
[Name]
lArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName String
"x")
[Name]
rArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName String
"y")
[PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
, Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
]
( ExpQ -> BodyQ
normalB
[| runGComparing $
$(doE
( [ if needsGCompare argType con
then bindS (conP 'Refl []) [| geq' $(varE lArg) $(varE rArg) |]
else noBindS [| compare' $(varE lArg) $(varE rArg) |]
| (lArg, rArg, argType) <- zip3 lArgNames rArgNames argTypes
]
++ [ noBindS [| return GEQ |] ]
)
)
|]
) []