{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE DerivingVia         #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

module Ouroboros.Network.Snocket
  ( -- * Snocket Interface
    Accept (..)
  , Accepted (..)
  , AddressFamily (..)
  , Snocket (..)
    -- ** Socket based Snocktes
  , SocketSnocket
  , socketSnocket
    -- ** Local Snockets
    -- Using unix sockets (posix) or named pipes (windows)
    --
  , LocalSnocket
  , localSnocket
  , LocalSocket (..)
  , LocalAddress (..)
  , localAddressFromPath
  , TestAddress (..)
  , FileDescriptor (..)
  , socketFileDescriptor
  , localSocketFileDescriptor
  ) where

import           Control.Exception
import           Control.Monad (when)
import           Control.Monad.Class.MonadTime (DiffTime)
import           Control.Tracer (Tracer)
import           Data.Bifoldable (Bifoldable (..))
import           Data.Bifunctor (Bifunctor (..))
import           Data.Hashable
import           Data.Typeable (Typeable)
import           Data.Word
import           GHC.Generics (Generic)
import           Quiet (Quiet (..))
#if !defined(mingw32_HOST_OS)
import           Network.Socket (Family (AF_UNIX))
#endif
import           Network.Socket (SockAddr (..), Socket)
#if defined(mingw32_HOST_OS)
import           Data.Bits
import           Foreign.Ptr (IntPtr (..), ptrToIntPtr)
import qualified System.Win32 as Win32
import qualified System.Win32.Async as Win32.Async
import qualified System.Win32.NamedPipes as Win32

import           Network.Mux.Bearer.NamedPipe (namedPipeAsBearer)
#endif
import qualified Network.Socket as Socket

import qualified Network.Mux.Bearer.Socket as Mx
import           Network.Mux.Trace (MuxTrace)
import           Network.Mux.Types (MuxBearer)

import           Ouroboros.Network.IOManager
import           Ouroboros.Network.Linger (StructLinger (..))


-- | Named pipes and Berkeley sockets have different API when accepting
-- a connection.  For named pipes the file descriptor created by 'createNamedPipe' is
-- supposed to be used for the first connected client.  Named pipe accept loop
-- looks this way:
--
-- > acceptLoop k = do
-- >   h <- createNamedPipe name
-- >   connectNamedPipe h
-- >   -- h is now in connected state
-- >   forkIO (k h)
-- >   acceptLoop k
--
-- For Berkeley sockets equivalent loop starts by creating a socket
-- which accepts connections and accept returns a new socket in connected
-- state
--
-- > acceptLoop k = do
-- >     s <- socket ...
-- >     bind s address
-- >     listen s
-- >     loop s
-- >   where
-- >     loop s = do
-- >       (s' , _addr') <- accept s
-- >       -- s' is in connected state
-- >       forkIO (k s')
-- >       loop s
--
-- To make common API for both we use a recursive type 'Accept', see
-- 'berkeleyAccept' below.  Creation of a socket / named pipe is part of
-- 'Snocket', but this means we need to have different recursion step for named
-- pipe & sockets.  For sockets its recursion step will always return 'accept'
-- syscall; for named pipes the first callback will reuse the file descriptor
-- created by 'open' and only subsequent calls will create a new file
-- descriptor by `createNamedPipe`, see 'namedPipeSnocket'.
--
newtype Accept m fd addr = Accept
  { Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
runAccept :: m (Accepted fd addr, Accept m fd addr)
  }

instance Functor m => Bifunctor (Accept m) where
    bimap :: (a -> b) -> (c -> d) -> Accept m a c -> Accept m b d
bimap a -> b
f c -> d
g (Accept m (Accepted a c, Accept m a c)
ac) = m (Accepted b d, Accept m b d) -> Accept m b d
forall (m :: * -> *) fd addr.
m (Accepted fd addr, Accept m fd addr) -> Accept m fd addr
Accept ((Accepted a c, Accept m a c) -> (Accepted b d, Accept m b d)
h ((Accepted a c, Accept m a c) -> (Accepted b d, Accept m b d))
-> m (Accepted a c, Accept m a c) -> m (Accepted b d, Accept m b d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Accepted a c, Accept m a c)
ac)
      where
        h :: (Accepted a c, Accept m a c) -> (Accepted b d, Accept m b d)
h (Accepted a c
accepted, Accept m a c
next) = ((a -> b) -> (c -> d) -> Accepted a c -> Accepted b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> b
f c -> d
g Accepted a c
accepted, (a -> b) -> (c -> d) -> Accept m a c -> Accept m b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> b
f c -> d
g Accept m a c
next)


data Accepted fd addr where
    AcceptFailure :: !SomeException -> Accepted fd addr
    Accepted      :: !fd -> !addr -> Accepted fd addr

instance Bifunctor Accepted where
    bimap :: (a -> b) -> (c -> d) -> Accepted a c -> Accepted b d
bimap a -> b
f c -> d
g (Accepted a
fd c
addr)  = b -> d -> Accepted b d
forall fd addr. fd -> addr -> Accepted fd addr
Accepted (a -> b
f a
fd) (c -> d
g c
addr)
    bimap a -> b
_ c -> d
_ (AcceptFailure SomeException
err) = SomeException -> Accepted b d
forall fd addr. SomeException -> Accepted fd addr
AcceptFailure SomeException
err

instance Bifoldable Accepted where
    bifoldMap :: (a -> m) -> (b -> m) -> Accepted a b -> m
bifoldMap a -> m
f b -> m
g (Accepted a
fd b
addr) = a -> m
f a
fd m -> m -> m
forall a. Semigroup a => a -> a -> a
<> b -> m
g b
addr
    bifoldMap a -> m
_ b -> m
_ (AcceptFailure SomeException
_)  = m
forall a. Monoid a => a
mempty


-- | BSD accept loop.
--
berkeleyAccept :: IOManager
               -> Socket
               -> IO (Accept IO Socket SockAddr)
berkeleyAccept :: IOManager -> Socket -> IO (Accept IO Socket SockAddr)
berkeleyAccept IOManager
ioManager Socket
sock =
      Word64 -> SockAddr -> Accept IO Socket SockAddr
go Word64
0 (SockAddr -> Accept IO Socket SockAddr)
-> IO SockAddr -> IO (Accept IO Socket SockAddr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> IO SockAddr
Socket.getSocketName Socket
sock
    where
      go :: Word64 -> SockAddr -> Accept IO Socket SockAddr
go !Word64
cnt !SockAddr
addr = IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
-> Accept IO Socket SockAddr
forall (m :: * -> *) fd addr.
m (Accepted fd addr, Accept m fd addr) -> Accept m fd addr
Accept (SockAddr
-> Word64
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
acceptOne SockAddr
addr Word64
cnt IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
-> (SomeException
    -> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr))
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SockAddr
-> Word64
-> SomeException
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
handleException SockAddr
addr Word64
cnt)

      acceptOne
        :: SockAddr
        -> Word64
        -> IO ( Accepted  Socket SockAddr
              , Accept IO Socket SockAddr
              )
      acceptOne :: SockAddr
-> Word64
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
acceptOne SockAddr
addr Word64
cnt =
        IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr)
    -> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr))
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
#if !defined(mingw32_HOST_OS)
          (Socket -> IO (Socket, SockAddr)
Socket.accept Socket
sock)
#else
          (Win32.Async.accept sock)
#endif
          (IO () -> IO ()
forall a. IO a -> IO a
uninterruptibleMask_ (IO () -> IO ())
-> ((Socket, SockAddr) -> IO ()) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
Socket.close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst)
          (((Socket, SockAddr)
  -> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr))
 -> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr))
-> ((Socket, SockAddr)
    -> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr))
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall a b. (a -> b) -> a -> b
$ \(Socket
sock', SockAddr
addr') -> do
            IOManager -> Either Any Socket -> IO ()
IOManager -> forall hole. hole -> IO ()
associateWithIOManager IOManager
ioManager (Socket -> Either Any Socket
forall a b. b -> Either a b
Right Socket
sock')

            -- UNIX sockets don't provide a unique endpoint for the remote
            -- side, but the InboundGovernor/Server requires one in order to
            -- track connections.
            -- So to differentiate clients we use a simple counter as the
            -- remote end's address.
            --
            SockAddr
addr'' <- case SockAddr
addr of
                           Socket.SockAddrUnix String
path
                             -> SockAddr -> IO SockAddr
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> SockAddr
Socket.SockAddrUnix (String -> SockAddr) -> String -> SockAddr
forall a b. (a -> b) -> a -> b
$ String
path String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"@" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word64 -> String
forall a. Show a => a -> String
show Word64
cnt)
                           SockAddr
_ -> SockAddr -> IO SockAddr
forall (m :: * -> *) a. Monad m => a -> m a
return SockAddr
addr'

            (Accepted Socket SockAddr, Accept IO Socket SockAddr)
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> SockAddr -> Accepted Socket SockAddr
forall fd addr. fd -> addr -> Accepted fd addr
Accepted Socket
sock' SockAddr
addr'', Word64 -> SockAddr -> Accept IO Socket SockAddr
go (Word64 -> Word64
forall a. Enum a => a -> a
succ Word64
cnt) SockAddr
addr)

      -- Only non-async exceptions will be caught and put into the
      -- AcceptFailure variant.
      handleException
        :: SockAddr
        -> Word64
        -> SomeException
        -> IO ( Accepted  Socket SockAddr
              , Accept IO Socket SockAddr
              )
      handleException :: SockAddr
-> Word64
-> SomeException
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
handleException SockAddr
addr Word64
cnt SomeException
err =
        case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
err of
          Just (SomeAsyncException e
_) -> SomeException
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall e a. Exception e => e -> IO a
throwIO SomeException
err
          Maybe SomeAsyncException
Nothing                     -> (Accepted Socket SockAddr, Accept IO Socket SockAddr)
-> IO (Accepted Socket SockAddr, Accept IO Socket SockAddr)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SomeException -> Accepted Socket SockAddr
forall fd addr. SomeException -> Accepted fd addr
AcceptFailure SomeException
err, Word64 -> SockAddr -> Accept IO Socket SockAddr
go Word64
cnt SockAddr
addr)

-- | Local address, on Unix is associated with `Socket.AF_UNIX` family, on
--
-- Windows with `named-pipes`.
--
newtype LocalAddress = LocalAddress { LocalAddress -> String
getFilePath :: FilePath }
  deriving (LocalAddress -> LocalAddress -> Bool
(LocalAddress -> LocalAddress -> Bool)
-> (LocalAddress -> LocalAddress -> Bool) -> Eq LocalAddress
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LocalAddress -> LocalAddress -> Bool
$c/= :: LocalAddress -> LocalAddress -> Bool
== :: LocalAddress -> LocalAddress -> Bool
$c== :: LocalAddress -> LocalAddress -> Bool
Eq, Eq LocalAddress
Eq LocalAddress
-> (LocalAddress -> LocalAddress -> Ordering)
-> (LocalAddress -> LocalAddress -> Bool)
-> (LocalAddress -> LocalAddress -> Bool)
-> (LocalAddress -> LocalAddress -> Bool)
-> (LocalAddress -> LocalAddress -> Bool)
-> (LocalAddress -> LocalAddress -> LocalAddress)
-> (LocalAddress -> LocalAddress -> LocalAddress)
-> Ord LocalAddress
LocalAddress -> LocalAddress -> Bool
LocalAddress -> LocalAddress -> Ordering
LocalAddress -> LocalAddress -> LocalAddress
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: LocalAddress -> LocalAddress -> LocalAddress
$cmin :: LocalAddress -> LocalAddress -> LocalAddress
max :: LocalAddress -> LocalAddress -> LocalAddress
$cmax :: LocalAddress -> LocalAddress -> LocalAddress
>= :: LocalAddress -> LocalAddress -> Bool
$c>= :: LocalAddress -> LocalAddress -> Bool
> :: LocalAddress -> LocalAddress -> Bool
$c> :: LocalAddress -> LocalAddress -> Bool
<= :: LocalAddress -> LocalAddress -> Bool
$c<= :: LocalAddress -> LocalAddress -> Bool
< :: LocalAddress -> LocalAddress -> Bool
$c< :: LocalAddress -> LocalAddress -> Bool
compare :: LocalAddress -> LocalAddress -> Ordering
$ccompare :: LocalAddress -> LocalAddress -> Ordering
$cp1Ord :: Eq LocalAddress
Ord, (forall x. LocalAddress -> Rep LocalAddress x)
-> (forall x. Rep LocalAddress x -> LocalAddress)
-> Generic LocalAddress
forall x. Rep LocalAddress x -> LocalAddress
forall x. LocalAddress -> Rep LocalAddress x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep LocalAddress x -> LocalAddress
$cfrom :: forall x. LocalAddress -> Rep LocalAddress x
Generic)
  deriving Int -> LocalAddress -> String -> String
[LocalAddress] -> String -> String
LocalAddress -> String
(Int -> LocalAddress -> String -> String)
-> (LocalAddress -> String)
-> ([LocalAddress] -> String -> String)
-> Show LocalAddress
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LocalAddress] -> String -> String
$cshowList :: [LocalAddress] -> String -> String
show :: LocalAddress -> String
$cshow :: LocalAddress -> String
showsPrec :: Int -> LocalAddress -> String -> String
$cshowsPrec :: Int -> LocalAddress -> String -> String
Show via Quiet LocalAddress

instance Hashable LocalAddress where
    hashWithSalt :: Int -> LocalAddress -> Int
hashWithSalt Int
s (LocalAddress String
path) = Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s String
path

newtype TestAddress addr = TestAddress { TestAddress addr -> addr
getTestAddress :: addr }
  deriving (TestAddress addr -> TestAddress addr -> Bool
(TestAddress addr -> TestAddress addr -> Bool)
-> (TestAddress addr -> TestAddress addr -> Bool)
-> Eq (TestAddress addr)
forall addr.
Eq addr =>
TestAddress addr -> TestAddress addr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TestAddress addr -> TestAddress addr -> Bool
$c/= :: forall addr.
Eq addr =>
TestAddress addr -> TestAddress addr -> Bool
== :: TestAddress addr -> TestAddress addr -> Bool
$c== :: forall addr.
Eq addr =>
TestAddress addr -> TestAddress addr -> Bool
Eq, Eq (TestAddress addr)
Eq (TestAddress addr)
-> (TestAddress addr -> TestAddress addr -> Ordering)
-> (TestAddress addr -> TestAddress addr -> Bool)
-> (TestAddress addr -> TestAddress addr -> Bool)
-> (TestAddress addr -> TestAddress addr -> Bool)
-> (TestAddress addr -> TestAddress addr -> Bool)
-> (TestAddress addr -> TestAddress addr -> TestAddress addr)
-> (TestAddress addr -> TestAddress addr -> TestAddress addr)
-> Ord (TestAddress addr)
TestAddress addr -> TestAddress addr -> Bool
TestAddress addr -> TestAddress addr -> Ordering
TestAddress addr -> TestAddress addr -> TestAddress addr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall addr. Ord addr => Eq (TestAddress addr)
forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Bool
forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Ordering
forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> TestAddress addr
min :: TestAddress addr -> TestAddress addr -> TestAddress addr
$cmin :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> TestAddress addr
max :: TestAddress addr -> TestAddress addr -> TestAddress addr
$cmax :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> TestAddress addr
>= :: TestAddress addr -> TestAddress addr -> Bool
$c>= :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Bool
> :: TestAddress addr -> TestAddress addr -> Bool
$c> :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Bool
<= :: TestAddress addr -> TestAddress addr -> Bool
$c<= :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Bool
< :: TestAddress addr -> TestAddress addr -> Bool
$c< :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Bool
compare :: TestAddress addr -> TestAddress addr -> Ordering
$ccompare :: forall addr.
Ord addr =>
TestAddress addr -> TestAddress addr -> Ordering
$cp1Ord :: forall addr. Ord addr => Eq (TestAddress addr)
Ord, (forall x. TestAddress addr -> Rep (TestAddress addr) x)
-> (forall x. Rep (TestAddress addr) x -> TestAddress addr)
-> Generic (TestAddress addr)
forall x. Rep (TestAddress addr) x -> TestAddress addr
forall x. TestAddress addr -> Rep (TestAddress addr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall addr x. Rep (TestAddress addr) x -> TestAddress addr
forall addr x. TestAddress addr -> Rep (TestAddress addr) x
$cto :: forall addr x. Rep (TestAddress addr) x -> TestAddress addr
$cfrom :: forall addr x. TestAddress addr -> Rep (TestAddress addr) x
Generic, Typeable)
  deriving Int -> TestAddress addr -> String -> String
[TestAddress addr] -> String -> String
TestAddress addr -> String
(Int -> TestAddress addr -> String -> String)
-> (TestAddress addr -> String)
-> ([TestAddress addr] -> String -> String)
-> Show (TestAddress addr)
forall addr.
Show addr =>
Int -> TestAddress addr -> String -> String
forall addr. Show addr => [TestAddress addr] -> String -> String
forall addr. Show addr => TestAddress addr -> String
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [TestAddress addr] -> String -> String
$cshowList :: forall addr. Show addr => [TestAddress addr] -> String -> String
show :: TestAddress addr -> String
$cshow :: forall addr. Show addr => TestAddress addr -> String
showsPrec :: Int -> TestAddress addr -> String -> String
$cshowsPrec :: forall addr.
Show addr =>
Int -> TestAddress addr -> String -> String
Show via Quiet (TestAddress addr)

-- | We support either sockets or named pipes.
--
-- There are three families of addresses: 'SocketFamily' usef for Berkeley
-- sockets, 'LocalFamily' used for 'LocalAddress'es (either Unix sockets or
-- Windows named pipe addresses), and 'TestFamily' for testing purposes.
--
-- 'LocalFamily' requires 'LocalAddress', this is needed to provide path of the
-- opened Win32 'HANDLE'.
--
data AddressFamily addr where

    SocketFamily :: !Socket.Family
                 -> AddressFamily Socket.SockAddr

    LocalFamily  :: !LocalAddress -> AddressFamily LocalAddress

    -- | Using a newtype wrapper 'TestAddress' makes pattern matches on
    -- @AddressFamily@ complete, e.g. it makes 'AddressFamily' injective:
    -- @AddressFamily addr == AddressFamily addr'@ then @addr == addr'@. .
    --
    TestFamily   :: AddressFamily (TestAddress addr)

deriving instance Eq   addr => Eq   (AddressFamily addr)
deriving instance Show addr => Show (AddressFamily addr)


-- | Abstract communication interface that can be used by more than
-- 'Socket'.  Snockets are polymorphic over monad which is used, this feature
-- is useful for testing and/or simulations.
--
data Snocket m fd addr = Snocket {
    Snocket m fd addr -> fd -> m addr
getLocalAddr  :: fd -> m addr
  , Snocket m fd addr -> fd -> m addr
getRemoteAddr :: fd -> m addr

  , Snocket m fd addr -> addr -> AddressFamily addr
addrFamily    :: addr -> AddressFamily addr

  -- | Open a file descriptor  (socket / namedPipe).  For named pipes this is
  -- using 'CreateNamedPipe' syscall, for Berkeley sockets 'socket' is used.
  --
  , Snocket m fd addr -> AddressFamily addr -> m fd
open          :: AddressFamily addr -> m fd

    -- | A way to create 'fd' to pass to 'connect'.  For named pipes it will
    -- use 'CreateFile' syscall.  For Berkeley sockets this the same as 'open'.
    --
    -- For named pipes we need full 'addr' rather than just address family as
    -- it is for sockets.
    --
  , Snocket m fd addr -> addr -> m fd
openToConnect :: addr -> m fd

    -- | `connect` is only needed for Berkeley sockets, for named pipes this is
    -- no-op.
    --
  , Snocket m fd addr -> fd -> addr -> m ()
connect       :: fd -> addr -> m ()
  , Snocket m fd addr -> fd -> addr -> m ()
bind          :: fd -> addr -> m ()
  , Snocket m fd addr -> fd -> m ()
listen        :: fd -> m ()

  -- SomeException is chosen here to avoid having to include it in the Snocket
  -- type, and therefore refactoring a bunch of stuff.
  -- FIXME probably a good idea to abstract it.
  , Snocket m fd addr -> fd -> m (Accept m fd addr)
accept        :: fd -> m (Accept m fd addr)

  , Snocket m fd addr -> fd -> m ()
close         :: fd -> m ()

  , Snocket m fd addr
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
toBearer      ::  DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
  }


pureBearer :: Monad m
           => (DiffTime -> Tracer m MuxTrace -> fd ->    MuxBearer m)
           ->  DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
pureBearer :: (DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m)
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
pureBearer DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m
f = \DiffTime
timeout Tracer m MuxTrace
tr fd
fd -> MuxBearer m -> m (MuxBearer m)
forall (m :: * -> *) a. Monad m => a -> m a
return (DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m
f DiffTime
timeout Tracer m MuxTrace
tr fd
fd)

--
-- Socket based Snockets
--


socketAddrFamily
    :: Socket.SockAddr
    -> AddressFamily Socket.SockAddr
socketAddrFamily :: SockAddr -> AddressFamily SockAddr
socketAddrFamily (Socket.SockAddrInet  PortNumber
_ HostAddress
_    ) = Family -> AddressFamily SockAddr
SocketFamily Family
Socket.AF_INET
socketAddrFamily (Socket.SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
_ HostAddress
_) = Family -> AddressFamily SockAddr
SocketFamily Family
Socket.AF_INET6
socketAddrFamily (Socket.SockAddrUnix String
_       ) = Family -> AddressFamily SockAddr
SocketFamily Family
Socket.AF_UNIX


type SocketSnocket = Snocket IO Socket SockAddr


-- | Create a 'Snocket' for the given 'Socket.Family'. In the 'bind' method set
-- 'Socket.ReuseAddr` and 'Socket.ReusePort'.
--
socketSnocket
  :: IOManager
  -- ^ 'IOManager' interface.  We use it when we create a new socket and when we
  -- accept a connection.
  --
  -- Though it could be used in `open`, but that is going to be used in
  -- a bracket so it's better to keep it simple.
  --
  -> SocketSnocket
socketSnocket :: IOManager -> SocketSnocket
socketSnocket IOManager
ioManager = Snocket :: forall (m :: * -> *) fd addr.
(fd -> m addr)
-> (fd -> m addr)
-> (addr -> AddressFamily addr)
-> (AddressFamily addr -> m fd)
-> (addr -> m fd)
-> (fd -> addr -> m ())
-> (fd -> addr -> m ())
-> (fd -> m ())
-> (fd -> m (Accept m fd addr))
-> (fd -> m ())
-> (DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m))
-> Snocket m fd addr
Snocket {
      getLocalAddr :: Socket -> IO SockAddr
getLocalAddr   = Socket -> IO SockAddr
Socket.getSocketName
    , getRemoteAddr :: Socket -> IO SockAddr
getRemoteAddr  = Socket -> IO SockAddr
Socket.getPeerName
    , addrFamily :: SockAddr -> AddressFamily SockAddr
addrFamily     = SockAddr -> AddressFamily SockAddr
socketAddrFamily
    , open :: AddressFamily SockAddr -> IO Socket
open           = AddressFamily SockAddr -> IO Socket
openSocket
    , openToConnect :: SockAddr -> IO Socket
openToConnect  = \SockAddr
addr -> AddressFamily SockAddr -> IO Socket
openSocket (SockAddr -> AddressFamily SockAddr
socketAddrFamily SockAddr
addr)
    , connect :: Socket -> SockAddr -> IO ()
connect        = \Socket
s SockAddr
a -> do
#if !defined(mingw32_HOST_OS)
        Socket -> SockAddr -> IO ()
Socket.connect Socket
s SockAddr
a
#else
        Win32.Async.connect s a
#endif
    , bind :: Socket -> SockAddr -> IO ()
bind = \Socket
sd SockAddr
addr -> do
        let SocketFamily Family
fml = SockAddr -> AddressFamily SockAddr
socketAddrFamily SockAddr
addr
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Family
fml Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET Bool -> Bool -> Bool
||
              Family
fml Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET6) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sd SocketOption
Socket.ReuseAddr Int
1
#if !defined(mingw32_HOST_OS)
          -- not supported on Windows 10
          Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sd SocketOption
Socket.ReusePort Int
1
#endif
          Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sd SocketOption
Socket.NoDelay Int
1
          -- it is safe to set 'SO_LINGER' option (which implicates that every
          -- close will reset the connection), since our protocols are robust.
          -- In particualar if invalid data will arive (which includes the the
          -- rare case of a late packet from a previous connection), we will
          -- abandon (and close) the connection.
          Socket -> SocketOption -> StructLinger -> IO ()
forall a. Storable a => Socket -> SocketOption -> a -> IO ()
Socket.setSockOpt Socket
sd SocketOption
Socket.Linger
                              (StructLinger :: CInt -> CInt -> StructLinger
StructLinger { sl_onoff :: CInt
sl_onoff  = CInt
1,
                                              sl_linger :: CInt
sl_linger = CInt
0 })
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Family
fml Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET6)
          -- An AF_INET6 socket can be used to talk to both IPv4 and IPv6 end points, and
          -- it is enabled by default on some systems. Disabled here since we run a separate
          -- IPv4 server instance if configured to use IPv4.
          (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sd SocketOption
Socket.IPv6Only Int
1

        Socket -> SockAddr -> IO ()
Socket.bind Socket
sd SockAddr
addr
    , listen :: Socket -> IO ()
listen   = \Socket
s -> Socket -> Int -> IO ()
Socket.listen Socket
s Int
8
    , accept :: Socket -> IO (Accept IO Socket SockAddr)
accept   = IOManager -> Socket -> IO (Accept IO Socket SockAddr)
berkeleyAccept IOManager
ioManager
      -- TODO: 'Socket.close' is interruptible by asynchronous exceptions; it
      -- should be fixed upstream, once that's done we can remove
      -- `uninterruptibleMask_'
    , close :: Socket -> IO ()
close    = IO () -> IO ()
forall a. IO a -> IO a
uninterruptibleMask_ (IO () -> IO ()) -> (Socket -> IO ()) -> Socket -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
Socket.close
    , toBearer :: DiffTime -> Tracer IO MuxTrace -> Socket -> IO (MuxBearer IO)
toBearer = (DiffTime -> Tracer IO MuxTrace -> Socket -> MuxBearer IO)
-> DiffTime -> Tracer IO MuxTrace -> Socket -> IO (MuxBearer IO)
forall (m :: * -> *) fd.
Monad m =>
(DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m)
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
pureBearer DiffTime -> Tracer IO MuxTrace -> Socket -> MuxBearer IO
Mx.socketAsMuxBearer
    }
  where
    openSocket :: AddressFamily SockAddr -> IO Socket
    openSocket :: AddressFamily SockAddr -> IO Socket
openSocket (SocketFamily Family
family_) = do
      Socket
sd <- Family -> SocketType -> CInt -> IO Socket
Socket.socket Family
family_ SocketType
Socket.Stream CInt
Socket.defaultProtocol
      IOManager -> Either Any Socket -> IO ()
IOManager -> forall hole. hole -> IO ()
associateWithIOManager IOManager
ioManager (Socket -> Either Any Socket
forall a b. b -> Either a b
Right Socket
sd)
        -- open is designed to be used in `bracket`, and thus it's called with
        -- async exceptions masked.  The 'associateWithIOCP' is a blocking
        -- operation and thus it may throw.
        IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(IOException
e :: IOException) -> do
          Socket -> IO ()
Socket.close Socket
sd
          IOException -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOException
e
        IO () -> (SomeAsyncException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(SomeAsyncException e
_) -> do
          Socket -> IO ()
Socket.close Socket
sd
          IOException -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOException
e
      Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sd



--
-- LocalSnockets either based on unix sockets or named pipes.
--

#if defined(mingw32_HOST_OS)
type LocalHandle = Win32.HANDLE
#else
type LocalHandle = Socket
#endif

-- | System dependent LocalSnocket type
--
#if defined(mingw32_HOST_OS)
data LocalSocket = LocalSocket { getLocalHandle :: !LocalHandle
                                 -- ^ underlying windows 'HANDLE'
                               , getLocalPath   :: !LocalAddress
                                 -- ^ original path, used when creating the handle
                               , getRemotePath  :: !LocalAddress
                                 -- ^ unique identifier (not a real path).  It
                                 -- makes the pair of local and remote
                                 -- addresses unique.
                               }
    deriving (Eq, Generic)
    deriving Show via Quiet LocalSocket
#else
newtype LocalSocket  = LocalSocket { LocalSocket -> Socket
getLocalHandle :: LocalHandle }
    deriving (LocalSocket -> LocalSocket -> Bool
(LocalSocket -> LocalSocket -> Bool)
-> (LocalSocket -> LocalSocket -> Bool) -> Eq LocalSocket
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LocalSocket -> LocalSocket -> Bool
$c/= :: LocalSocket -> LocalSocket -> Bool
== :: LocalSocket -> LocalSocket -> Bool
$c== :: LocalSocket -> LocalSocket -> Bool
Eq, (forall x. LocalSocket -> Rep LocalSocket x)
-> (forall x. Rep LocalSocket x -> LocalSocket)
-> Generic LocalSocket
forall x. Rep LocalSocket x -> LocalSocket
forall x. LocalSocket -> Rep LocalSocket x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep LocalSocket x -> LocalSocket
$cfrom :: forall x. LocalSocket -> Rep LocalSocket x
Generic)
    deriving Int -> LocalSocket -> String -> String
[LocalSocket] -> String -> String
LocalSocket -> String
(Int -> LocalSocket -> String -> String)
-> (LocalSocket -> String)
-> ([LocalSocket] -> String -> String)
-> Show LocalSocket
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LocalSocket] -> String -> String
$cshowList :: [LocalSocket] -> String -> String
show :: LocalSocket -> String
$cshow :: LocalSocket -> String
showsPrec :: Int -> LocalSocket -> String -> String
$cshowsPrec :: Int -> LocalSocket -> String -> String
Show via Quiet LocalSocket
#endif

-- | System dependent LocalSnocket
type    LocalSnocket = Snocket IO LocalSocket LocalAddress


-- | Create a 'LocalSnocket'.
--
-- On /Windows/, there is no way to get path associated to a named pipe.  To go
-- around this, the address passed to 'open' via 'LocalFamily' will be
-- referenced by 'LocalSocket'.
--
localSnocket :: IOManager -> LocalSnocket
#if defined(mingw32_HOST_OS)
localSnocket ioManager = Snocket {
      getLocalAddr  = return . getLocalPath
    , getRemoteAddr = return . getRemotePath
    , addrFamily    = LocalFamily

    , open = \(LocalFamily addr) -> do
        hpipe <- Win32.createNamedPipe
                   (getFilePath addr)
                   (Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED)
                   (Win32.pIPE_TYPE_BYTE     .|. Win32.pIPE_READMODE_BYTE)
                   Win32.pIPE_UNLIMITED_INSTANCES
                   65536   -- outbound pipe size
                   16384   -- inbound pipe size
                   0       -- default timeout
                   Nothing -- default security
        associateWithIOManager ioManager (Left hpipe)
          `catch` \(e :: IOException) -> do
            Win32.closeHandle hpipe
            throwIO e
          `catch` \(SomeAsyncException _) -> do
            Win32.closeHandle hpipe
            throwIO e
        pure (LocalSocket hpipe addr (LocalAddress ""))

    -- To connect, simply create a file whose name is the named pipe name.
    , openToConnect  = \(LocalAddress pipeName) -> do
        hpipe <- Win32.connect pipeName
                   (Win32.gENERIC_READ .|. Win32.gENERIC_WRITE )
                   Win32.fILE_SHARE_NONE
                   Nothing
                   Win32.oPEN_EXISTING
                   Win32.fILE_FLAG_OVERLAPPED
                   Nothing
        associateWithIOManager ioManager (Left hpipe)
          `catch` \(e :: IOException) -> do
            Win32.closeHandle hpipe
            throwIO e
          `catch` \(SomeAsyncException _) -> do
            Win32.closeHandle hpipe
            throwIO e
        return (LocalSocket hpipe (LocalAddress pipeName) (LocalAddress pipeName))
    , connect  = \_ _ -> pure ()

    -- Bind and listen are no-op.
    , bind     = \_ _ -> pure ()
    , listen   = \_ -> pure ()

    , accept   = \sock@(LocalSocket hpipe addr _) -> pure $ Accept $ do
          Win32.Async.connectNamedPipe hpipe
          return (Accepted sock addr, acceptNext 0 addr)

      -- Win32.closeHandle is not interruptible
    , close    = Win32.closeHandle . getLocalHandle

    , toBearer = \_sduTimeout tr -> pure . namedPipeAsBearer tr . getLocalHandle
    }
  where
    acceptNext :: Word64 -> LocalAddress -> Accept IO LocalSocket LocalAddress
    acceptNext !cnt addr = Accept (acceptOne `catch` handleIOException)
      where
        handleIOException
          :: IOException
          -> IO ( Accepted  LocalSocket LocalAddress
                , Accept IO LocalSocket LocalAddress
                )
        handleIOException err =
          pure ( AcceptFailure (toException err)
               , acceptNext (succ cnt) addr
               )

        acceptOne
          :: IO ( Accepted  LocalSocket LocalAddress
                , Accept IO LocalSocket LocalAddress
                )
        acceptOne =
          bracketOnError
            (Win32.createNamedPipe
                 (getFilePath addr)
                 (Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED)
                 (Win32.pIPE_TYPE_BYTE     .|. Win32.pIPE_READMODE_BYTE)
                 Win32.pIPE_UNLIMITED_INSTANCES
                 65536    -- outbound pipe size
                 16384    -- inbound pipe size
                 0        -- default timeout
                 Nothing) -- default security
             Win32.closeHandle
             $ \hpipe -> do
              associateWithIOManager ioManager (Left hpipe)
              Win32.Async.connectNamedPipe hpipe
              -- InboundGovernor/Server requires a unique address for the
              -- remote end one in order to track connections.
              -- So to differentiate clients we use a simple counter as the
              -- remote end's address.
              --
              let addr' = LocalAddress $ getFilePath addr ++ "@" ++ show cnt
              return (Accepted (LocalSocket hpipe addr addr') addr', acceptNext (succ cnt) addr)

-- local snocket on unix
#else

localSnocket :: IOManager -> LocalSnocket
localSnocket IOManager
ioManager =
    Snocket :: forall (m :: * -> *) fd addr.
(fd -> m addr)
-> (fd -> m addr)
-> (addr -> AddressFamily addr)
-> (AddressFamily addr -> m fd)
-> (addr -> m fd)
-> (fd -> addr -> m ())
-> (fd -> addr -> m ())
-> (fd -> m ())
-> (fd -> m (Accept m fd addr))
-> (fd -> m ())
-> (DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m))
-> Snocket m fd addr
Snocket {
        getLocalAddr :: LocalSocket -> IO LocalAddress
getLocalAddr  = (SockAddr -> LocalAddress) -> IO SockAddr -> IO LocalAddress
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SockAddr -> LocalAddress
toLocalAddress (IO SockAddr -> IO LocalAddress)
-> (LocalSocket -> IO SockAddr) -> LocalSocket -> IO LocalAddress
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO SockAddr
Socket.getSocketName (Socket -> IO SockAddr)
-> (LocalSocket -> Socket) -> LocalSocket -> IO SockAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
      , getRemoteAddr :: LocalSocket -> IO LocalAddress
getRemoteAddr = (SockAddr -> LocalAddress) -> IO SockAddr -> IO LocalAddress
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SockAddr -> LocalAddress
toLocalAddress (IO SockAddr -> IO LocalAddress)
-> (LocalSocket -> IO SockAddr) -> LocalSocket -> IO LocalAddress
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO SockAddr
Socket.getPeerName (Socket -> IO SockAddr)
-> (LocalSocket -> Socket) -> LocalSocket -> IO SockAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
      , addrFamily :: LocalAddress -> AddressFamily LocalAddress
addrFamily    = LocalAddress -> AddressFamily LocalAddress
LocalFamily
      , connect :: LocalSocket -> LocalAddress -> IO ()
connect       = \(LocalSocket Socket
s) LocalAddress
addr ->
          Socket -> SockAddr -> IO ()
Socket.connect Socket
s (LocalAddress -> SockAddr
fromLocalAddress LocalAddress
addr)
      , bind :: LocalSocket -> LocalAddress -> IO ()
bind          = \(LocalSocket Socket
fd) LocalAddress
addr -> Socket -> SockAddr -> IO ()
Socket.bind Socket
fd (LocalAddress -> SockAddr
fromLocalAddress LocalAddress
addr)
      , listen :: LocalSocket -> IO ()
listen        = (Socket -> Int -> IO ()) -> Int -> Socket -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> Int -> IO ()
Socket.listen Int
8 (Socket -> IO ())
-> (LocalSocket -> Socket) -> LocalSocket -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
      , accept :: LocalSocket -> IO (Accept IO LocalSocket LocalAddress)
accept        = (Accept IO Socket SockAddr -> Accept IO LocalSocket LocalAddress)
-> IO (Accept IO Socket SockAddr)
-> IO (Accept IO LocalSocket LocalAddress)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Socket -> LocalSocket)
-> (SockAddr -> LocalAddress)
-> Accept IO Socket SockAddr
-> Accept IO LocalSocket LocalAddress
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Socket -> LocalSocket
LocalSocket SockAddr -> LocalAddress
toLocalAddress)
                      (IO (Accept IO Socket SockAddr)
 -> IO (Accept IO LocalSocket LocalAddress))
-> (LocalSocket -> IO (Accept IO Socket SockAddr))
-> LocalSocket
-> IO (Accept IO LocalSocket LocalAddress)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOManager -> Socket -> IO (Accept IO Socket SockAddr)
berkeleyAccept IOManager
ioManager
                      (Socket -> IO (Accept IO Socket SockAddr))
-> (LocalSocket -> Socket)
-> LocalSocket
-> IO (Accept IO Socket SockAddr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
      , open :: AddressFamily LocalAddress -> IO LocalSocket
open          = AddressFamily LocalAddress -> IO LocalSocket
openSocket
      , openToConnect :: LocalAddress -> IO LocalSocket
openToConnect = \LocalAddress
addr -> AddressFamily LocalAddress -> IO LocalSocket
openSocket (LocalAddress -> AddressFamily LocalAddress
LocalFamily LocalAddress
addr)
      , close :: LocalSocket -> IO ()
close         = IO () -> IO ()
forall a. IO a -> IO a
uninterruptibleMask_ (IO () -> IO ()) -> (LocalSocket -> IO ()) -> LocalSocket -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
Socket.close (Socket -> IO ())
-> (LocalSocket -> Socket) -> LocalSocket -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
      , toBearer :: DiffTime -> Tracer IO MuxTrace -> LocalSocket -> IO (MuxBearer IO)
toBearer      = \DiffTime
df Tracer IO MuxTrace
tr (LocalSocket Socket
sd) -> MuxBearer IO -> IO (MuxBearer IO)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DiffTime -> Tracer IO MuxTrace -> Socket -> MuxBearer IO
Mx.socketAsMuxBearer DiffTime
df Tracer IO MuxTrace
tr Socket
sd)
      }
  where
    toLocalAddress :: SockAddr -> LocalAddress
    toLocalAddress :: SockAddr -> LocalAddress
toLocalAddress (SockAddrUnix String
path) = String -> LocalAddress
LocalAddress String
path
    toLocalAddress SockAddr
_                   = String -> LocalAddress
forall a. HasCallStack => String -> a
error String
"localSnocket.toLocalAddr: impossible happened"

    fromLocalAddress :: LocalAddress -> SockAddr
    fromLocalAddress :: LocalAddress -> SockAddr
fromLocalAddress = String -> SockAddr
SockAddrUnix (String -> SockAddr)
-> (LocalAddress -> String) -> LocalAddress -> SockAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalAddress -> String
getFilePath

    openSocket :: AddressFamily LocalAddress -> IO LocalSocket
    openSocket :: AddressFamily LocalAddress -> IO LocalSocket
openSocket (LocalFamily LocalAddress
_addr) = do
      Socket
sd <- Family -> SocketType -> CInt -> IO Socket
Socket.socket Family
AF_UNIX SocketType
Socket.Stream CInt
Socket.defaultProtocol
      IOManager -> Either Any Socket -> IO ()
IOManager -> forall hole. hole -> IO ()
associateWithIOManager IOManager
ioManager (Socket -> Either Any Socket
forall a b. b -> Either a b
Right Socket
sd)
        -- open is designed to be used in `bracket`, and thus it's called with
        -- async exceptions masked.  The 'associateWithIOManager' is a blocking
        -- operation and thus it may throw.
        IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(IOException
e :: IOException) -> do
          Socket -> IO ()
Socket.close Socket
sd
          IOException -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOException
e
        IO () -> (SomeAsyncException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(SomeAsyncException e
_) -> do
          Socket -> IO ()
Socket.close Socket
sd
          IOException -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOException
e
      LocalSocket -> IO LocalSocket
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> LocalSocket
LocalSocket Socket
sd)
#endif

localAddressFromPath :: FilePath -> LocalAddress
localAddressFromPath :: String -> LocalAddress
localAddressFromPath = String -> LocalAddress
LocalAddress

-- | Socket file descriptor.
--
newtype FileDescriptor = FileDescriptor { FileDescriptor -> Int
getFileDescriptor :: Int }
  deriving (forall x. FileDescriptor -> Rep FileDescriptor x)
-> (forall x. Rep FileDescriptor x -> FileDescriptor)
-> Generic FileDescriptor
forall x. Rep FileDescriptor x -> FileDescriptor
forall x. FileDescriptor -> Rep FileDescriptor x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep FileDescriptor x -> FileDescriptor
$cfrom :: forall x. FileDescriptor -> Rep FileDescriptor x
Generic
  deriving Int -> FileDescriptor -> String -> String
[FileDescriptor] -> String -> String
FileDescriptor -> String
(Int -> FileDescriptor -> String -> String)
-> (FileDescriptor -> String)
-> ([FileDescriptor] -> String -> String)
-> Show FileDescriptor
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [FileDescriptor] -> String -> String
$cshowList :: [FileDescriptor] -> String -> String
show :: FileDescriptor -> String
$cshow :: FileDescriptor -> String
showsPrec :: Int -> FileDescriptor -> String -> String
$cshowsPrec :: Int -> FileDescriptor -> String -> String
Show via Quiet FileDescriptor

-- | We use 'unsafeFdSocket' but 'FileDescriptor' constructor is not exposed.
-- This forbids any usage of 'FileDescriptor' (at least in a straightforward
-- way) using any low level functions which operate on file descriptors.
--
socketFileDescriptor :: Socket -> IO FileDescriptor
socketFileDescriptor :: Socket -> IO FileDescriptor
socketFileDescriptor = (CInt -> FileDescriptor) -> IO CInt -> IO FileDescriptor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> FileDescriptor
FileDescriptor (Int -> FileDescriptor) -> (CInt -> Int) -> CInt -> FileDescriptor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral) (IO CInt -> IO FileDescriptor)
-> (Socket -> IO CInt) -> Socket -> IO FileDescriptor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO CInt
Socket.unsafeFdSocket

localSocketFileDescriptor :: LocalSocket -> IO FileDescriptor
#if defined(mingw32_HOST_OS)
localSocketFileDescriptor =
  \(LocalSocket fd _ _) -> case ptrToIntPtr fd of
    IntPtr i -> return (FileDescriptor i)
#else
localSocketFileDescriptor :: LocalSocket -> IO FileDescriptor
localSocketFileDescriptor = Socket -> IO FileDescriptor
socketFileDescriptor (Socket -> IO FileDescriptor)
-> (LocalSocket -> Socket) -> LocalSocket -> IO FileDescriptor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocalSocket -> Socket
getLocalHandle
#endif