{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Data.BiMap where

import Cardano.Binary
  ( Decoder,
    DecoderError (DecoderErrorCustom),
    FromCBOR (..),
    ToCBOR (..),
    decodeListLen,
    decodeMapSkel,
    dropMap,
  )
import Codec.CBOR.Encoding (encodeListLen)
import Control.DeepSeq (NFData (rnf))
import Control.Monad (unless, void)
import Data.Coders (cborError, invalidKey)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import NoThunks.Class (NoThunks (..))

-- =================== Basic BiMap =====================
-- For Bijections we define (BiMap v k v).  Reasons we can't use (Data.Bimap k v)
-- 1) We need to enforce that the second argument `v` is in the Ord class, when making it an Iter instance.
-- 2) The constructor for Data.BiMap is not exported, and it implements a Bijection
-- 3) Missing operation 'restrictkeys' and 'withoutkeys' make performant versions of operations  ◁ ⋪ ▷ ⋫ hard.
-- 4) Missing operation 'union', make performant versions of ∪ and ⨃ hard.
-- 5) So we roll our own which is really a (Data.Map k v) with an index that maps v to Set{k}

data BiMap v a b where MkBiMap :: (v ~ b) => !(Map.Map a b) -> !(Map.Map b (Set.Set a)) -> BiMap v a b

--  ^   the 1st and 3rd parameter must be the same:             ^   ^

biMapToMap :: BiMap v a b -> Map a b
biMapToMap :: BiMap v a b -> Map a b
biMapToMap (MkBiMap Map a b
m Map b (Set a)
_) = Map a b
m

biMapFromMap ::
  (Ord k, Ord v) => Map k v -> BiMap v k v
biMapFromMap :: Map k v -> BiMap v k v
biMapFromMap Map k v
bmForward =
  Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap Map k v
bmForward (Map v (Set k) -> BiMap v k v) -> Map v (Set k) -> BiMap v k v
forall a b. (a -> b) -> a -> b
$ ((k, v) -> Map v (Set k) -> Map v (Set k))
-> Map v (Set k) -> [(k, v)] -> Map v (Set k)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((k -> v -> Map v (Set k) -> Map v (Set k))
-> (k, v) -> Map v (Set k) -> Map v (Set k)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((k -> v -> Map v (Set k) -> Map v (Set k))
 -> (k, v) -> Map v (Set k) -> Map v (Set k))
-> (k -> v -> Map v (Set k) -> Map v (Set k))
-> (k, v)
-> Map v (Set k)
-> Map v (Set k)
forall a b. (a -> b) -> a -> b
$ (v -> k -> Map v (Set k) -> Map v (Set k))
-> k -> v -> Map v (Set k) -> Map v (Set k)
forall a b c. (a -> b -> c) -> b -> a -> c
flip v -> k -> Map v (Set k) -> Map v (Set k)
forall v k.
(Ord v, Ord k) =>
v -> k -> Map v (Set k) -> Map v (Set k)
addBack) Map v (Set k)
forall k a. Map k a
Map.empty ([(k, v)] -> Map v (Set k)) -> [(k, v)] -> Map v (Set k)
forall a b. (a -> b) -> a -> b
$ Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
Map.toList Map k v
bmForward

-- ============== begin necessary Cardano.Binary instances ===============
instance (Ord a, Ord b, ToCBOR a, ToCBOR b) => ToCBOR (BiMap b a b) where
  -- The `toCBOR` instance encodes only the forward map. We wrap this in a
  -- length-one list because a _previous_ encoding wrote out both maps, and we
  -- can easily use the list length token to distinguish between them.
  toCBOR :: BiMap b a b -> Encoding
toCBOR (MkBiMap Map a b
l Map b (Set a)
_) = Word -> Encoding
encodeListLen Word
1 Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> Map a b -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR Map a b
l

instance
  forall a b.
  (Ord a, Ord b, FromCBOR a, FromCBOR b) =>
  FromCBOR (BiMap b a b)
  where
  fromCBOR :: Decoder s (BiMap b a b)
fromCBOR =
    Decoder s Int
forall s. Decoder s Int
decodeListLen Decoder s Int
-> (Int -> Decoder s (BiMap b a b)) -> Decoder s (BiMap b a b)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Int
1 -> Decoder s (BiMap b a b)
forall a b s.
(FromCBOR a, FromCBOR b, Ord a, Ord b) =>
Decoder s (BiMap b a b)
decodeMapAsBimap
      -- Previous encoding of 'BiMap' encoded both the forward and reverse
      -- directions. In this case we skip the reverse encoding. Note that,
      -- further, the reverse encoding was from 'b' to 'a', not the current 'b'
      -- to 'Set a', and hence the dropper reflects that.
      Int
2 -> do
        !BiMap b a b
x <- Decoder s (BiMap b a b)
forall a b s.
(FromCBOR a, FromCBOR b, Ord a, Ord b) =>
Decoder s (BiMap b a b)
decodeMapAsBimap
        Dropper s -> Dropper s -> Dropper s
forall s. Dropper s -> Dropper s -> Dropper s
dropMap (Decoder s b -> Dropper s
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Decoder s b -> Dropper s) -> Decoder s b -> Dropper s
forall a b. (a -> b) -> a -> b
$ forall s. FromCBOR b => Decoder s b
forall a s. FromCBOR a => Decoder s a
fromCBOR @b) (Decoder s a -> Dropper s
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Decoder s a -> Dropper s) -> Decoder s a -> Dropper s
forall a b. (a -> b) -> a -> b
$ forall s. FromCBOR a => Decoder s a
forall a s. FromCBOR a => Decoder s a
fromCBOR @a)
        BiMap b a b -> Decoder s (BiMap b a b)
forall (m :: * -> *) a. Monad m => a -> m a
return BiMap b a b
x
      Int
k -> Word -> Decoder s (BiMap b a b)
forall s a. Word -> Decoder s a
invalidKey (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k)

-- | Decode a serialised CBOR Map as a Bimap
decodeMapAsBimap ::
  (FromCBOR a, FromCBOR b, Ord a, Ord b) =>
  Decoder s (BiMap b a b)
decodeMapAsBimap :: Decoder s (BiMap b a b)
decodeMapAsBimap = do
  bimap :: BiMap b a b
bimap@(MkBiMap Map a b
mf Map b (Set a)
mb) <- ([(a, b)] -> BiMap b a b) -> Decoder s (BiMap b a b)
forall k v m s.
(Ord k, FromCBOR k, FromCBOR v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel [(a, b)] -> BiMap b a b
forall k v. (Ord k, Ord v) => [(k, v)] -> BiMap v k v
biMapFromAscDistinctList
  Bool -> Decoder s () -> Decoder s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Map a b -> Bool
forall k a. Ord k => Map k a -> Bool
Map.valid Map a b
mf Bool -> Bool -> Bool
&& Map b (Set a) -> Bool
forall k a. Ord k => Map k a -> Bool
Map.valid Map b (Set a)
mb) (Decoder s () -> Decoder s ()) -> Decoder s () -> Decoder s ()
forall a b. (a -> b) -> a -> b
$
    DecoderError -> Decoder s ()
forall e s a. Buildable e => e -> Decoder s a
cborError (DecoderError -> Decoder s ()) -> DecoderError -> Decoder s ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> DecoderError
DecoderErrorCustom Text
"BiMap" Text
"Expected distinct keys in ascending order"
  BiMap b a b -> Decoder s (BiMap b a b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BiMap b a b
bimap

instance (NoThunks a, NoThunks b) => NoThunks (BiMap v a b) where
  showTypeOf :: Proxy (BiMap v a b) -> String
showTypeOf Proxy (BiMap v a b)
_ = String
"BiMap"
  wNoThunks :: Context -> BiMap v a b -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt (MkBiMap Map a b
l Map b (Set a)
r) = Context -> (Map a b, Map b (Set a)) -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Context -> a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt (Map a b
l, Map b (Set a)
r)

instance NFData (BiMap v a b) where
  rnf :: BiMap v a b -> ()
rnf (MkBiMap Map a b
l Map b (Set a)
r) = Map a b -> () -> ()
seq Map a b
l (Map b (Set a) -> () -> ()
seq Map b (Set a)
r ())

-- ============== end Necessary Cardano.Binary instances ===================

instance (Eq k, Eq v) => Eq (BiMap u k v) where
  (MkBiMap Map k v
l Map v (Set k)
_) == :: BiMap u k v -> BiMap u k v -> Bool
== (MkBiMap Map k v
x Map v (Set k)
_) = Map k v
l Map k v -> Map k v -> Bool
forall a. Eq a => a -> a -> Bool
== Map k v
x

instance (Show k, Show v) => Show (BiMap u k v) where
  show :: BiMap u k v -> String
show (MkBiMap Map k v
l Map v (Set k)
_r) = Map k v -> String
forall a. Show a => a -> String
show Map k v
l

addBack :: (Ord v, Ord k) => v -> k -> Map.Map v (Set.Set k) -> Map.Map v (Set.Set k)
addBack :: v -> k -> Map v (Set k) -> Map v (Set k)
addBack v
newv k
k Map v (Set k)
m = (Set k -> Set k -> Set k)
-> v -> Set k -> Map v (Set k) -> Map v (Set k)
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith Set k -> Set k -> Set k
forall a. Ord a => Set a -> Set a -> Set a
Set.union v
newv (k -> Set k
forall a. a -> Set a
Set.singleton k
k) Map v (Set k)
m

retract :: (Ord v, Ord k) => v -> k -> Map.Map v (Set.Set k) -> Map.Map v (Set.Set k)
retract :: v -> k -> Map v (Set k) -> Map v (Set k)
retract v
oldv k
k Map v (Set k)
m = (Set k -> Set k) -> v -> Map v (Set k) -> Map v (Set k)
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust (k -> Set k -> Set k
forall a. Ord a => a -> Set a -> Set a
Set.delete k
k) v
oldv Map v (Set k)
m

insertBackwards :: (Ord k, Ord v) => v -> v -> k -> Map.Map v (Set.Set k) -> Map.Map v (Set.Set k)
insertBackwards :: v -> v -> k -> Map v (Set k) -> Map v (Set k)
insertBackwards v
oldv v
newv k
k Map v (Set k)
m = v -> k -> Map v (Set k) -> Map v (Set k)
forall v k.
(Ord v, Ord k) =>
v -> k -> Map v (Set k) -> Map v (Set k)
addBack v
newv k
k (v -> k -> Map v (Set k) -> Map v (Set k)
forall v k.
(Ord v, Ord k) =>
v -> k -> Map v (Set k) -> Map v (Set k)
retract v
oldv k
k Map v (Set k)
m)

insertWithBiMap :: (Ord k, Ord v) => (v -> v -> v) -> k -> v -> BiMap v k v -> BiMap v k v
insertWithBiMap :: (v -> v -> v) -> k -> v -> BiMap v k v -> BiMap v k v
insertWithBiMap v -> v -> v
comb k
k v
v (MkBiMap Map k v
f Map v (Set k)
b) = Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap ((v -> v -> v) -> k -> v -> Map k v -> Map k v
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith ((v -> v -> v) -> v -> v -> v
forall v. (v -> v -> v) -> v -> v -> v
mapflip v -> v -> v
comb) k
k v
v Map k v
f) (v -> v -> k -> Map v (Set k) -> Map v (Set k)
forall k v.
(Ord k, Ord v) =>
v -> v -> k -> Map v (Set k) -> Map v (Set k)
insertBackwards v
oldv v
newv k
k Map v (Set k)
b)
  where
    (v
oldv, v
newv) = case k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k v
f of Maybe v
Nothing -> (v
v, v
v); Just v
v2 -> (v
v2, v -> v -> v
comb v
v2 v
v)

biMapEmpty :: BiMap v k v
biMapEmpty :: BiMap v k v
biMapEmpty = Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap Map k v
forall k a. Map k a
Map.empty Map v (Set k)
forall k a. Map k a
Map.empty

-- Make a BiMap from a list of pairs.
-- The combine function comb=(\ earlier later -> later) will let elements
-- later in the list override ones earlier in the list, and comb =
-- (\ earlier later -> earlier) will keep the vaue that appears first in the list

biMapFromList :: (Ord k, Ord v) => (v -> v -> v) -> [(k, v)] -> BiMap v k v
biMapFromList :: (v -> v -> v) -> [(k, v)] -> BiMap v k v
biMapFromList v -> v -> v
comb [(k, v)]
xs = ((k, v) -> BiMap v k v -> BiMap v k v)
-> BiMap v k v -> [(k, v)] -> BiMap v k v
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (k, v) -> BiMap v k v -> BiMap v k v
addEntry BiMap v k v
forall v k. BiMap v k v
biMapEmpty [(k, v)]
xs
  where
    addEntry :: (k, v) -> BiMap v k v -> BiMap v k v
addEntry (k
k, v
v) (MkBiMap Map k v
forward Map v (Set k)
backward) =
      case k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k v
forward of
        Maybe v
Nothing -> Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap ((v -> v -> v) -> k -> v -> Map k v -> Map k v
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith ((v -> v -> v) -> v -> v -> v
forall v. (v -> v -> v) -> v -> v -> v
mapflip v -> v -> v
comb) k
k v
v Map k v
forward) (v -> k -> Map v (Set k) -> Map v (Set k)
forall v k.
(Ord v, Ord k) =>
v -> k -> Map v (Set k) -> Map v (Set k)
addBack v
v k
k Map v (Set k)
backward)
        Just v
oldv -> Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap ((v -> v -> v) -> k -> v -> Map k v -> Map k v
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith ((v -> v -> v) -> v -> v -> v
forall v. (v -> v -> v) -> v -> v -> v
mapflip v -> v -> v
comb) k
k v
v Map k v
forward) (v -> v -> k -> Map v (Set k) -> Map v (Set k)
forall k v.
(Ord k, Ord v) =>
v -> v -> k -> Map v (Set k) -> Map v (Set k)
insertBackwards v
oldv v
newv k
k Map v (Set k)
backward)
          where
            newv :: v
newv = v -> v -> v
comb v
oldv v
v

-- Data.Map uses(\ new old -> ...) while our convention is (\ old new -> ...)
-- We also use this in the Basic instance for BiMap, which uses Data.Map
mapflip :: (v -> v -> v) -> (v -> v -> v)
mapflip :: (v -> v -> v) -> v -> v -> v
mapflip v -> v -> v
f = (\v
old v
new -> v -> v -> v
f v
new v
old)

-- | /Warning/ - invariant that keys are distinct and in ascending order is not
-- checked. Make sure it is not violated, otherwise crazy things will happen.
biMapFromAscDistinctList ::
  (Ord k, Ord v) => [(k, v)] -> BiMap v k v
biMapFromAscDistinctList :: [(k, v)] -> BiMap v k v
biMapFromAscDistinctList [(k, v)]
xs = Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap Map k v
bmForward Map v (Set k)
bmBackward
  where
    bmForward :: Map k v
bmForward = [(k, v)] -> Map k v
forall k a. [(k, a)] -> Map k a
Map.fromDistinctAscList [(k, v)]
xs
    bmBackward :: Map v (Set k)
bmBackward = ((k, v) -> Map v (Set k) -> Map v (Set k))
-> Map v (Set k) -> [(k, v)] -> Map v (Set k)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((k -> v -> Map v (Set k) -> Map v (Set k))
-> (k, v) -> Map v (Set k) -> Map v (Set k)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((k -> v -> Map v (Set k) -> Map v (Set k))
 -> (k, v) -> Map v (Set k) -> Map v (Set k))
-> (k -> v -> Map v (Set k) -> Map v (Set k))
-> (k, v)
-> Map v (Set k)
-> Map v (Set k)
forall a b. (a -> b) -> a -> b
$ (v -> k -> Map v (Set k) -> Map v (Set k))
-> k -> v -> Map v (Set k) -> Map v (Set k)
forall a b c. (a -> b -> c) -> b -> a -> c
flip v -> k -> Map v (Set k) -> Map v (Set k)
forall v k.
(Ord v, Ord k) =>
v -> k -> Map v (Set k) -> Map v (Set k)
addBack) Map v (Set k)
forall k a. Map k a
Map.empty [(k, v)]
xs

-- This synonym makes (BiMap v k v) appear as an ordinary Binary type contructor: (Bimap k v)
type Bimap k v = BiMap v k v

-- This operation is very fast (Log n) on BiMap, but extremely slow on other collections.
removeval :: (Ord k, Ord v) => v -> BiMap v k v -> BiMap v k v
removeval :: v -> BiMap v k v -> BiMap v k v
removeval v
v (m :: BiMap v k v
m@(MkBiMap Map k v
m1 Map v (Set k)
m2)) =
  case v -> Map v (Set k) -> Maybe (Set k)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup v
v Map v (Set k)
m2 of
    Just Set k
kset -> Map k v -> Map v (Set k) -> BiMap v k v
forall v b a. (v ~ b) => Map a b -> Map b (Set a) -> BiMap v a b
MkBiMap ((k -> Map k v -> Map k v) -> Map k v -> Set k -> Map k v
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\k
k Map k v
set -> k -> Map k v -> Map k v
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete k
k Map k v
set) Map k v
m1 Set k
kset) (v -> Map v (Set k) -> Map v (Set k)
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete v
v Map v (Set k)
m2)
    Maybe (Set k)
Nothing -> BiMap v k v
m