{-# LANGUAGE CPP                   #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}

module Servant.API.WebSocket where

import Control.Monad                              (void, (>=>))
import Control.Monad.IO.Class                     (liftIO)
import Control.Monad.Trans.Resource               (runResourceT)
import Data.Proxy                                 (Proxy (..))
import Network.Wai.Handler.WebSockets             (websocketsOr)
import Network.WebSockets                         (Connection, PendingConnection, acceptRequest, defaultConnectionOptions)
import Servant.Server                             (HasServer (..), ServerError (..), ServerT, runHandler)
import Servant.Server.Internal.Router             (leafRouter)
import Servant.Server.Internal.RouteResult        (RouteResult (..))
import Servant.Server.Internal.Delayed            (runDelayed)

-- | Endpoint for defining a route to provide a web socket. The
-- handler function gets an already negotiated websocket 'Connection'
-- to send and receive data.
--
-- Example:
--
-- > type WebSocketApi = "stream" :> WebSocket
-- >
-- > server :: Server WebSocketApi
-- > server = streamData
-- >  where
-- >   streamData :: MonadIO m => Connection -> m ()
-- >   streamData c = do
-- >     liftIO $ forkPingThread c 10
-- >     liftIO . forM_ [1..] $ \i -> do
-- >        sendTextData c (pack $ show (i :: Int)) >> threadDelay 1000000
data WebSocket

instance HasServer WebSocket ctx where

  type ServerT WebSocket m = Connection -> m ()

#if MIN_VERSION_servant_server(0,12,0)
  hoistServerWithContext :: Proxy WebSocket
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT WebSocket m
-> ServerT WebSocket n
hoistServerWithContext Proxy WebSocket
_ Proxy ctx
_ forall x. m x -> n x
nat ServerT WebSocket m
svr = m () -> n ()
forall x. m x -> n x
nat (m () -> n ()) -> (Connection -> m ()) -> Connection -> n ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT WebSocket m
Connection -> m ()
svr
#endif

  route :: Proxy WebSocket
-> Context ctx -> Delayed env (Server WebSocket) -> Router env
route Proxy WebSocket
Proxy Context ctx
_ Delayed env (Server WebSocket)
app = (env
 -> Request
 -> (RouteResult Response -> IO ResponseReceived)
 -> IO ResponseReceived)
-> Router env
forall env a. (env -> a) -> Router' env a
leafRouter ((env
  -> Request
  -> (RouteResult Response -> IO ResponseReceived)
  -> IO ResponseReceived)
 -> Router env)
-> (env
    -> Request
    -> (RouteResult Response -> IO ResponseReceived)
    -> IO ResponseReceived)
-> Router env
forall a b. (a -> b) -> a -> b
$ \env
env Request
request RouteResult Response -> IO ResponseReceived
respond -> ResourceT IO ResponseReceived -> IO ResponseReceived
forall (m :: * -> *) a. MonadUnliftIO m => ResourceT m a -> m a
runResourceT (ResourceT IO ResponseReceived -> IO ResponseReceived)
-> ResourceT IO ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
    Delayed env (Connection -> Handler ())
-> env
-> Request
-> ResourceT IO (RouteResult (Connection -> Handler ()))
forall env a.
Delayed env a -> env -> Request -> ResourceT IO (RouteResult a)
runDelayed Delayed env (Server WebSocket)
Delayed env (Connection -> Handler ())
app env
env Request
request ResourceT IO (RouteResult (Connection -> Handler ()))
-> (RouteResult (Connection -> Handler ())
    -> ResourceT IO ResponseReceived)
-> ResourceT IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO ResponseReceived -> ResourceT IO ResponseReceived
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ResponseReceived -> ResourceT IO ResponseReceived)
-> (RouteResult (Connection -> Handler ()) -> IO ResponseReceived)
-> RouteResult (Connection -> Handler ())
-> ResourceT IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (Connection -> Handler ())
-> IO ResponseReceived
forall a.
Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (Connection -> Handler a)
-> IO ResponseReceived
go Request
request RouteResult Response -> IO ResponseReceived
respond
   where
    go :: Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (Connection -> Handler a)
-> IO ResponseReceived
go Request
request RouteResult Response -> IO ResponseReceived
respond (Route Connection -> Handler a
app') =
      ConnectionOptions -> ServerApp -> Application -> Application
websocketsOr ConnectionOptions
defaultConnectionOptions ((Connection -> Handler a) -> ServerApp
forall a. (Connection -> Handler a) -> ServerApp
runApp Connection -> Handler a
app') ((RouteResult Response -> IO ResponseReceived) -> Application
forall a t p p. (RouteResult a -> t) -> p -> p -> t
backupApp RouteResult Response -> IO ResponseReceived
respond) Request
request (RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> (Response -> RouteResult Response)
-> Response
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> RouteResult Response
forall a. a -> RouteResult a
Route)
    go Request
_ RouteResult Response -> IO ResponseReceived
respond (Fail ServerError
e) = RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> RouteResult Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult Response
forall a. ServerError -> RouteResult a
Fail ServerError
e
    go Request
_ RouteResult Response -> IO ResponseReceived
respond (FailFatal ServerError
e) = RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> RouteResult Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult Response
forall a. ServerError -> RouteResult a
FailFatal ServerError
e

    runApp :: (Connection -> Handler a) -> ServerApp
runApp Connection -> Handler a
a = PendingConnection -> IO Connection
acceptRequest (PendingConnection -> IO Connection)
-> (Connection -> IO ()) -> ServerApp
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \Connection
c -> IO (Either ServerError a) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Handler a -> IO (Either ServerError a)
forall a. Handler a -> IO (Either ServerError a)
runHandler (Handler a -> IO (Either ServerError a))
-> Handler a -> IO (Either ServerError a)
forall a b. (a -> b) -> a -> b
$ Connection -> Handler a
a Connection
c)

    backupApp :: (RouteResult a -> t) -> p -> p -> t
backupApp RouteResult a -> t
respond p
_ p
_ = RouteResult a -> t
respond (RouteResult a -> t) -> RouteResult a -> t
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult a
forall a. ServerError -> RouteResult a
Fail ServerError :: Int -> String -> ByteString -> [Header] -> ServerError
ServerError { errHTTPCode :: Int
errHTTPCode = Int
426
                                                       , errReasonPhrase :: String
errReasonPhrase = String
"Upgrade Required"
                                                       , errBody :: ByteString
errBody = ByteString
forall a. Monoid a => a
mempty
                                                       , errHeaders :: [Header]
errHeaders = [Header]
forall a. Monoid a => a
mempty
                                                       }


-- | Endpoint for defining a route to provide a web socket. The
-- handler function gets a 'PendingConnection'. It can either
-- 'rejectRequest' or 'acceptRequest'. This function is provided
-- for greater flexibility to reject connections.
--
-- Example:
--
-- > type WebSocketApi = "stream" :> WebSocketPending
-- >
-- > server :: Server WebSocketApi
-- > server = streamData
-- >  where
-- >   streamData :: MonadIO m => PendingConnection -> m ()
-- >   streamData pc = do
-- >      c <- acceptRequest pc
-- >      liftIO $ forkPingThread c 10
-- >      liftIO . forM_ [1..] $ \i ->
-- >        sendTextData c (pack $ show (i :: Int)) >> threadDelay 1000000
data WebSocketPending

instance HasServer WebSocketPending ctx where

  type ServerT WebSocketPending m = PendingConnection -> m ()

#if MIN_VERSION_servant_server(0,12,0)
  hoistServerWithContext :: Proxy WebSocketPending
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT WebSocketPending m
-> ServerT WebSocketPending n
hoistServerWithContext Proxy WebSocketPending
_ Proxy ctx
_ forall x. m x -> n x
nat ServerT WebSocketPending m
svr = m () -> n ()
forall x. m x -> n x
nat (m () -> n ())
-> (PendingConnection -> m ()) -> PendingConnection -> n ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT WebSocketPending m
PendingConnection -> m ()
svr
#endif

  route :: Proxy WebSocketPending
-> Context ctx
-> Delayed env (Server WebSocketPending)
-> Router env
route Proxy WebSocketPending
Proxy Context ctx
_ Delayed env (Server WebSocketPending)
app = (env
 -> Request
 -> (RouteResult Response -> IO ResponseReceived)
 -> IO ResponseReceived)
-> Router env
forall env a. (env -> a) -> Router' env a
leafRouter ((env
  -> Request
  -> (RouteResult Response -> IO ResponseReceived)
  -> IO ResponseReceived)
 -> Router env)
-> (env
    -> Request
    -> (RouteResult Response -> IO ResponseReceived)
    -> IO ResponseReceived)
-> Router env
forall a b. (a -> b) -> a -> b
$ \env
env Request
request RouteResult Response -> IO ResponseReceived
respond -> ResourceT IO ResponseReceived -> IO ResponseReceived
forall (m :: * -> *) a. MonadUnliftIO m => ResourceT m a -> m a
runResourceT (ResourceT IO ResponseReceived -> IO ResponseReceived)
-> ResourceT IO ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
    Delayed env (PendingConnection -> Handler ())
-> env
-> Request
-> ResourceT IO (RouteResult (PendingConnection -> Handler ()))
forall env a.
Delayed env a -> env -> Request -> ResourceT IO (RouteResult a)
runDelayed Delayed env (Server WebSocketPending)
Delayed env (PendingConnection -> Handler ())
app env
env Request
request ResourceT IO (RouteResult (PendingConnection -> Handler ()))
-> (RouteResult (PendingConnection -> Handler ())
    -> ResourceT IO ResponseReceived)
-> ResourceT IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO ResponseReceived -> ResourceT IO ResponseReceived
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ResponseReceived -> ResourceT IO ResponseReceived)
-> (RouteResult (PendingConnection -> Handler ())
    -> IO ResponseReceived)
-> RouteResult (PendingConnection -> Handler ())
-> ResourceT IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (PendingConnection -> Handler ())
-> IO ResponseReceived
forall a.
Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (PendingConnection -> Handler a)
-> IO ResponseReceived
go Request
request RouteResult Response -> IO ResponseReceived
respond
   where
    go :: Request
-> (RouteResult Response -> IO ResponseReceived)
-> RouteResult (PendingConnection -> Handler a)
-> IO ResponseReceived
go Request
request RouteResult Response -> IO ResponseReceived
respond (Route PendingConnection -> Handler a
app') =
      ConnectionOptions -> ServerApp -> Application -> Application
websocketsOr ConnectionOptions
defaultConnectionOptions ((PendingConnection -> Handler a) -> ServerApp
forall t a. (t -> Handler a) -> t -> IO ()
runApp PendingConnection -> Handler a
app') ((RouteResult Response -> IO ResponseReceived) -> Application
forall a t p p. (RouteResult a -> t) -> p -> p -> t
backupApp RouteResult Response -> IO ResponseReceived
respond) Request
request (RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> (Response -> RouteResult Response)
-> Response
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> RouteResult Response
forall a. a -> RouteResult a
Route)
    go Request
_ RouteResult Response -> IO ResponseReceived
respond (Fail ServerError
e) = RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> RouteResult Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult Response
forall a. ServerError -> RouteResult a
Fail ServerError
e
    go Request
_ RouteResult Response -> IO ResponseReceived
respond (FailFatal ServerError
e) = RouteResult Response -> IO ResponseReceived
respond (RouteResult Response -> IO ResponseReceived)
-> RouteResult Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult Response
forall a. ServerError -> RouteResult a
FailFatal ServerError
e

    runApp :: (t -> Handler a) -> t -> IO ()
runApp t -> Handler a
a t
c = IO (Either ServerError a) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Handler a -> IO (Either ServerError a)
forall a. Handler a -> IO (Either ServerError a)
runHandler (Handler a -> IO (Either ServerError a))
-> Handler a -> IO (Either ServerError a)
forall a b. (a -> b) -> a -> b
$ t -> Handler a
a t
c)

    backupApp :: (RouteResult a -> t) -> p -> p -> t
backupApp RouteResult a -> t
respond p
_ p
_ = RouteResult a -> t
respond (RouteResult a -> t) -> RouteResult a -> t
forall a b. (a -> b) -> a -> b
$ ServerError -> RouteResult a
forall a. ServerError -> RouteResult a
Fail ServerError :: Int -> String -> ByteString -> [Header] -> ServerError
ServerError { errHTTPCode :: Int
errHTTPCode = Int
426
                                                       , errReasonPhrase :: String
errReasonPhrase = String
"Upgrade Required"
                                                       , errBody :: ByteString
errBody = ByteString
forall a. Monoid a => a
mempty
                                                       , errHeaders :: [Header]
errHeaders = [Header]
forall a. Monoid a => a
mempty
                                                       }