-- |
-- Module      : Network.TLS.Handshake.Certificate
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Handshake.Certificate
    ( certificateRejected
    , badCertificate
    , rejectOnException
    , verifyLeafKeyUsage
    , extractCAname
    ) where

import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.X509
import Control.Monad (unless)
import Control.Monad.State.Strict
import Control.Exception (SomeException)
import Data.X509 (ExtKeyUsage(..), ExtKeyUsageFlag, extensionGet)

-- on certificate reject, throw an exception with the proper protocol alert error.
certificateRejected :: MonadIO m => CertificateRejectReason -> m a
certificateRejected :: CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
CertificateRejectRevoked =
    TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"certificate is revoked", Bool
True, AlertDescription
CertificateRevoked)
certificateRejected CertificateRejectReason
CertificateRejectExpired =
    TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"certificate has expired", Bool
True, AlertDescription
CertificateExpired)
certificateRejected CertificateRejectReason
CertificateRejectUnknownCA =
    TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"certificate has unknown CA", Bool
True, AlertDescription
UnknownCa)
certificateRejected CertificateRejectReason
CertificateRejectAbsent =
    TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"certificate is missing", Bool
True, AlertDescription
CertificateRequired)
certificateRejected (CertificateRejectOther String
s) =
    TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"certificate rejected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s, Bool
True, AlertDescription
CertificateUnknown)

badCertificate :: MonadIO m => String -> m a
badCertificate :: String -> m a
badCertificate String
msg = TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
msg, Bool
True, AlertDescription
BadCertificate)

rejectOnException :: SomeException -> IO CertificateUsage
rejectOnException :: SomeException -> IO CertificateUsage
rejectOnException SomeException
e = CertificateUsage -> IO CertificateUsage
forall (m :: * -> *) a. Monad m => a -> m a
return (CertificateUsage -> IO CertificateUsage)
-> CertificateUsage -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$ CertificateRejectReason -> CertificateUsage
CertificateUsageReject (CertificateRejectReason -> CertificateUsage)
-> CertificateRejectReason -> CertificateUsage
forall a b. (a -> b) -> a -> b
$ String -> CertificateRejectReason
CertificateRejectOther (String -> CertificateRejectReason)
-> String -> CertificateRejectReason
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e

verifyLeafKeyUsage :: MonadIO m => [ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage :: [ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag]
_          (CertificateChain [])         = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
verifyLeafKeyUsage [ExtKeyUsageFlag]
validFlags (CertificateChain (SignedExact Certificate
signed:[SignedExact Certificate]
_)) =
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
verified (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> m a
badCertificate (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
        String
"certificate is not allowed for any of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [ExtKeyUsageFlag] -> String
forall a. Show a => a -> String
show [ExtKeyUsageFlag]
validFlags
  where
    cert :: Certificate
cert     = SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
signed
    verified :: Bool
verified =
        case Extensions -> Maybe ExtKeyUsage
forall a. Extension a => Extensions -> Maybe a
extensionGet (Certificate -> Extensions
certExtensions Certificate
cert) of
            Maybe ExtKeyUsage
Nothing                          -> Bool
True -- unrestricted cert
            Just (ExtKeyUsage [ExtKeyUsageFlag]
flags)         -> (ExtKeyUsageFlag -> Bool) -> [ExtKeyUsageFlag] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (ExtKeyUsageFlag -> [ExtKeyUsageFlag] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ExtKeyUsageFlag]
validFlags) [ExtKeyUsageFlag]
flags

extractCAname :: SignedCertificate -> DistinguishedName
extractCAname :: SignedExact Certificate -> DistinguishedName
extractCAname SignedExact Certificate
cert = Certificate -> DistinguishedName
certSubjectDN (Certificate -> DistinguishedName)
-> Certificate -> DistinguishedName
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
cert