{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

-- | Sometimes we need to write our own version of functions over `Map.Map` that
-- do not appear in the "containers" library. This module is for such functions.
--
-- For example:
--
-- 1. Version of `Map.withoutKeys` where both arguments are `Map.Map`
-- 2. Comparing that two maps have exactly the same set of keys
-- 3. The intersection of two maps guarded by a predicate.
--
--    > ((dom stkcred) ◁ deleg) ▷ (dom stpool)) ==>
--    > intersectDomP (\ k v -> Map.member v stpool) stkcred deleg
module Data.MapExtras
  ( StrictTriple (..),
    extract,
    noKeys,
    keysEqual,
    splitMemberMap,
    splitMemberSet,
    intersectDomP,
    intersectDomPLeft,
    intersectMapSetFold,
    disjointMapSetFold,
    extractKeys,
    extractKeysSmallSet,
  )
where

import Data.Map.Internal (Map (..), balanceL, balanceR, glue, link, link2)
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Set.Internal as Set
import GHC.Exts (isTrue#, reallyUnsafePtrEquality#, (==#))

data StrictTriple a b c = StrictTriple !a !b !c
  deriving (Int -> StrictTriple a b c -> ShowS
[StrictTriple a b c] -> ShowS
StrictTriple a b c -> String
(Int -> StrictTriple a b c -> ShowS)
-> (StrictTriple a b c -> String)
-> ([StrictTriple a b c] -> ShowS)
-> Show (StrictTriple a b c)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b c.
(Show a, Show b, Show c) =>
Int -> StrictTriple a b c -> ShowS
forall a b c.
(Show a, Show b, Show c) =>
[StrictTriple a b c] -> ShowS
forall a b c.
(Show a, Show b, Show c) =>
StrictTriple a b c -> String
showList :: [StrictTriple a b c] -> ShowS
$cshowList :: forall a b c.
(Show a, Show b, Show c) =>
[StrictTriple a b c] -> ShowS
show :: StrictTriple a b c -> String
$cshow :: forall a b c.
(Show a, Show b, Show c) =>
StrictTriple a b c -> String
showsPrec :: Int -> StrictTriple a b c -> ShowS
$cshowsPrec :: forall a b c.
(Show a, Show b, Show c) =>
Int -> StrictTriple a b c -> ShowS
Show, StrictTriple a b c -> StrictTriple a b c -> Bool
(StrictTriple a b c -> StrictTriple a b c -> Bool)
-> (StrictTriple a b c -> StrictTriple a b c -> Bool)
-> Eq (StrictTriple a b c)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall a b c.
(Eq a, Eq b, Eq c) =>
StrictTriple a b c -> StrictTriple a b c -> Bool
/= :: StrictTriple a b c -> StrictTriple a b c -> Bool
$c/= :: forall a b c.
(Eq a, Eq b, Eq c) =>
StrictTriple a b c -> StrictTriple a b c -> Bool
== :: StrictTriple a b c -> StrictTriple a b c -> Bool
$c== :: forall a b c.
(Eq a, Eq b, Eq c) =>
StrictTriple a b c -> StrictTriple a b c -> Bool
Eq)

noKeys :: Ord k => Map k a -> Map k b -> Map k a
noKeys :: Map k a -> Map k b -> Map k a
noKeys Map k a
Tip Map k b
_ = Map k a
forall k a. Map k a
Tip
noKeys Map k a
m Map k b
Tip = Map k a
m
noKeys Map k a
m (Bin Int
_ k
k b
_ Map k b
ls Map k b
rs) = case k -> Map k a -> (Map k a, Map k a)
forall k a. Ord k => k -> Map k a -> (Map k a, Map k a)
Map.split k
k Map k a
m of
  (Map k a
lm, Map k a
rm) -> Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
lm' Map k a
rm' -- We know `k` is not in either `lm` or `rm`
    where
      !lm' :: Map k a
lm' = Map k a -> Map k b -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
noKeys Map k a
lm Map k b
ls
      !rm' :: Map k a
rm' = Map k a -> Map k b -> Map k a
forall k a b. Ord k => Map k a -> Map k b -> Map k a
noKeys Map k a
rm Map k b
rs
{-# INLINEABLE noKeys #-}

-- | Checks if two pointers are equal. Yes means yes;
-- no means maybe. The values should be forced to at least
-- WHNF before comparison to get moderately reliable results.
ptrEq :: a -> a -> Bool
ptrEq :: a -> a -> Bool
ptrEq a
x a
y = Int# -> Bool
isTrue# (a -> a -> Int#
forall a. a -> a -> Int#
reallyUnsafePtrEquality# a
x a
y Int# -> Int# -> Int#
==# Int#
1#)
{-# INLINE ptrEq #-}

keysEqual :: Ord k => Map k v1 -> Map k v2 -> Bool
keysEqual :: Map k v1 -> Map k v2 -> Bool
keysEqual Map k v1
Tip Map k v2
Tip = Bool
True
keysEqual Map k v1
Tip (Bin Int
_ k
_ v2
_ Map k v2
_ Map k v2
_) = Bool
False
keysEqual (Bin Int
_ k
_ v1
_ Map k v1
_ Map k v1
_) Map k v2
Tip = Bool
False
keysEqual Map k v1
m (Bin Int
_ k
k v2
_ Map k v2
ls Map k v2
rs) =
  case k -> Map k v1 -> StrictTriple (Map k v1) Bool (Map k v1)
forall k a.
Ord k =>
k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
splitMemberMap k
k Map k v1
m of
    StrictTriple Map k v1
lm Bool
True Map k v1
rm -> Map k v2 -> Map k v1 -> Bool
forall k v1 v2. Ord k => Map k v1 -> Map k v2 -> Bool
keysEqual Map k v2
ls Map k v1
lm Bool -> Bool -> Bool
&& Map k v2 -> Map k v1 -> Bool
forall k v1 v2. Ord k => Map k v1 -> Map k v2 -> Bool
keysEqual Map k v2
rs Map k v1
rm
    StrictTriple (Map k v1) Bool (Map k v1)
_ -> Bool
False
{-# INLINEABLE keysEqual #-}

-- | A variant of 'splitLookup' that indicates only whether the
-- key was present, rather than producing its value. This is used to
-- implement 'keysEqual' to avoid allocating unnecessary 'Just'
-- constructors.
--
-- /Note/ - this is a copy pasted internal function from "containers" package
-- adjusted to return `StrictTriple`
splitMemberMap :: Ord k => k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
splitMemberMap :: k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
splitMemberMap = k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall k a.
Ord k =>
k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
go
  where
    go :: Ord k => k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
    go :: k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
go !k
k Map k a
t =
      case Map k a
t of
        Map k a
Tip -> Map k a -> Bool -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Map k a
forall k a. Map k a
Tip Bool
False Map k a
forall k a. Map k a
Tip
        Bin Int
_ k
kx a
x Map k a
l Map k a
r -> case k -> k -> Ordering
forall a. Ord a => a -> a -> Ordering
compare k
k k
kx of
          Ordering
LT ->
            let !(StrictTriple Map k a
lt Bool
z Map k a
gt) = k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall k a.
Ord k =>
k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
go k
k Map k a
l
                !gt' :: Map k a
gt' = k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
kx a
x Map k a
gt Map k a
r
             in Map k a -> Bool -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Map k a
lt Bool
z Map k a
gt'
          Ordering
GT ->
            let !(StrictTriple Map k a
lt Bool
z Map k a
gt) = k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall k a.
Ord k =>
k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
go k
k Map k a
r
                !lt' :: Map k a
lt' = k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
kx a
x Map k a
l Map k a
lt
             in Map k a -> Bool -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Map k a
lt' Bool
z Map k a
gt
          Ordering
EQ -> Map k a -> Bool -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Map k a
l Bool
True Map k a
r
{-# INLINEABLE splitMemberMap #-}

-- | /O(log n)/. Performs a 'split' but also returns whether the pivot
-- element was found in the original set.
--
-- This is a modified version of `Set.splitMember`, where `StrictTriple` is used
-- instead of a lazy one for minor performance gain.
splitMemberSet :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberSet :: a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberSet a
_ Set a
Set.Tip = Set a -> Bool -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Set a
forall a. Set a
Set.Tip Bool
False Set a
forall a. Set a
Set.Tip
splitMemberSet a
x (Set.Bin Int
_ a
y Set a
l Set a
r) =
  case a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y of
    Ordering
LT ->
      let !(StrictTriple Set a
lt Bool
found Set a
gt) = a -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a. Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberSet a
x Set a
l
          !gt' :: Set a
gt' = a -> Set a -> Set a -> Set a
forall a. a -> Set a -> Set a -> Set a
Set.link a
y Set a
gt Set a
r
       in Set a -> Bool -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Set a
lt Bool
found Set a
gt'
    Ordering
GT ->
      let !(StrictTriple Set a
lt Bool
found Set a
gt) = a -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a. Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberSet a
x Set a
r
          !lt' :: Set a
lt' = a -> Set a -> Set a -> Set a
forall a. a -> Set a -> Set a -> Set a
Set.link a
y Set a
l Set a
lt
       in Set a -> Bool -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Set a
lt' Bool
found Set a
gt
    Ordering
EQ -> Set a -> Bool -> Set a -> StrictTriple (Set a) Bool (Set a)
forall a b c. a -> b -> c -> StrictTriple a b c
StrictTriple Set a
l Bool
True Set a
r
{-# INLINEABLE splitMemberSet #-}

-- | intersetDomP p m1 m2 == Keep the key and value from m2, iff (the key is in the dom of m1) && ((p key value) is true)
intersectDomP :: Ord k => (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
intersectDomP :: (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
intersectDomP k -> v2 -> Bool
_ Map k v1
Tip Map k v2
_ = Map k v2
forall k a. Map k a
Tip
intersectDomP k -> v2 -> Bool
_ Map k v1
_ Map k v2
Tip = Map k v2
forall k a. Map k a
Tip
intersectDomP k -> v2 -> Bool
p Map k v1
t1 (Bin Int
_ k
k v2
v Map k v2
l2 Map k v2
r2) =
  if Bool
mb Bool -> Bool -> Bool
&& k -> v2 -> Bool
p k
k v2
v
    then k -> v2 -> Map k v2 -> Map k v2 -> Map k v2
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k v2
v Map k v2
l1l2 Map k v2
r1r2
    else Map k v2 -> Map k v2 -> Map k v2
forall k a. Map k a -> Map k a -> Map k a
link2 Map k v2
l1l2 Map k v2
r1r2
  where
    !(StrictTriple Map k v1
l1 Bool
mb Map k v1
r1) = k -> Map k v1 -> StrictTriple (Map k v1) Bool (Map k v1)
forall k a.
Ord k =>
k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
splitMemberMap k
k Map k v1
t1
    !l1l2 :: Map k v2
l1l2 = (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
forall k v2 v1.
Ord k =>
(k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
intersectDomP k -> v2 -> Bool
p Map k v1
l1 Map k v2
l2
    !r1r2 :: Map k v2
r1r2 = (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
forall k v2 v1.
Ord k =>
(k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v2
intersectDomP k -> v2 -> Bool
p Map k v1
r1 Map k v2
r2
{-# INLINEABLE intersectDomP #-}

-- | - Similar to intersectDomP, except the Map returned has the same key as the first input map, rather than the second input map.
intersectDomPLeft :: Ord k => (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
intersectDomPLeft :: (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
intersectDomPLeft k -> v2 -> Bool
_ Map k v1
Tip Map k v2
_ = Map k v1
forall k a. Map k a
Tip
intersectDomPLeft k -> v2 -> Bool
_ Map k v1
_ Map k v2
Tip = Map k v1
forall k a. Map k a
Tip
intersectDomPLeft k -> v2 -> Bool
p (Bin Int
_ k
k v1
v1 Map k v1
l1 Map k v1
r1) Map k v2
t2 =
  case Maybe v2
mb of
    Just v2
v2 | k -> v2 -> Bool
p k
k v2
v2 -> k -> v1 -> Map k v1 -> Map k v1 -> Map k v1
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k v1
v1 Map k v1
l1l2 Map k v1
r1r2
    Maybe v2
_other -> Map k v1 -> Map k v1 -> Map k v1
forall k a. Map k a -> Map k a -> Map k a
link2 Map k v1
l1l2 Map k v1
r1r2
  where
    !(Map k v2
l2, Maybe v2
mb, Map k v2
r2) = k -> Map k v2 -> (Map k v2, Maybe v2, Map k v2)
forall k a. Ord k => k -> Map k a -> (Map k a, Maybe a, Map k a)
Map.splitLookup k
k Map k v2
t2
    !l1l2 :: Map k v1
l1l2 = (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
forall k v2 v1.
Ord k =>
(k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
intersectDomPLeft k -> v2 -> Bool
p Map k v1
l1 Map k v2
l2
    !r1r2 :: Map k v1
r1r2 = (k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
forall k v2 v1.
Ord k =>
(k -> v2 -> Bool) -> Map k v1 -> Map k v2 -> Map k v1
intersectDomPLeft k -> v2 -> Bool
p Map k v1
r1 Map k v2
r2
{-# INLINEABLE intersectDomPLeft #-}

-- | - fold over the intersection of a Map and a Set
intersectMapSetFold :: Ord k => (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
intersectMapSetFold :: (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
intersectMapSetFold k -> v -> ans -> ans
_accum Map k v
Tip Set k
_ !ans
ans = ans
ans
intersectMapSetFold k -> v -> ans -> ans
_accum Map k v
_ Set k
set !ans
ans | Set k -> Bool
forall a. Set a -> Bool
Set.null Set k
set = ans
ans
intersectMapSetFold k -> v -> ans -> ans
accum (Bin Int
_ k
k v
v Map k v
l1 Map k v
l2) Set k
set !ans
ans =
  (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
forall k v ans.
Ord k =>
(k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
intersectMapSetFold k -> v -> ans -> ans
accum Map k v
l1 Set k
s1 (k -> v -> ans -> ans
addKV k
k v
v ((k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
forall k v ans.
Ord k =>
(k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
intersectMapSetFold k -> v -> ans -> ans
accum Map k v
l2 Set k
s2 ans
ans))
  where
    (Set k
s1, Bool
found, Set k
s2) = k -> Set k -> (Set k, Bool, Set k)
forall a. Ord a => a -> Set a -> (Set a, Bool, Set a)
Set.splitMember k
k Set k
set
    addKV :: k -> v -> ans -> ans
addKV k
k1 v
v1 !ans
ans1 = if Bool
found then k -> v -> ans -> ans
accum k
k1 v
v1 ans
ans1 else ans
ans1
{-# INLINEABLE intersectMapSetFold #-}

-- | Fold with 'accum' all those pairs in the map, not appearing in the set.
disjointMapSetFold :: Ord k => (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
disjointMapSetFold :: (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
disjointMapSetFold k -> v -> ans -> ans
_accum Map k v
Tip Set k
_ !ans
ans = ans
ans
disjointMapSetFold k -> v -> ans -> ans
accum Map k v
m Set k
set !ans
ans | Set k -> Bool
forall a. Set a -> Bool
Set.null Set k
set = (k -> v -> ans -> ans) -> ans -> Map k v -> ans
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey' k -> v -> ans -> ans
accum ans
ans Map k v
m
disjointMapSetFold k -> v -> ans -> ans
accum (Bin Int
_ k
k v
v Map k v
l1 Map k v
l2) Set k
set !ans
ans =
  (k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
forall k v ans.
Ord k =>
(k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
disjointMapSetFold k -> v -> ans -> ans
accum Map k v
l1 Set k
s1 (k -> v -> ans -> ans
addKV k
k v
v ((k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
forall k v ans.
Ord k =>
(k -> v -> ans -> ans) -> Map k v -> Set k -> ans -> ans
disjointMapSetFold k -> v -> ans -> ans
accum Map k v
l2 Set k
s2 ans
ans))
  where
    (Set k
s1, Bool
found, Set k
s2) = k -> Set k -> (Set k, Bool, Set k)
forall a. Ord a => a -> Set a -> (Set a, Bool, Set a)
Set.splitMember k
k Set k
set
    addKV :: k -> v -> ans -> ans
addKV k
k1 v
v1 !ans
ans1 = if Bool -> Bool
not Bool
found then k -> v -> ans -> ans
accum k
k1 v
v1 ans
ans1 else ans
ans1
{-# INLINEABLE disjointMapSetFold #-}

-- =================================

-- This is a slightly adjusted version of `delete` from "containers"
extract# :: Ord k => k -> Map k a -> (# Maybe a, Map k a #)
extract# :: k -> Map k a -> (# Maybe a, Map k a #)
extract# !k
k = Map k a -> (# Maybe a, Map k a #)
forall a. Map k a -> (# Maybe a, Map k a #)
go
  where
    go :: Map k a -> (# Maybe a, Map k a #)
go Map k a
Tip = (# Maybe a
forall a. Maybe a
Nothing, Map k a
forall k a. Map k a
Tip #)
    go t :: Map k a
t@(Bin Int
_ k
kx a
x Map k a
l Map k a
r) =
      case k -> k -> Ordering
forall a. Ord a => a -> a -> Ordering
compare k
k k
kx of
        Ordering
LT
          | Map k a
l' Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
l -> (# Maybe a
mVal, Map k a
t #)
          | Bool
otherwise -> let !m :: Map k a
m = k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
balanceR k
kx a
x Map k a
l' Map k a
r in (# Maybe a
mVal, Map k a
m #)
          where
            !(# Maybe a
mVal, Map k a
l' #) = Map k a -> (# Maybe a, Map k a #)
go Map k a
l
        Ordering
GT
          | Map k a
r' Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
r -> (# Maybe a
mVal, Map k a
t #)
          | Bool
otherwise -> let !m :: Map k a
m = k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
balanceL k
kx a
x Map k a
l Map k a
r' in (# Maybe a
mVal, Map k a
m #)
          where
            !(# Maybe a
mVal, Map k a
r' #) = Map k a -> (# Maybe a, Map k a #)
go Map k a
r
        Ordering
EQ -> let !m :: Map k a
m = Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
glue Map k a
l Map k a
r in (# a -> Maybe a
forall a. a -> Maybe a
Just a
x, Map k a
m #)
{-# INLINE extract# #-}

-- | Just like `Map.delete`, but also returns the value if it was indeed deleted
-- from the map.
extract :: Ord k => k -> Map k b -> (Maybe b, Map k b)
extract :: k -> Map k b -> (Maybe b, Map k b)
extract k
k Map k b
m =
  case k -> Map k b -> (# Maybe b, Map k b #)
forall k a. Ord k => k -> Map k a -> (# Maybe a, Map k a #)
extract# k
k Map k b
m of
    (# Just b
v, Map k b
m' #) -> (b -> Maybe b
forall a. a -> Maybe a
Just b
v, Map k b
m')
    (# Maybe b, Map k b #)
_ -> (Maybe b
forall a. Maybe a
Nothing, Map k b
m)
{-# INLINE extract #-}

-- | Partition the `Map` according to keys in the `Set`. This is equivalent to:
--
-- > extractKeys m s === (withoutKeys m s, restrictKeys m s)
extractKeys :: Ord k => Map k a -> Set k -> (Map k a, Map k a)
extractKeys :: Map k a -> Set k -> (Map k a, Map k a)
extractKeys Map k a
m Set k
s
  | Set k -> Int
forall a. Set a -> Int
Set.size Set k
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
6 = Map k a -> Set k -> (Map k a, Map k a)
forall k a. Ord k => Map k a -> Set k -> (Map k a, Map k a)
extractKeysSmallSet Map k a
m Set k
s -- See haddock for value 6
  | Bool
otherwise =
      case Map k a -> Set k -> (# Map k a, Map k a #)
forall k a. Ord k => Map k a -> Set k -> (# Map k a, Map k a #)
extractKeys# Map k a
m Set k
s of
        (# Map k a
w, Map k a
r #) -> (Map k a
w, Map k a
r)
{-# INLINE extractKeys #-}

-- | It has been discovered expirementally through benchmarks that for small Set
-- size of under around 6 elements this function performs faster than
-- `extractKeys#`
extractKeysSmallSet :: Ord k => Map k a -> Set.Set k -> (Map k a, Map k a)
extractKeysSmallSet :: Map k a -> Set k -> (Map k a, Map k a)
extractKeysSmallSet Map k a
sm = ((Map k a, Map k a) -> k -> (Map k a, Map k a))
-> (Map k a, Map k a) -> Set k -> (Map k a, Map k a)
forall a b. (a -> b -> a) -> a -> Set b -> a
Set.foldl' (Map k a, Map k a) -> k -> (Map k a, Map k a)
forall k a. Ord k => (Map k a, Map k a) -> k -> (Map k a, Map k a)
f (Map k a
sm, Map k a
forall k a. Map k a
Map.empty)
  where
    f :: (Map k a, Map k a) -> k -> (Map k a, Map k a)
f acc :: (Map k a, Map k a)
acc@(Map k a
without, Map k a
restrict) k
k =
      case k -> Map k a -> (# Maybe a, Map k a #)
forall k a. Ord k => k -> Map k a -> (# Maybe a, Map k a #)
extract# k
k Map k a
without of
        (# Just a
v, Map k a
without' #) ->
          let !restrict' :: Map k a
restrict' = k -> a -> Map k a -> Map k a
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
k a
v Map k a
restrict
           in (Map k a
without', Map k a
restrict')
        (# Maybe a, Map k a #)
_ -> (Map k a, Map k a)
acc
{-# INLINE extractKeysSmallSet #-}

-- | This function will produce exactly the same results as
-- `extractKeysSmallSet` for all inputs, but it performs better whenever the set
-- is big.
extractKeys# :: Ord k => Map k a -> Set k -> (# Map k a, Map k a #)
extractKeys# :: Map k a -> Set k -> (# Map k a, Map k a #)
extractKeys# Map k a
Tip Set k
_ = (# Map k a
forall k a. Map k a
Tip, Map k a
forall k a. Map k a
Tip #)
extractKeys# Map k a
m Set k
Set.Tip = (# Map k a
m, Map k a
forall k a. Map k a
Tip #)
extractKeys# m :: Map k a
m@(Bin Int
_ k
k a
x Map k a
lm Map k a
rm) Set k
s = (# Map k a
w, Map k a
r #)
  where
    !(StrictTriple Set k
ls Bool
b Set k
rs) = k -> Set k -> StrictTriple (Set k) Bool (Set k)
forall a. Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberSet k
k Set k
s
    !w :: Map k a
w
      | Bool -> Bool
not Bool
b =
          if Map k a
lmw Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
lm Bool -> Bool -> Bool
&& Map k a
rmw Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
rm
            then Map k a
m
            else k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k a
x Map k a
lmw Map k a
rmw
      | Bool
otherwise = Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
lmw Map k a
rmw
    !r :: Map k a
r
      | Bool
b =
          if Map k a
lmr Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
lm Bool -> Bool -> Bool
&& Map k a
rmr Map k a -> Map k a -> Bool
forall a. a -> a -> Bool
`ptrEq` Map k a
rm
            then Map k a
m
            else k -> a -> Map k a -> Map k a -> Map k a
forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k a
x Map k a
lmr Map k a
rmr
      | Bool
otherwise = Map k a -> Map k a -> Map k a
forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
lmr Map k a
rmr
    !(# Map k a
lmw, Map k a
lmr #) = Map k a -> Set k -> (# Map k a, Map k a #)
forall k a. Ord k => Map k a -> Set k -> (# Map k a, Map k a #)
extractKeys# Map k a
lm Set k
ls
    !(# Map k a
rmw, Map k a
rmr #) = Map k a -> Set k -> (# Map k a, Map k a #)
forall k a. Ord k => Map k a -> Set k -> (# Map k a, Map k a #)
extractKeys# Map k a
rm Set k
rs
{-# INLINEABLE extractKeys# #-}