module Network.DNS.Memo where

import qualified Control.Reaper as R
import qualified Data.ByteString as B
import Data.OrdPSQ (OrdPSQ)
import qualified Data.OrdPSQ as PSQ
import Data.Time (UTCTime, getCurrentTime)

import Network.DNS.Imports
import Network.DNS.Types

data Section = Answer | Authority deriving (Section -> Section -> Bool
(Section -> Section -> Bool)
-> (Section -> Section -> Bool) -> Eq Section
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Section -> Section -> Bool
$c/= :: Section -> Section -> Bool
== :: Section -> Section -> Bool
$c== :: Section -> Section -> Bool
Eq, Eq Section
Eq Section
-> (Section -> Section -> Ordering)
-> (Section -> Section -> Bool)
-> (Section -> Section -> Bool)
-> (Section -> Section -> Bool)
-> (Section -> Section -> Bool)
-> (Section -> Section -> Section)
-> (Section -> Section -> Section)
-> Ord Section
Section -> Section -> Bool
Section -> Section -> Ordering
Section -> Section -> Section
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Section -> Section -> Section
$cmin :: Section -> Section -> Section
max :: Section -> Section -> Section
$cmax :: Section -> Section -> Section
>= :: Section -> Section -> Bool
$c>= :: Section -> Section -> Bool
> :: Section -> Section -> Bool
$c> :: Section -> Section -> Bool
<= :: Section -> Section -> Bool
$c<= :: Section -> Section -> Bool
< :: Section -> Section -> Bool
$c< :: Section -> Section -> Bool
compare :: Section -> Section -> Ordering
$ccompare :: Section -> Section -> Ordering
$cp1Ord :: Eq Section
Ord, Int -> Section -> ShowS
[Section] -> ShowS
Section -> String
(Int -> Section -> ShowS)
-> (Section -> String) -> ([Section] -> ShowS) -> Show Section
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Section] -> ShowS
$cshowList :: [Section] -> ShowS
show :: Section -> String
$cshow :: Section -> String
showsPrec :: Int -> Section -> ShowS
$cshowsPrec :: Int -> Section -> ShowS
Show)

type Key = (ByteString
           ,TYPE)
type Prio = UTCTime

type Entry = Either DNSError [RData]

type DB = OrdPSQ Key Prio Entry

type Cache = R.Reaper DB (Key,Prio,Entry)

newCache :: Int -> IO Cache
newCache :: Int -> IO Cache
newCache Int
delay = ReaperSettings (OrdPSQ Key Prio Entry) (Key, Prio, Entry)
-> IO Cache
forall workload item.
ReaperSettings workload item -> IO (Reaper workload item)
R.mkReaper ReaperSettings [Any] Any
forall item. ReaperSettings [item] item
R.defaultReaperSettings {
    reaperEmpty :: OrdPSQ Key Prio Entry
R.reaperEmpty  = OrdPSQ Key Prio Entry
forall k p v. OrdPSQ k p v
PSQ.empty
  , reaperCons :: (Key, Prio, Entry)
-> OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry
R.reaperCons   = \(Key
k, Prio
tim, Entry
v) OrdPSQ Key Prio Entry
psq -> Key
-> Prio -> Entry -> OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
PSQ.insert Key
k Prio
tim Entry
v OrdPSQ Key Prio Entry
psq
  , reaperAction :: OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
R.reaperAction = OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
prune
  , reaperDelay :: Int
R.reaperDelay  = Int
delay Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000000
  , reaperNull :: OrdPSQ Key Prio Entry -> Bool
R.reaperNull   = OrdPSQ Key Prio Entry -> Bool
forall k p v. OrdPSQ k p v -> Bool
PSQ.null
  }

lookupCache :: Key -> Cache -> IO (Maybe (Prio, Entry))
lookupCache :: Key -> Cache -> IO (Maybe (Prio, Entry))
lookupCache Key
key Cache
reaper = Key -> OrdPSQ Key Prio Entry -> Maybe (Prio, Entry)
forall k p v. Ord k => k -> OrdPSQ k p v -> Maybe (p, v)
PSQ.lookup Key
key (OrdPSQ Key Prio Entry -> Maybe (Prio, Entry))
-> IO (OrdPSQ Key Prio Entry) -> IO (Maybe (Prio, Entry))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cache -> IO (OrdPSQ Key Prio Entry)
forall workload item. Reaper workload item -> IO workload
R.reaperRead Cache
reaper

insertCache :: Key -> Prio -> Entry -> Cache -> IO ()
insertCache :: Key -> Prio -> Entry -> Cache -> IO ()
insertCache (ByteString
dom,TYPE
typ) Prio
tim Entry
ent0 Cache
reaper = Cache -> (Key, Prio, Entry) -> IO ()
forall workload item. Reaper workload item -> item -> IO ()
R.reaperAdd Cache
reaper (Key
key,Prio
tim,Entry
ent)
  where
    key :: Key
key = (ByteString -> ByteString
B.copy ByteString
dom,TYPE
typ)
    ent :: Entry
ent = case Entry
ent0 of
      l :: Entry
l@(Left DNSError
_)  -> Entry
l
      (Right [RData]
rds) -> [RData] -> Entry
forall a b. b -> Either a b
Right ([RData] -> Entry) -> [RData] -> Entry
forall a b. (a -> b) -> a -> b
$ (RData -> RData) -> [RData] -> [RData]
forall a b. (a -> b) -> [a] -> [b]
map RData -> RData
copy [RData]
rds

-- Theoretically speaking, atMostView itself is good enough for pruning.
-- But auto-update assumes a list based db which does not provide atMost
-- functions. So, we need to do this redundant way.
prune :: DB -> IO (DB -> DB)
prune :: OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
prune OrdPSQ Key Prio Entry
oldpsq = do
    Prio
tim <- IO Prio
getCurrentTime
    let ([(Key, Prio, Entry)]
_, OrdPSQ Key Prio Entry
pruned) = Prio
-> OrdPSQ Key Prio Entry
-> ([(Key, Prio, Entry)], OrdPSQ Key Prio Entry)
forall k p v.
(Ord k, Ord p) =>
p -> OrdPSQ k p v -> ([(k, p, v)], OrdPSQ k p v)
PSQ.atMostView Prio
tim OrdPSQ Key Prio Entry
oldpsq
    (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
forall (m :: * -> *) a. Monad m => a -> m a
return ((OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
 -> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry))
-> (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
forall a b. (a -> b) -> a -> b
$ \OrdPSQ Key Prio Entry
newpsq -> (OrdPSQ Key Prio Entry
 -> (Key, Prio, Entry) -> OrdPSQ Key Prio Entry)
-> OrdPSQ Key Prio Entry
-> [(Key, Prio, Entry)]
-> OrdPSQ Key Prio Entry
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' OrdPSQ Key Prio Entry
-> (Key, Prio, Entry) -> OrdPSQ Key Prio Entry
forall k p v.
(Ord k, Ord p) =>
OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ Key Prio Entry
pruned ([(Key, Prio, Entry)] -> OrdPSQ Key Prio Entry)
-> [(Key, Prio, Entry)] -> OrdPSQ Key Prio Entry
forall a b. (a -> b) -> a -> b
$ OrdPSQ Key Prio Entry -> [(Key, Prio, Entry)]
forall k p v. OrdPSQ k p v -> [(k, p, v)]
PSQ.toList OrdPSQ Key Prio Entry
newpsq
  where
    ins :: OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ k p v
psq (k
k,p
p,v
v) = k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
PSQ.insert k
k p
p v
v OrdPSQ k p v
psq

copy :: RData -> RData
copy :: RData -> RData
copy r :: RData
r@(RD_A IPv4
_)           = RData
r
copy (RD_NS ByteString
dom)          = ByteString -> RData
RD_NS (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_CNAME ByteString
dom)       = ByteString -> RData
RD_CNAME (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_SOA ByteString
mn ByteString
mr Word32
a Word32
b Word32
c Word32
d Word32
e) = ByteString
-> ByteString
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> RData
RD_SOA (ByteString -> ByteString
B.copy ByteString
mn) (ByteString -> ByteString
B.copy ByteString
mr) Word32
a Word32
b Word32
c Word32
d Word32
e
copy (RD_PTR ByteString
dom)         = ByteString -> RData
RD_PTR (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy RData
RD_NULL              = RData
RD_NULL
copy (RD_MX Word16
prf ByteString
dom)      = Word16 -> ByteString -> RData
RD_MX Word16
prf (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_TXT ByteString
txt)         = ByteString -> RData
RD_TXT (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
txt
copy r :: RData
r@(RD_AAAA IPv6
_)        = RData
r
copy (RD_SRV Word16
a Word16
b Word16
c ByteString
dom)   = Word16 -> Word16 -> Word16 -> ByteString -> RData
RD_SRV Word16
a Word16
b Word16
c (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_DNAME ByteString
dom)       = ByteString -> RData
RD_DNAME (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_OPT [OData]
od)          = [OData] -> RData
RD_OPT ([OData] -> RData) -> [OData] -> RData
forall a b. (a -> b) -> a -> b
$ (OData -> OData) -> [OData] -> [OData]
forall a b. (a -> b) -> [a] -> [b]
map OData -> OData
copyOData [OData]
od
copy (RD_DS Word16
t Word8
a Word8
dt ByteString
dv)    = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_DS Word16
t Word8
a Word8
dt (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dv
copy (RD_DNSKEY Word16
f Word8
p Word8
a ByteString
k)  = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_DNSKEY Word16
f Word8
p Word8
a (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
k
copy (RD_TLSA Word8
a Word8
b Word8
c ByteString
dgst) = Word8 -> Word8 -> Word8 -> ByteString -> RData
RD_TLSA Word8
a Word8
b Word8
c (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dgst
copy (RD_NSEC3PARAM Word8
a Word8
b Word16
c ByteString
salt) = Word8 -> Word8 -> Word16 -> ByteString -> RData
RD_NSEC3PARAM Word8
a Word8
b Word16
c (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
salt
copy (UnknownRData ByteString
is)    = ByteString -> RData
UnknownRData (ByteString -> RData) -> ByteString -> RData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
is

copyOData :: OData -> OData
copyOData :: OData -> OData
copyOData o :: OData
o@(OD_ClientSubnet Word8
_ Word8
_ IP
_) = OData
o
copyOData (UnknownOData OptCode
c ByteString
b)        = OptCode -> ByteString -> OData
UnknownOData OptCode
c (ByteString -> OData) -> ByteString -> OData
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
b