{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.DNS.IO (
    -- * Receiving from socket
    receive
  , receiveVC
    -- * Sending to socket
  , send
  , sendVC
    -- ** Creating Query
  , encodeQuestions
  , composeQuery
  , composeQueryAD
    -- ** Creating Response
  , responseA
  , responseAAAA
  ) where

#if !defined(mingw32_HOST_OS)
#define POSIX
#else
#define WIN
#endif

#if __GLASGOW_HASKELL__ < 709
#define GHC708
#endif

import qualified Control.Exception as E
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Char (ord)
import Data.IP (IPv4, IPv6)
import Network.Socket (Socket)
import System.IO.Error


#if defined(WIN) && defined(GHC708)
import Network.Socket (send, recv)
import qualified Data.ByteString.Char8 as BS
#else
import Network.Socket.ByteString (sendAll, recv)
#endif

import Network.DNS.Decode (decode)
import Network.DNS.Encode (encode)
import Network.DNS.Imports
import Network.DNS.Types

----------------------------------------------------------------

-- | Receiving DNS data from 'Socket' and parse it.

receive :: Socket -> IO DNSMessage
receive :: Socket -> IO DNSMessage
receive Socket
sock = do
    let bufsiz :: Int
bufsiz = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxUdpSize
    ByteString
bs <- Socket -> Int -> IO ByteString
recv Socket
sock Int
bufsiz IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \IOException
e -> DNSError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO (DNSError -> IO ByteString) -> DNSError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOException -> DNSError
NetworkFailure IOException
e
    case ByteString -> Either DNSError DNSMessage
decode ByteString
bs of
        Left  DNSError
e   -> DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
e
        Right DNSMessage
msg -> DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
msg

-- | Receive and parse a single virtual-circuit (TCP) query or response.
--   It is up to the caller to implement any desired timeout.

receiveVC :: Socket -> IO DNSMessage
receiveVC :: Socket -> IO DNSMessage
receiveVC Socket
sock = do
    Int
len <- ByteString -> Int
toLen (ByteString -> Int) -> IO ByteString -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
2
    ByteString
bs <- Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
len
    case ByteString -> Either DNSError DNSMessage
decode ByteString
bs of
        Left DNSError
e    -> DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
e
        Right DNSMessage
msg -> DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
msg
  where
    toLen :: ByteString -> Int
toLen ByteString
bs = case (Char -> Int) -> [Char] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Char -> Int
ord ([Char] -> [Int]) -> [Char] -> [Int]
forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
BS.unpack ByteString
bs of
        [Int
hi, Int
lo] -> Int
256 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lo
        [Int]
_        -> Int
0              -- never reached

recvDNS :: Socket -> Int -> IO ByteString
recvDNS :: Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
len = IO ByteString
recv1 IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \IOException
e -> DNSError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO (DNSError -> IO ByteString) -> DNSError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOException -> DNSError
NetworkFailure IOException
e
  where
    recv1 :: IO ByteString
recv1 = do
        ByteString
bs1 <- Int -> IO ByteString
recvCore Int
len
        if ByteString -> Int
BS.length ByteString
bs1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len then
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs1
          else do
            ByteString -> IO ByteString
loop ByteString
bs1
    loop :: ByteString -> IO ByteString
loop ByteString
bs0 = do
        let left :: Int
left = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bs0
        ByteString
bs1 <- Int -> IO ByteString
recvCore Int
left
        let bs :: ByteString
bs = ByteString
bs0 ByteString -> ByteString -> ByteString
`BS.append` ByteString
bs1
        if ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len then
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
          else
            ByteString -> IO ByteString
loop ByteString
bs
    eofE :: IOException
eofE = IOErrorType
-> [Char] -> Maybe Handle -> Maybe [Char] -> IOException
mkIOError IOErrorType
eofErrorType [Char]
"connection terminated" Maybe Handle
forall a. Maybe a
Nothing Maybe [Char]
forall a. Maybe a
Nothing
    recvCore :: Int -> IO ByteString
recvCore Int
len0 = do
        ByteString
bs <- Socket -> Int -> IO ByteString
recv Socket
sock Int
len0
        if ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"" then
            IOException -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO IOException
eofE
          else
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

-- | Sending composed query or response to 'Socket'.
send :: Socket -> ByteString -> IO ()
send :: Socket -> ByteString -> IO ()
send Socket
sock ByteString
legacyQuery = Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
legacyQuery

-- | Sending composed query or response to a single virtual-circuit (TCP).
sendVC :: Socket -> ByteString -> IO ()
sendVC :: Socket -> ByteString -> IO ()
sendVC Socket
vc ByteString
legacyQuery = Socket -> ByteString -> IO ()
sendAll Socket
vc (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
encodeVC ByteString
legacyQuery

-- | Encoding for virtual circuit.
encodeVC :: ByteString -> ByteString
encodeVC :: ByteString -> ByteString
encodeVC ByteString
legacyQuery =
    let len :: ByteString
len = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BB.toLazyByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$ Int16 -> Builder
BB.int16BE (Int16 -> Builder) -> Int16 -> Builder
forall a b. (a -> b) -> a -> b
$ Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int16) -> Int -> Int16
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
legacyQuery
    in ByteString
len ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
legacyQuery

#if defined(WIN) && defined(GHC708)
-- Windows does not support sendAll in Network.ByteString for older GHCs.
sendAll :: Socket -> BS.ByteString -> IO ()
sendAll sock bs = do
  sent <- send sock (BS.unpack bs)
  when (sent < fromIntegral (BS.length bs)) $ sendAll sock (BS.drop (fromIntegral sent) bs)
#endif

----------------------------------------------------------------

-- | Creating query.
encodeQuestions :: Identifier
                -> [Question]
                -> [ResourceRecord] -- ^ Additional RRs for EDNS.
                -> Bool             -- ^ Authentication
                -> ByteString
encodeQuestions :: Word16 -> [Question] -> [ResourceRecord] -> Bool -> ByteString
encodeQuestions Word16
idt [Question]
qs [ResourceRecord]
adds Bool
auth = DNSMessage -> ByteString
encode DNSMessage
qry
  where
      hdr :: DNSHeader
hdr = DNSMessage -> DNSHeader
header DNSMessage
defaultQuery
      flg :: DNSFlags
flg = DNSHeader -> DNSFlags
flags DNSHeader
hdr
      qry :: DNSMessage
qry = DNSMessage
defaultQuery {
          header :: DNSHeader
header = DNSHeader
hdr {
              identifier :: Word16
identifier = Word16
idt,
              flags :: DNSFlags
flags = DNSFlags
flg {
                  authenData :: Bool
authenData = Bool
auth
              }
           }
        , question :: [Question]
question = [Question]
qs
        , additional :: [ResourceRecord]
additional = [ResourceRecord]
adds
        }

{-# DEPRECATED composeQuery "Use encodeQuestions instead" #-}
-- | Composing query without EDNS0.
composeQuery :: Identifier -> [Question] -> ByteString
composeQuery :: Word16 -> [Question] -> ByteString
composeQuery Word16
idt [Question]
qs = Word16 -> [Question] -> [ResourceRecord] -> Bool -> ByteString
encodeQuestions Word16
idt [Question]
qs [] Bool
False

{-# DEPRECATED composeQueryAD "Use encodeQuestions instead" #-}
-- | Composing query with authentic data flag set without EDNS0.
composeQueryAD :: Identifier -> [Question] -> ByteString
composeQueryAD :: Word16 -> [Question] -> ByteString
composeQueryAD Word16
idt [Question]
qs = Word16 -> [Question] -> [ResourceRecord] -> Bool -> ByteString
encodeQuestions Word16
idt [Question]
qs [] Bool
True

----------------------------------------------------------------

-- | Composing a response from IPv4 addresses
responseA :: Identifier -> Question -> [IPv4] -> DNSMessage
responseA :: Word16 -> Question -> [IPv4] -> DNSMessage
responseA Word16
ident Question
q [IPv4]
ips =
  let hd :: DNSHeader
hd = DNSMessage -> DNSHeader
header DNSMessage
defaultResponse
      dom :: ByteString
dom = Question -> ByteString
qname Question
q
      an :: [ResourceRecord]
an = ByteString -> TYPE -> Word16 -> TTL -> RData -> ResourceRecord
ResourceRecord ByteString
dom TYPE
A Word16
classIN TTL
300 (RData -> ResourceRecord)
-> (IPv4 -> RData) -> IPv4 -> ResourceRecord
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPv4 -> RData
RD_A (IPv4 -> ResourceRecord) -> [IPv4] -> [ResourceRecord]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IPv4]
ips
  in  DNSMessage
defaultResponse {
          header :: DNSHeader
header = DNSHeader
hd { identifier :: Word16
identifier=Word16
ident }
        , question :: [Question]
question = [Question
q]
        , answer :: [ResourceRecord]
answer = [ResourceRecord]
an
      }

-- | Composing a response from IPv6 addresses
responseAAAA :: Identifier -> Question -> [IPv6] -> DNSMessage
responseAAAA :: Word16 -> Question -> [IPv6] -> DNSMessage
responseAAAA Word16
ident Question
q [IPv6]
ips =
  let hd :: DNSHeader
hd = DNSMessage -> DNSHeader
header DNSMessage
defaultResponse
      dom :: ByteString
dom = Question -> ByteString
qname Question
q
      an :: [ResourceRecord]
an = ByteString -> TYPE -> Word16 -> TTL -> RData -> ResourceRecord
ResourceRecord ByteString
dom TYPE
AAAA Word16
classIN TTL
300 (RData -> ResourceRecord)
-> (IPv6 -> RData) -> IPv6 -> ResourceRecord
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPv6 -> RData
RD_AAAA (IPv6 -> ResourceRecord) -> [IPv6] -> [ResourceRecord]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IPv6]
ips
  in  DNSMessage
defaultResponse {
          header :: DNSHeader
header = DNSHeader
hd { identifier :: Word16
identifier=Word16
ident }
        , question :: [Question]
question = [Question
q]
        , answer :: [ResourceRecord]
answer = [ResourceRecord]
an
      }