{-# LANGUAGE ConstraintKinds #-}
module PlutusTx.Lattice where

import PlutusTx.Bool
import PlutusTx.Monoid
import PlutusTx.Semigroup

-- | A join semi-lattice, i.e. a partially ordered set equipped with a
-- binary operation '(\/)'.
--
-- Note that the mathematical definition would require an ordering constraint -
-- we omit that so we can define instances for e.g. '(->)'.
class JoinSemiLattice a where
    (\/) :: a -> a -> a

-- | A meet semi-lattice, i.e. a partially ordered set equipped with a
-- binary operation '(/\)'.
--
-- Note that the mathematical definition would require an ordering constraint -
-- we omit that so we can define instances for e.g. '(->)'.
class MeetSemiLattice a where
    (/\) :: a -> a -> a

-- | A lattice.
type Lattice a = (JoinSemiLattice a, MeetSemiLattice a)

-- | A bounded join semi-lattice, i.e. a join semi-lattice augmented with
-- a distinguished element 'bottom' which is the unit of '(\/)'.
class JoinSemiLattice a => BoundedJoinSemiLattice a where
    bottom :: a

-- | A bounded meet semi-lattice, i.e. a meet semi-lattice augmented with
-- a distinguished element 'top' which is the unit of '(/\)'.
class MeetSemiLattice a => BoundedMeetSemiLattice a where
    top :: a

-- | A bounded lattice.
type BoundedLattice a = (BoundedJoinSemiLattice a, BoundedMeetSemiLattice a)

-- Wrappers

-- | A wrapper witnessing that a join semi-lattice is a monoid with '(\/)' and 'bottom'.
newtype Join a = Join a

instance JoinSemiLattice a => Semigroup (Join a) where
    Join a
l <> :: Join a -> Join a -> Join a
<> Join a
r = a -> Join a
forall a. a -> Join a
Join (a
l a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
\/ a
r)

instance BoundedJoinSemiLattice a => Monoid (Join a) where
    mempty :: Join a
mempty = a -> Join a
forall a. a -> Join a
Join a
forall a. BoundedJoinSemiLattice a => a
bottom

-- | A wrapper witnessing that a meet semi-lattice is a monoid with '(/\)' and 'top'.
newtype Meet a = Meet a

instance MeetSemiLattice a => Semigroup (Meet a) where
    Meet a
l <> :: Meet a -> Meet a -> Meet a
<> Meet a
r = a -> Meet a
forall a. a -> Meet a
Meet (a
l a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
/\ a
r)

instance BoundedMeetSemiLattice a => Monoid (Meet a) where
    mempty :: Meet a
mempty = a -> Meet a
forall a. a -> Meet a
Meet a
forall a. BoundedMeetSemiLattice a => a
top

-- Instances

instance JoinSemiLattice Bool where
    {-# INLINABLE (\/) #-}
    \/ :: Bool -> Bool -> Bool
(\/) = Bool -> Bool -> Bool
(||)

instance BoundedJoinSemiLattice Bool where
    {-# INLINABLE bottom #-}
    bottom :: Bool
bottom = Bool
False

instance MeetSemiLattice Bool where
    {-# INLINABLE (/\) #-}
    /\ :: Bool -> Bool -> Bool
(/\) = Bool -> Bool -> Bool
(&&)

instance BoundedMeetSemiLattice Bool where
    {-# INLINABLE top #-}
    top :: Bool
top = Bool
True

instance (JoinSemiLattice a, JoinSemiLattice b) => JoinSemiLattice (a, b) where
    {-# INLINABLE (\/) #-}
    (a
a1, b
b1) \/ :: (a, b) -> (a, b) -> (a, b)
\/ (a
a2, b
b2) = (a
a1 a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
\/ a
a2, b
b1 b -> b -> b
forall a. JoinSemiLattice a => a -> a -> a
\/ b
b2)

instance (BoundedJoinSemiLattice a, BoundedJoinSemiLattice b) => BoundedJoinSemiLattice (a, b) where
    {-# INLINABLE bottom #-}
    bottom :: (a, b)
bottom = (a
forall a. BoundedJoinSemiLattice a => a
bottom, b
forall a. BoundedJoinSemiLattice a => a
bottom)

instance (MeetSemiLattice a, MeetSemiLattice b) => MeetSemiLattice (a, b) where
    {-# INLINABLE (/\) #-}
    (a
a1, b
b1) /\ :: (a, b) -> (a, b) -> (a, b)
/\ (a
a2, b
b2) = (a
a1 a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
/\ a
a2, b
b1 b -> b -> b
forall a. MeetSemiLattice a => a -> a -> a
/\ b
b2)

instance (BoundedMeetSemiLattice a, BoundedMeetSemiLattice b) => BoundedMeetSemiLattice (a, b) where
    {-# INLINABLE top #-}
    top :: (a, b)
top = (a
forall a. BoundedMeetSemiLattice a => a
top, b
forall a. BoundedMeetSemiLattice a => a
top)

instance JoinSemiLattice b => JoinSemiLattice (a -> b) where
    {-# INLINABLE (\/) #-}
    (a -> b
f \/ :: (a -> b) -> (a -> b) -> a -> b
\/ a -> b
g) a
a = a -> b
f a
a b -> b -> b
forall a. JoinSemiLattice a => a -> a -> a
\/ a -> b
g a
a

instance BoundedJoinSemiLattice b => BoundedJoinSemiLattice (a -> b) where
    {-# INLINABLE bottom #-}
    bottom :: a -> b
bottom a
_ = b
forall a. BoundedJoinSemiLattice a => a
bottom

instance MeetSemiLattice b => MeetSemiLattice (a -> b) where
    {-# INLINABLE (/\) #-}
    (a -> b
f /\ :: (a -> b) -> (a -> b) -> a -> b
/\ a -> b
g) a
a = a -> b
f a
a b -> b -> b
forall a. MeetSemiLattice a => a -> a -> a
/\ a -> b
g a
a

instance BoundedMeetSemiLattice b => BoundedMeetSemiLattice (a -> b) where
    {-# INLINABLE top #-}
    top :: a -> b
top a
_ = b
forall a. BoundedMeetSemiLattice a => a
top