{-# LANGUAGE BangPatterns #-}

module Data.CanonicalMaps
  ( CanonicalZero (..),
    canonicalInsert,
    canonicalMapUnion,
    canonicalMap,
    pointWise,
    Map.Map,
  )
where

import Data.Map.Internal
  ( Map (..),
    balanceL,
    balanceR,
    link,
    link2,
  )
import qualified Data.Map.Strict as Map
import Data.Map.Strict.Internal (singleton, splitLookup)

-- =====================================================================================
-- Operations on Map from keys to values that are specialised to `CanonicalZero` values.
-- A (Map k v) is (CanonicalZero v), if it never stores a zero at type v.
-- In order to do this we need to know what 'zeroC' is, and 'joinC' has to know how to
-- joining together two maps where one of its arguments might be 'zeroC'
-- This class is strictly used in the implementation, and is not observable by the user.
-- ======================================================================================

class Eq t => CanonicalZero t where
  zeroC :: t
  joinC :: t -> t -> t

instance CanonicalZero Integer where
  zeroC :: Integer
zeroC = Integer
0
  joinC :: Integer -> Integer -> Integer
joinC = Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+)

instance (Eq k, Eq v, Ord k, CanonicalZero v) => CanonicalZero (Map k v) where
  zeroC :: Map k v
zeroC = Map k v
forall k a. Map k a
Map.empty
  joinC :: Map k v -> Map k v -> Map k v
joinC = (v -> v -> v) -> Map k v -> Map k v -> Map k v
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion v -> v -> v
forall t. CanonicalZero t => t -> t -> t
joinC

-- Note that the class CanonicalZero and the function canonicalMapUnion are mutually recursive.

canonicalMapUnion ::
  (Ord k, CanonicalZero a) =>
  (a -> a -> a) -> -- (\ left right -> ??) which side do you prefer?
  Map k a ->
  Map k a ->
  Map k a
canonicalMapUnion :: (a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
_f Map k a
t1 Map k a
Tip = Map k a
t1
canonicalMapUnion a -> a -> a
f Map k a
t1 (Bin Size
_ k
k a
x Map k a
Tip Map k a
Tip) = (a -> a -> a) -> k -> a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert a -> a -> a
f k
k a
x Map k a
t1
canonicalMapUnion a -> a -> a
f (Bin Size
_ k
k a
x Map k a
Tip Map k a
Tip) Map k a
t2 = (a -> a -> a) -> k -> a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert a -> a -> a
f k
k a
x Map k a
t2
canonicalMapUnion a -> a -> a
_f Map k a
Tip Map k a
t2 = Map k a
t2
canonicalMapUnion a -> a -> a
f (Bin Size
_ k
k1 a
x1 Map k a
l1 Map k a
r1) Map k a
t2 = case k -> Map k a -> (Map k a, Maybe a, Map k a)
forall k a. Ord k => k -> Map k a -> (Map k a, Maybe a, Map k a)
splitLookup k
k1 Map k a
t2 of
  (Map k a
l2, Maybe a
mb, Map k a
r2) -> case Maybe a
mb of
    Maybe a
Nothing ->
      if a
x1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall t. CanonicalZero t => t
zeroC
        then Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l1l2 Map k a
r1r2
        else k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k1 a
x1 Map k a
l1l2 Map k a
r1r2
    Just a
x2 ->
      if a
new a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall t. CanonicalZero t => t
zeroC
        then Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l1l2 Map k a
r1r2
        else k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k1 a
new Map k a
l1l2 Map k a
r1r2
      where
        new :: a
new = a -> a -> a
f a
x1 a
x2
    where
      !l1l2 :: Map k a
l1l2 = (a -> a -> a) -> Map k a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
f Map k a
l1 Map k a
l2
      !r1r2 :: Map k a
r1r2 = (a -> a -> a) -> Map k a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
f Map k a
r1 Map k a
r2
{-# INLINEABLE canonicalMapUnion #-}

canonicalInsert ::
  (Ord k, CanonicalZero a) =>
  (a -> a -> a) ->
  k ->
  a ->
  Map k a ->
  Map k a
canonicalInsert :: (a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert = (a -> a -> a) -> k -> a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
go
  where
    go ::
      (Ord k, CanonicalZero a) =>
      (a -> a -> a) ->
      k ->
      a ->
      Map k a ->
      Map k a
    go :: (a -> a -> a) -> k -> a -> Map k a -> Map k a
go a -> a -> a
_ !k
kx a
x Map k a
Tip = if a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall t. CanonicalZero t => t
zeroC then Map k a
forall k a. Map k a
Tip else k -> a -> Map k a
forall k a. k -> a -> Map k a
singleton k
kx a
x
    go a -> a -> a
f !k
kx a
x (Bin Size
sy k
ky a
y Map k a
l Map k a
r) =
      case k -> k -> Ordering
forall a. Ord a => a -> a -> Ordering
compare k
kx k
ky of
        Ordering
LT -> k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
balanceL k
ky a
y ((a -> a -> a) -> k -> a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
go a -> a -> a
f k
kx a
x Map k a
l) Map k a
r
        Ordering
GT -> k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
balanceR k
ky a
y Map k a
l ((a -> a -> a) -> k -> a -> Map k a -> Map k a
forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
go a -> a -> a
f k
kx a
x Map k a
r)
        Ordering
EQ -> if a
new a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall t. CanonicalZero t => t
zeroC then Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l Map k a
r else Size -> k -> a -> Map k a -> Map k a -> Map k a
forall k a. Size -> k -> a -> Map k a -> Map k a -> Map k a
Bin Size
sy k
kx a
new Map k a
l Map k a
r
          where
            new :: a
new = a -> a -> a
f a
y a
x -- Apply to value in the tree, then the new value
{-# INLINEABLE canonicalInsert #-}

canonicalMap :: (Ord k, CanonicalZero a) => (a -> a) -> Map k a -> Map k a
canonicalMap :: (a -> a) -> Map k a -> Map k a
canonicalMap a -> a
f = (k -> a -> Map k a -> Map k a) -> Map k a -> Map k a -> Map k a
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey k -> a -> Map k a -> Map k a
forall k. Ord k => k -> a -> Map k a -> Map k a
accum Map k a
forall k a. Map k a
Map.empty
  where
    accum :: k -> a -> Map k a -> Map k a
accum k
k a
v Map k a
ans = if a
new a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall t. CanonicalZero t => t
zeroC then Map k a
ans else k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
k a
new Map k a
ans
      where
        new :: a
new = a -> a
f a
v
{-# INLINEABLE canonicalMap #-}

-- Pointwise comparison assuming the map is CanonicalZero, and we assume semantically that
-- the value for keys not appearing in the map is 'zeroC'

pointWise ::
  (Ord k, CanonicalZero v) =>
  (v -> v -> Bool) ->
  Map k v ->
  Map k v ->
  Bool
pointWise :: (v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
_ Map k v
Tip Map k v
Tip = Bool
True
pointWise v -> v -> Bool
p Map k v
Tip m :: Map k v
m@Bin {} = (v -> Bool) -> Map k v -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (v
forall t. CanonicalZero t => t
zeroC v -> v -> Bool
`p`) Map k v
m
pointWise v -> v -> Bool
p m :: Map k v
m@Bin {} Map k v
Tip = (v -> Bool) -> Map k v -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (v -> v -> Bool
`p` v
forall t. CanonicalZero t => t
zeroC) Map k v
m
pointWise v -> v -> Bool
p Map k v
m (Bin Size
_ k
k v
v2 Map k v
ls Map k v
rs) =
  case k -> Map k v -> (Map k v, Maybe v, Map k v)
forall k a. Ord k => k -> Map k a -> (Map k a, Maybe a, Map k a)
Map.splitLookup k
k Map k v
m of
    (Map k v
lm, Just v
v1, Map k v
rm) -> v -> v -> Bool
p v
v1 v
v2 Bool -> Bool -> Bool
&& (v -> v -> Bool) -> Map k v -> Map k v -> Bool
forall k v.
(Ord k, CanonicalZero v) =>
(v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
p Map k v
ls Map k v
lm Bool -> Bool -> Bool
&& (v -> v -> Bool) -> Map k v -> Map k v -> Bool
forall k v.
(Ord k, CanonicalZero v) =>
(v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
p Map k v
rs Map k v
rm
    (Map k v, Maybe v, Map k v)
_ -> Bool
False
{-# INLINEABLE pointWise #-}