{-# LANGUAGE RankNTypes #-}
module Data.Conduit.Network.UDP
    ( -- * UDP message representation
      SN.Message (..)
      -- * Basic utilities
    , sourceSocket
    , sinkSocket
    , sinkAllSocket
    , sinkToSocket
    , sinkAllToSocket
      -- * Helper Utilities
    , SN.HostPreference
    ) where

import Data.Conduit
import Network.Socket (Socket)
import Network.Socket.ByteString (recvFrom, send, sendAll, sendTo, sendAllTo)
import Data.ByteString (ByteString)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad (void)
import Control.Monad.Trans.Class (lift)
import qualified Data.Streaming.Network as SN

-- | Stream messages from the socket.
--
-- The given @len@ defines the maximum packet size. Every produced item
-- contains the message payload and the origin address.
--
-- This function does /not/ automatically close the socket.
sourceSocket :: MonadIO m => Socket -> Int -> ConduitT i SN.Message m ()
sourceSocket :: Socket -> Int -> ConduitT i Message m ()
sourceSocket Socket
socket Int
len = ConduitT i Message m ()
forall i b. ConduitT i Message m b
loop
  where
    loop :: ConduitT i Message m b
loop = do
        (ByteString
bs, SockAddr
addr) <- m (ByteString, SockAddr)
-> ConduitT i Message m (ByteString, SockAddr)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ByteString, SockAddr)
 -> ConduitT i Message m (ByteString, SockAddr))
-> m (ByteString, SockAddr)
-> ConduitT i Message m (ByteString, SockAddr)
forall a b. (a -> b) -> a -> b
$ IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ByteString, SockAddr) -> m (ByteString, SockAddr))
-> IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
socket Int
len
        Message -> ConduitT i Message m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> SockAddr -> Message
SN.Message ByteString
bs SockAddr
addr) ConduitT i Message m ()
-> ConduitT i Message m b -> ConduitT i Message m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT i Message m b
loop

-- | Stream messages to the connected socket.
--
-- The payload is sent using @send@, so some of it might be lost.
--
-- This function does /not/ automatically close the socket.
sinkSocket :: MonadIO m => Socket -> ConduitT ByteString o m ()
sinkSocket :: Socket -> ConduitT ByteString o m ()
sinkSocket = (Socket -> ByteString -> IO ())
-> Socket -> ConduitT ByteString o m ()
forall (m :: * -> *) a o.
MonadIO m =>
(Socket -> a -> IO ()) -> Socket -> ConduitT a o m ()
sinkSocketHelper (\Socket
sock ByteString
bs -> IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO Int
send Socket
sock ByteString
bs)

-- | Stream messages to the connected socket.
--
-- The payload is sent using @sendAll@, so it might end up in multiple packets.
--
-- This function does /not/ automatically close the socket.
sinkAllSocket :: MonadIO m => Socket -> ConduitT ByteString o m ()
sinkAllSocket :: Socket -> ConduitT ByteString o m ()
sinkAllSocket = (Socket -> ByteString -> IO ())
-> Socket -> ConduitT ByteString o m ()
forall (m :: * -> *) a o.
MonadIO m =>
(Socket -> a -> IO ()) -> Socket -> ConduitT a o m ()
sinkSocketHelper Socket -> ByteString -> IO ()
sendAll

-- | Stream messages to the socket.
--
-- Every handled item contains the message payload and the destination
-- address. The payload is sent using @sendTo@, so some of it might be
-- lost.
--
-- This function does /not/ automatically close the socket.
sinkToSocket :: MonadIO m => Socket -> ConduitT SN.Message o m ()
sinkToSocket :: Socket -> ConduitT Message o m ()
sinkToSocket = (Socket -> Message -> IO ()) -> Socket -> ConduitT Message o m ()
forall (m :: * -> *) a o.
MonadIO m =>
(Socket -> a -> IO ()) -> Socket -> ConduitT a o m ()
sinkSocketHelper (\Socket
sock (SN.Message ByteString
bs SockAddr
addr) -> IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> SockAddr -> IO Int
sendTo Socket
sock ByteString
bs SockAddr
addr)

-- | Stream messages to the socket.
--
-- Every handled item contains the message payload and the destination
-- address. The payload is sent using @sendAllTo@, so it might end up in
-- multiple packets.
--
-- This function does /not/ automatically close the socket.
sinkAllToSocket :: MonadIO m => Socket -> ConduitT SN.Message o m ()
sinkAllToSocket :: Socket -> ConduitT Message o m ()
sinkAllToSocket = (Socket -> Message -> IO ()) -> Socket -> ConduitT Message o m ()
forall (m :: * -> *) a o.
MonadIO m =>
(Socket -> a -> IO ()) -> Socket -> ConduitT a o m ()
sinkSocketHelper (\Socket
sock (SN.Message ByteString
bs SockAddr
addr) -> Socket -> ByteString -> SockAddr -> IO ()
sendAllTo Socket
sock ByteString
bs SockAddr
addr)

-- Internal
sinkSocketHelper :: MonadIO m => (Socket -> a -> IO ())
                              -> Socket
                              -> ConduitT a o m ()
sinkSocketHelper :: (Socket -> a -> IO ()) -> Socket -> ConduitT a o m ()
sinkSocketHelper Socket -> a -> IO ()
act Socket
socket = ConduitT a o m ()
forall o. ConduitT a o m ()
loop
  where
    loop :: ConduitT a o m ()
loop = ConduitT a o m (Maybe a)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT a o m (Maybe a)
-> (Maybe a -> ConduitT a o m ()) -> ConduitT a o m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConduitT a o m ()
-> (a -> ConduitT a o m ()) -> Maybe a -> ConduitT a o m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                        (() -> ConduitT a o m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                        (\a
a -> m () -> ConduitT a o m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> a -> IO ()
act Socket
socket a
a) ConduitT a o m () -> ConduitT a o m () -> ConduitT a o m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT a o m ()
loop)
{-# INLINE sinkSocketHelper #-}