{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
module Data.Conduit.Network
    ( -- * Basic utilities
      sourceSocket
    , sinkSocket
      -- * Simple TCP server/client interface.
    , SN.AppData
    , appSource
    , appSink
    , SN.appSockAddr
    , SN.appLocalAddr
      -- ** Server
    , SN.ServerSettings
    , serverSettings
    , SN.runTCPServer
    , SN.runTCPServerWithHandle
    , forkTCPServer
    , runGeneralTCPServer
      -- ** Client
    , SN.ClientSettings
    , clientSettings
    , SN.runTCPClient
    , runGeneralTCPClient
      -- ** Getters
    , SN.getPort
    , SN.getHost
    , SN.getAfterBind
    , SN.getNeedLocalAddr
      -- ** Setters
    , SN.setPort
    , SN.setHost
    , SN.setAfterBind
    , SN.setNeedLocalAddr
      -- * Types
    , SN.HostPreference
    ) where

import Prelude
import Data.Conduit
import Network.Socket (Socket)
import Network.Socket.ByteString (sendAll)
import Data.ByteString (ByteString)
import qualified GHC.Conc as Conc (yield)
import qualified Data.ByteString as S
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad (unless)
import Control.Monad.Trans.Class (lift)
import Control.Concurrent (forkIO, newEmptyMVar, putMVar, takeMVar, MVar, ThreadId)
import qualified Data.Streaming.Network as SN
import Control.Monad.IO.Unlift (MonadUnliftIO, withRunInIO)

-- | Stream data from the socket.
--
-- This function does /not/ automatically close the socket.
--
-- Since 0.0.0
sourceSocket :: MonadIO m => Socket -> ConduitT i ByteString m ()
sourceSocket :: Socket -> ConduitT i ByteString m ()
sourceSocket Socket
socket =
    ConduitT i ByteString m ()
forall i. ConduitT i ByteString m ()
loop
  where
    loop :: ConduitT i ByteString m ()
loop = do
        ByteString
bs <- m ByteString -> ConduitT i ByteString m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> ConduitT i ByteString m ByteString)
-> m ByteString -> ConduitT i ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ByteString
SN.safeRecv Socket
socket Int
4096
        if ByteString -> Bool
S.null ByteString
bs
            then () -> ConduitT i ByteString m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            else ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs ConduitT i ByteString m ()
-> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT i ByteString m ()
loop

-- | Stream data to the socket.
--
-- This function does /not/ automatically close the socket.
--
-- Since 0.0.0
sinkSocket :: MonadIO m => Socket -> ConduitT ByteString o m ()
sinkSocket :: Socket -> ConduitT ByteString o m ()
sinkSocket Socket
socket =
    ConduitT ByteString o m ()
forall o. ConduitT ByteString o m ()
loop
  where
    loop :: ConduitT ByteString o m ()
loop = ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT ByteString o m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConduitT ByteString o m ()
-> (ByteString -> ConduitT ByteString o m ())
-> Maybe ByteString
-> ConduitT ByteString o m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> ConduitT ByteString o m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (\ByteString
bs -> m () -> ConduitT ByteString o m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendAll Socket
socket ByteString
bs) ConduitT ByteString o m ()
-> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT ByteString o m ()
loop)

serverSettings :: Int -> SN.HostPreference -> SN.ServerSettings
serverSettings :: Int -> HostPreference -> ServerSettings
serverSettings = Int -> HostPreference -> ServerSettings
SN.serverSettingsTCP

clientSettings :: Int -> ByteString -> SN.ClientSettings
clientSettings :: Int -> ByteString -> ClientSettings
clientSettings = Int -> ByteString -> ClientSettings
SN.clientSettingsTCP

appSource :: (SN.HasReadWrite ad, MonadIO m) => ad -> ConduitT i ByteString m ()
appSource :: ad -> ConduitT i ByteString m ()
appSource ad
ad =
    ConduitT i ByteString m ()
forall i. ConduitT i ByteString m ()
loop
  where
    read' :: IO ByteString
read' = ad -> IO ByteString
forall a. HasReadWrite a => a -> IO ByteString
SN.appRead ad
ad
    loop :: ConduitT i ByteString m ()
loop = do
        ByteString
bs <- IO ByteString -> ConduitT i ByteString m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
read'
        Bool -> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
S.null ByteString
bs) (ConduitT i ByteString m () -> ConduitT i ByteString m ())
-> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ do
            ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs
            ConduitT i ByteString m ()
loop

appSink :: (SN.HasReadWrite ad, MonadIO m) => ad -> ConduitT ByteString o m ()
appSink :: ad -> ConduitT ByteString o m ()
appSink ad
ad = (ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever ((ByteString -> ConduitT ByteString o m ())
 -> ConduitT ByteString o m ())
-> (ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ \ByteString
d -> IO () -> ConduitT ByteString o m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT ByteString o m ())
-> IO () -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ ad -> ByteString -> IO ()
forall a. HasReadWrite a => a -> ByteString -> IO ()
SN.appWrite ad
ad ByteString
d IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
Conc.yield

addBoundSignal::MVar ()-> SN.ServerSettings -> SN.ServerSettings
addBoundSignal :: MVar () -> ServerSettings -> ServerSettings
addBoundSignal MVar ()
isBound ServerSettings
set = (Socket -> IO ()) -> ServerSettings -> ServerSettings
forall a. HasAfterBind a => (Socket -> IO ()) -> a -> a
SN.setAfterBind ( \Socket
socket -> Socket -> IO ()
originalAfterBind Socket
socket IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>  Socket -> IO ()
signalBound Socket
socket) ServerSettings
set
                             where originalAfterBind :: Socket -> IO ()
                                   originalAfterBind :: Socket -> IO ()
originalAfterBind = ServerSettings -> Socket -> IO ()
forall a. HasAfterBind a => a -> Socket -> IO ()
SN.getAfterBind ServerSettings
set
                                   signalBound :: Socket -> IO ()
                                   signalBound :: Socket -> IO ()
signalBound Socket
_socket = MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
isBound ()

-- | Fork a TCP Server
--
-- Will fork the runGeneralTCPServer function but will only return from
-- this call when the server is bound to the port and accepting incoming
-- connections. Will return the thread id of the server
--
-- Since 1.1.4
forkTCPServer
  :: MonadUnliftIO m
  => SN.ServerSettings
  -> (SN.AppData -> m ())
  -> m ThreadId
forkTCPServer :: ServerSettings -> (AppData -> m ()) -> m ThreadId
forkTCPServer ServerSettings
set AppData -> m ()
f =
       ((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId)
-> ((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run -> do
         MVar ()
isBound <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
         let setWithWaitForBind :: ServerSettings
setWithWaitForBind = MVar () -> ServerSettings -> ServerSettings
addBoundSignal MVar ()
isBound ServerSettings
set
         ThreadId
threadId <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> (m () -> IO ()) -> m () -> IO ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> IO ()
forall a. m a -> IO a
run (m () -> IO ThreadId) -> m () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ServerSettings -> (AppData -> m ()) -> m ()
forall (m :: * -> *) a.
MonadUnliftIO m =>
ServerSettings -> (AppData -> m ()) -> m a
runGeneralTCPServer ServerSettings
setWithWaitForBind AppData -> m ()
f
         MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
isBound
         ThreadId -> IO ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
threadId



-- | Run a general TCP server
--
-- Same as 'SN.runTCPServer', except monad can be any instance of
-- 'MonadUnliftIO'.
--
-- Note that any changes to the monadic state performed by individual
-- client handlers will be discarded. If you have mutable state you want
-- to share among multiple handlers, you need to use some kind of mutable
-- variables.
--
-- Since 1.1.3
runGeneralTCPServer
  :: MonadUnliftIO m
  => SN.ServerSettings
  -> (SN.AppData -> m ())
  -> m a
runGeneralTCPServer :: ServerSettings -> (AppData -> m ()) -> m a
runGeneralTCPServer ServerSettings
set AppData -> m ()
f = ((forall a. m a -> IO a) -> IO a) -> m a
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO a) -> m a)
-> ((forall a. m a -> IO a) -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run ->
    ServerSettings -> (AppData -> IO ()) -> IO a
forall a. ServerSettings -> (AppData -> IO ()) -> IO a
SN.runTCPServer ServerSettings
set ((AppData -> IO ()) -> IO a) -> (AppData -> IO ()) -> IO a
forall a b. (a -> b) -> a -> b
$ m () -> IO ()
forall a. m a -> IO a
run (m () -> IO ()) -> (AppData -> m ()) -> AppData -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppData -> m ()
f

-- | Run a general TCP client
--
-- Same as 'SN.runTCPClient', except monad can be any instance of 'MonadUnliftIO'.
--
-- Since 1.1.3
runGeneralTCPClient
  :: MonadUnliftIO m
  => SN.ClientSettings
  -> (SN.AppData -> m a)
  -> m a
runGeneralTCPClient :: ClientSettings -> (AppData -> m a) -> m a
runGeneralTCPClient ClientSettings
set AppData -> m a
f = ((forall a. m a -> IO a) -> IO a) -> m a
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO a) -> m a)
-> ((forall a. m a -> IO a) -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run ->
    ClientSettings -> (AppData -> IO a) -> IO a
forall a. ClientSettings -> (AppData -> IO a) -> IO a
SN.runTCPClient ClientSettings
set ((AppData -> IO a) -> IO a) -> (AppData -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ m a -> IO a
forall a. m a -> IO a
run (m a -> IO a) -> (AppData -> m a) -> AppData -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppData -> m a
f