{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Network.DNS.Transport (
Resolver(..)
, resolve
) where
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)
import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types
import Network.DNS.Types.Internal
checkResp :: [Question] -> Identifier -> DNSMessage -> Bool
checkResp :: [Question] -> Identifier -> DNSMessage -> Bool
checkResp [Question]
q Identifier
seqno DNSMessage
resp =
(DNSHeader -> Identifier
identifier (DNSMessage -> DNSHeader
header DNSMessage
resp) Identifier -> Identifier -> Bool
forall a. Eq a => a -> a -> Bool
== Identifier
seqno) Bool -> Bool -> Bool
&& ([Question]
q [Question] -> [Question] -> Bool
forall a. Eq a => a -> a -> Bool
== (DNSMessage -> [Question]
question DNSMessage
resp))
data TCPFallback = TCPFallback deriving (Int -> TCPFallback -> ShowS
[TCPFallback] -> ShowS
TCPFallback -> String
(Int -> TCPFallback -> ShowS)
-> (TCPFallback -> String)
-> ([TCPFallback] -> ShowS)
-> Show TCPFallback
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TCPFallback] -> ShowS
$cshowList :: [TCPFallback] -> ShowS
show :: TCPFallback -> String
$cshow :: TCPFallback -> String
showsPrec :: Int -> TCPFallback -> ShowS
$cshowsPrec :: Int -> TCPFallback -> ShowS
Show, Typeable)
instance Exception TCPFallback
type Rslv0 = Bool -> (Socket -> IO DNSMessage)
-> IO (Either DNSError DNSMessage)
type Rslv1 = [Question]
-> [ResourceRecord]
-> Int
-> Int
-> Rslv0
type TcpRslv = Identifier -> AddrInfo -> [Question] -> Int
-> Bool -> IO DNSMessage
type UdpRslv = [ResourceRecord] -> Int
-> (Socket -> IO DNSMessage) -> TcpRslv
resolve :: Domain -> TYPE -> Resolver -> Rslv0
resolve :: Domain -> TYPE -> Resolver -> Rslv0
resolve Domain
dom TYPE
typ Resolver
rlv Bool
ad Socket -> IO DNSMessage
rcv
| Domain -> Bool
isIllegal Domain
dom = Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either DNSError DNSMessage -> IO (Either DNSError DNSMessage))
-> Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall a b. (a -> b) -> a -> b
$ DNSError -> Either DNSError DNSMessage
forall a b. a -> Either a b
Left DNSError
IllegalDomain
| Bool
onlyOne = AddrInfo -> IO Identifier -> Rslv1
resolveOne ([AddrInfo] -> AddrInfo
forall a. [a] -> a
head [AddrInfo]
nss) ([IO Identifier] -> IO Identifier
forall a. [a] -> a
head [IO Identifier]
gens) [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
| Bool
concurrent = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss [IO Identifier]
gens [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
| Bool
otherwise = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss [IO Identifier]
gens [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
where
q :: [Question]
q = case Domain -> Char
BS.last Domain
dom of
Char
'.' -> [Domain -> TYPE -> Question
Question Domain
dom TYPE
typ]
Char
_ -> [Domain -> TYPE -> Question
Question (Domain
dom Domain -> Domain -> Domain
forall a. Semigroup a => a -> a -> a
<> Domain
".") TYPE
typ]
gens :: [IO Identifier]
gens = NonEmpty (IO Identifier) -> [IO Identifier]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (IO Identifier) -> [IO Identifier])
-> NonEmpty (IO Identifier) -> [IO Identifier]
forall a b. (a -> b) -> a -> b
$ Resolver -> NonEmpty (IO Identifier)
genIds Resolver
rlv
seed :: ResolvSeed
seed = Resolver -> ResolvSeed
resolvseed Resolver
rlv
nss :: [AddrInfo]
nss = NonEmpty AddrInfo -> [AddrInfo]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty AddrInfo -> [AddrInfo])
-> NonEmpty AddrInfo -> [AddrInfo]
forall a b. (a -> b) -> a -> b
$ ResolvSeed -> NonEmpty AddrInfo
nameservers ResolvSeed
seed
onlyOne :: Bool
onlyOne = [AddrInfo] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [AddrInfo]
nss Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
conf :: ResolvConf
conf = ResolvSeed -> ResolvConf
resolvconf ResolvSeed
seed
concurrent :: Bool
concurrent = ResolvConf -> Bool
resolvConcurrent ResolvConf
conf
tm :: Int
tm = ResolvConf -> Int
resolvTimeout ResolvConf
conf
retry :: Int
retry = ResolvConf -> Int
resolvRetry ResolvConf
conf
edns :: [ResourceRecord]
edns = ResolvConf -> [ResourceRecord]
resolvEDNS ResolvConf
conf
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss [IO Identifier]
gs [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv = [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
nss [IO Identifier]
gs
where
loop :: [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo
ai] [IO Identifier
gen] = AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
loop (AddrInfo
ai:[AddrInfo]
ais) (IO Identifier
gen:[IO Identifier]
gens) = do
Either DNSError DNSMessage
eres <- AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
case Either DNSError DNSMessage
eres of
Left DNSError
_ -> [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
ais [IO Identifier]
gens
Either DNSError DNSMessage
res -> Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall (m :: * -> *) a. Monad m => a -> m a
return Either DNSError DNSMessage
res
loop [AddrInfo]
_ [IO Identifier]
_ = String -> IO (Either DNSError DNSMessage)
forall a. HasCallStack => String -> a
error String
"resolveSequential:loop"
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss [IO Identifier]
gens [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv = do
[Async (Either DNSError DNSMessage)]
asyncs <- ((AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage)))
-> [(AddrInfo, IO Identifier)]
-> IO [Async (Either DNSError DNSMessage)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync ([(AddrInfo, IO Identifier)]
-> IO [Async (Either DNSError DNSMessage)])
-> [(AddrInfo, IO Identifier)]
-> IO [Async (Either DNSError DNSMessage)]
forall a b. (a -> b) -> a -> b
$ [AddrInfo] -> [IO Identifier] -> [(AddrInfo, IO Identifier)]
forall a b. [a] -> [b] -> [(a, b)]
zip [AddrInfo]
nss [IO Identifier]
gens
(Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
-> Either DNSError DNSMessage
forall a b. (a, b) -> b
snd ((Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
-> Either DNSError DNSMessage)
-> IO
(Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
-> IO (Either DNSError DNSMessage)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Async (Either DNSError DNSMessage)]
-> IO
(Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
forall a. [Async a] -> IO (Async a, a)
waitAnyCancel [Async (Either DNSError DNSMessage)]
asyncs
where
mkAsync :: (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync (AddrInfo
ai,IO Identifier
gen) = IO (Either DNSError DNSMessage)
-> IO (Async (Either DNSError DNSMessage))
forall a. IO a -> IO (Async a)
async (IO (Either DNSError DNSMessage)
-> IO (Async (Either DNSError DNSMessage)))
-> IO (Either DNSError DNSMessage)
-> IO (Async (Either DNSError DNSMessage))
forall a b. (a -> b) -> a -> b
$ AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen [Question]
q [ResourceRecord]
edns Int
tm Int
retry Bool
ad Socket -> IO DNSMessage
rcv = do
Identifier
ident <- IO Identifier
gen
IO DNSMessage -> IO (Either DNSError DNSMessage)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO DNSMessage -> IO (Either DNSError DNSMessage))
-> IO DNSMessage -> IO (Either DNSError DNSMessage)
forall a b. (a -> b) -> a -> b
$ UdpRslv
udpTcpLookup [ResourceRecord]
edns Int
retry Socket -> IO DNSMessage
rcv Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad
udpTcpLookup :: UdpRslv
udpTcpLookup :: UdpRslv
udpTcpLookup [ResourceRecord]
edns Int
retry Socket -> IO DNSMessage
rcv Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad =
UdpRslv
udpLookup [ResourceRecord]
edns Int
retry Socket -> IO DNSMessage
rcv Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad IO DNSMessage -> (TCPFallback -> IO DNSMessage) -> IO DNSMessage
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \TCPFallback
TCPFallback ->
TcpRslv
tcpLookup Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
tag IOError
ioe = DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
throwIO (DNSError -> IO DNSMessage) -> DNSError -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$ IOError -> DNSError
NetworkFailure IOError
aioe
where
aioe :: IOError
aioe = IOError -> String -> Maybe Handle -> Maybe String -> IOError
annotateIOError IOError
ioe (AddrInfo -> String
forall a. Show a => a -> String
show AddrInfo
ai) Maybe Handle
forall a. Maybe a
Nothing (Maybe String -> IOError) -> Maybe String -> IOError
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
tag
udpOpen :: AddrInfo -> IO Socket
udpOpen :: AddrInfo -> IO Socket
udpOpen AddrInfo
ai = do
Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
ai) (AddrInfo -> SocketType
addrSocketType AddrInfo
ai) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
ai)
Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
udpLookup :: UdpRslv
udpLookup :: UdpRslv
udpLookup [ResourceRecord]
edns Int
retry Socket -> IO DNSMessage
rcv Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad = do
let qry :: Domain
qry = Identifier -> [Question] -> [ResourceRecord] -> Bool -> Domain
encodeQuestions Identifier
ident [Question]
q [ResourceRecord]
edns Bool
ad
ednsRetry :: Bool
ednsRetry = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ResourceRecord] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ResourceRecord]
edns
(IOError -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"UDP") (IO DNSMessage -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$
IO Socket
-> (Socket -> IO ()) -> (Socket -> IO DNSMessage) -> IO DNSMessage
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (AddrInfo -> IO Socket
udpOpen AddrInfo
ai) Socket -> IO ()
close (Domain -> Bool -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry Bool
ednsRetry Int
0 DNSError
RetryLimitExceeded)
where
loop :: Domain -> Bool -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry Bool
ednsRetry Int
cnt DNSError
err Socket
sock
| Int
cnt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
retry = DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
err
| Bool
otherwise = do
Maybe DNSMessage
mres <- Int -> IO DNSMessage -> IO (Maybe DNSMessage)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm (Socket -> Domain -> IO ()
send Socket
sock Domain
qry IO () -> IO DNSMessage -> IO DNSMessage
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO DNSMessage
getAns Socket
sock)
case Maybe DNSMessage
mres of
Maybe DNSMessage
Nothing -> Domain -> Bool -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry Bool
ednsRetry (Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) DNSError
RetryLimitExceeded Socket
sock
Just DNSMessage
res -> do
let flgs :: DNSFlags
flgs = DNSHeader -> DNSFlags
flags(DNSHeader -> DNSFlags) -> DNSHeader -> DNSFlags
forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
res
truncated :: Bool
truncated = DNSFlags -> Bool
trunCation DNSFlags
flgs
rc :: RCODE
rc = DNSFlags -> RCODE
rcode DNSFlags
flgs
if Bool
truncated then
TCPFallback -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO TCPFallback
TCPFallback
else if Bool
ednsRetry Bool -> Bool -> Bool
&& RCODE
rc RCODE -> RCODE -> Bool
forall a. Eq a => a -> a -> Bool
== RCODE
FormatErr then
let nonednsQuery :: Domain
nonednsQuery = Identifier -> [Question] -> [ResourceRecord] -> Bool -> Domain
encodeQuestions Identifier
ident [Question]
q [] Bool
ad
in Domain -> Bool -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
nonednsQuery Bool
False Int
cnt DNSError
RetryLimitExceeded Socket
sock
else
DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res
getAns :: Socket -> IO DNSMessage
getAns Socket
sock = do
DNSMessage
mres <- Socket -> IO DNSMessage
rcv Socket
sock
if [Question] -> Identifier -> DNSMessage -> Bool
checkResp [Question]
q Identifier
ident DNSMessage
mres
then DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
mres
else Socket -> IO DNSMessage
getAns Socket
sock
tcpOpen :: SockAddr -> IO Socket
tcpOpen :: SockAddr -> IO Socket
tcpOpen SockAddr
peer = case SockAddr
peer of
SockAddrInet{} -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
defaultProtocol
SockAddrInet6{} -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET6 SocketType
Stream ProtocolNumber
defaultProtocol
SockAddr
_ -> DNSError -> IO Socket
forall e a. Exception e => e -> IO a
E.throwIO DNSError
ServerFailure
tcpLookup :: TcpRslv
tcpLookup :: TcpRslv
tcpLookup Identifier
ident AddrInfo
ai [Question]
q Int
tm Bool
ad =
(IOError -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"TCP") (IO DNSMessage -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$ IO Socket
-> (Socket -> IO ()) -> (Socket -> IO DNSMessage) -> IO DNSMessage
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SockAddr -> IO Socket
tcpOpen SockAddr
addr) Socket -> IO ()
close Socket -> IO DNSMessage
perform
where
addr :: SockAddr
addr = AddrInfo -> SockAddr
addrAddress AddrInfo
ai
perform :: Socket -> IO DNSMessage
perform Socket
vc = do
let qry :: Domain
qry = Identifier -> [Question] -> [ResourceRecord] -> Bool -> Domain
encodeQuestions Identifier
ident [Question]
q [] Bool
ad
Maybe DNSMessage
mres <- Int -> IO DNSMessage -> IO (Maybe DNSMessage)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm (IO DNSMessage -> IO (Maybe DNSMessage))
-> IO DNSMessage -> IO (Maybe DNSMessage)
forall a b. (a -> b) -> a -> b
$ do
Socket -> SockAddr -> IO ()
connect Socket
vc SockAddr
addr
Socket -> Domain -> IO ()
sendVC Socket
vc Domain
qry
Socket -> IO DNSMessage
receiveVC Socket
vc
case Maybe DNSMessage
mres of
Maybe DNSMessage
Nothing -> DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
TimeoutExpired
Just DNSMessage
res
| [Question] -> Identifier -> DNSMessage -> Bool
checkResp [Question]
q Identifier
ident DNSMessage
res -> DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res
| Bool
otherwise -> DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
SequenceNumberMismatch
badLength :: Domain -> Bool
badLength :: Domain -> Bool
badLength Domain
dom
| Domain -> Bool
BS.null Domain
dom = Bool
True
| Domain -> Char
BS.last Domain
dom Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'.' = Domain -> Int
BS.length Domain
dom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
254
| Bool
otherwise = Domain -> Int
BS.length Domain
dom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
253
isIllegal :: Domain -> Bool
isIllegal :: Domain -> Bool
isIllegal Domain
dom
| Domain -> Bool
badLength Domain
dom = Bool
True
| Char
'.' Char -> Domain -> Bool
`BS.notElem` Domain
dom = Bool
True
| Char
':' Char -> Domain -> Bool
`BS.elem` Domain
dom = Bool
True
| Char
'/' Char -> Domain -> Bool
`BS.elem` Domain
dom = Bool
True
| (Domain -> Bool) -> [Domain] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Domain
x -> Domain -> Int
BS.length Domain
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
63)
(Char -> Domain -> [Domain]
BS.split Char
'.' Domain
dom) = Bool
True
| Bool
otherwise = Bool
False