{-# LANGUAGE DefaultSignatures     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}

module Ouroboros.Consensus.Util.DepPair (
    -- * Dependent pairs
    DepPair
  , GenDepPair (GenDepPair, DepPair)
  , depPairFirst
    -- * Compare indices
  , SameDepIndex (..)
    -- * Trivial dependency
  , TrivialDependency (..)
  , fromTrivialDependency
  , toTrivialDependency
    -- * Convenience re-exports
  , Proxy (..)
  , (:~:) (..)
  ) where

import           Data.Kind (Type)
import           Data.Proxy
import           Data.SOP.Strict (I (..))
import           Data.Type.Equality ((:~:) (..))

{-------------------------------------------------------------------------------
  Dependent pairs
-------------------------------------------------------------------------------}

-- | Generalization of 'DepPair'
--
-- This adds an additional functor @g@ around the second value in the pair.
data GenDepPair g f where
  GenDepPair :: !(f a) -> !(g a) -> GenDepPair g f

-- | Dependent pair
--
-- A dependent pair is a pair of values where the type of the value depends
-- on the first value.
type DepPair = GenDepPair I

{-# COMPLETE DepPair #-}
pattern DepPair :: f a -> a -> DepPair f
pattern $bDepPair :: f a -> a -> DepPair f
$mDepPair :: forall r (f :: * -> *).
DepPair f -> (forall a. f a -> a -> r) -> (Void# -> r) -> r
DepPair fa a = GenDepPair fa (I a)

depPairFirst :: (forall a. f a -> f' a) -> GenDepPair g f -> GenDepPair g f'
depPairFirst :: (forall a. f a -> f' a) -> GenDepPair g f -> GenDepPair g f'
depPairFirst forall a. f a -> f' a
f (GenDepPair f a
ix g a
a) = f' a -> g a -> GenDepPair g f'
forall (f :: * -> *) a (g :: * -> *). f a -> g a -> GenDepPair g f
GenDepPair (f a -> f' a
forall a. f a -> f' a
f f a
ix) g a
a

{-------------------------------------------------------------------------------
  Compare indices
-------------------------------------------------------------------------------}

class SameDepIndex f where
  sameDepIndex :: f a -> f b -> Maybe (a :~: b)

  default sameDepIndex :: TrivialDependency f => f a -> f b -> Maybe (a :~: b)
  sameDepIndex f a
ix f b
ix' = (a :~: b) -> Maybe (a :~: b)
forall a. a -> Maybe a
Just ((a :~: b) -> Maybe (a :~: b)) -> (a :~: b) -> Maybe (a :~: b)
forall a b. (a -> b) -> a -> b
$ f a -> f b -> a :~: b
forall (f :: * -> *) a b.
TrivialDependency f =>
f a -> f b -> a :~: b
hasSingleIndex f a
ix f b
ix'

{-------------------------------------------------------------------------------
  Trivial dependencies
-------------------------------------------------------------------------------}

-- | A dependency is trivial if it always maps to the same type @b@
class TrivialDependency f where
  type TrivialIndex f :: Type
  hasSingleIndex :: f a -> f b -> a :~: b
  indexIsTrivial :: f (TrivialIndex f)

fromTrivialDependency :: TrivialDependency f => f a -> a -> TrivialIndex f
fromTrivialDependency :: f a -> a -> TrivialIndex f
fromTrivialDependency f a
ix =
    case f (TrivialIndex f) -> f a -> TrivialIndex f :~: a
forall (f :: * -> *) a b.
TrivialDependency f =>
f a -> f b -> a :~: b
hasSingleIndex f (TrivialIndex f)
forall (f :: * -> *). TrivialDependency f => f (TrivialIndex f)
indexIsTrivial f a
ix of
      TrivialIndex f :~: a
Refl -> a -> TrivialIndex f
forall a. a -> a
id

toTrivialDependency :: TrivialDependency f => f a -> TrivialIndex f -> a
toTrivialDependency :: f a -> TrivialIndex f -> a
toTrivialDependency f a
ix =
    case f (TrivialIndex f) -> f a -> TrivialIndex f :~: a
forall (f :: * -> *) a b.
TrivialDependency f =>
f a -> f b -> a :~: b
hasSingleIndex f (TrivialIndex f)
forall (f :: * -> *). TrivialDependency f => f (TrivialIndex f)
indexIsTrivial f a
ix of
      TrivialIndex f :~: a
Refl -> TrivialIndex f -> a
forall a. a -> a
id