-- |
-- Module      : Crypto.Cipher.RC4
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : Good
--
-- Simple implementation of the RC4 stream cipher.
-- http://en.wikipedia.org/wiki/RC4
--
-- Initial FFI implementation by Peter White <peter@janrain.com>
--
-- Reorganized and simplified to have an opaque context.
--
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.RC4
    ( initialize
    , combine
    , generate
    , State
    ) where

import           Data.Word
import           Foreign.Ptr
import           Crypto.Internal.ByteArray (ScrubbedBytes, ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B

import           Crypto.Internal.Compat
import           Crypto.Internal.Imports

-- | The encryption state for RC4
newtype State = State ScrubbedBytes
    deriving (State -> Int
State -> Ptr p -> IO ()
State -> (Ptr p -> IO a) -> IO a
(State -> Int)
-> (forall p a. State -> (Ptr p -> IO a) -> IO a)
-> (forall p. State -> Ptr p -> IO ())
-> ByteArrayAccess State
forall p. State -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. State -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: State -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall p. State -> Ptr p -> IO ()
withByteArray :: State -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall p a. State -> (Ptr p -> IO a) -> IO a
length :: State -> Int
$clength :: State -> Int
ByteArrayAccess,State -> ()
(State -> ()) -> NFData State
forall a. (a -> ()) -> NFData a
rnf :: State -> ()
$crnf :: State -> ()
NFData)

-- | C Call for initializing the encryptor
foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_init"
    c_rc4_init :: Ptr Word8 -- ^ The rc4 key
               -> Word32    -- ^ The key length
               -> Ptr State -- ^ The context
               -> IO ()

foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_combine"
    c_rc4_combine :: Ptr State        -- ^ Pointer to the permutation
                  -> Ptr Word8      -- ^ Pointer to the clear text
                  -> Word32         -- ^ Length of the clear text
                  -> Ptr Word8      -- ^ Output buffer
                  -> IO ()

-- | RC4 context initialization.
--
-- seed the context with an initial key. the key size need to be
-- adequate otherwise security takes a hit.
initialize :: ByteArrayAccess key
           => key   -- ^ The key
           -> State -- ^ The RC4 context with the key mixed in
initialize :: key -> State
initialize key
key = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
    ScrubbedBytes
st <- Int -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
264 ((Ptr Any -> IO ()) -> IO ScrubbedBytes)
-> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Any
stPtr ->
        key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr -> Ptr Word8 -> Word32 -> Ptr State -> IO ()
c_rc4_init Ptr Word8
keyPtr (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key) (Ptr Any -> Ptr State
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
stPtr)
    State -> IO State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
st

-- | generate the next len bytes of the rc4 stream without combining
-- it to anything.
generate :: ByteArray ba => State -> Int -> (State, ba)
generate :: State -> Int -> (State, ba)
generate State
ctx Int
len = State -> ba -> (State, ba)
forall ba. ByteArray ba => State -> ba -> (State, ba)
combine State
ctx (Int -> ba
forall ba. ByteArray ba => Int -> ba
B.zero Int
len)

-- | RC4 xor combination of the rc4 stream with an input
combine :: ByteArray ba
        => State               -- ^ rc4 context
        -> ba                  -- ^ input
        -> (State, ba)         -- ^ new rc4 context, and the output
combine :: State -> ba -> (State, ba)
combine (State ScrubbedBytes
prevSt) ba
clearText = IO (State, ba) -> (State, ba)
forall a. IO a -> a
unsafeDoIO (IO (State, ba) -> (State, ba)) -> IO (State, ba) -> (State, ba)
forall a b. (a -> b) -> a -> b
$
    Int -> (Ptr Word8 -> IO State) -> IO (State, ba)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
len            ((Ptr Word8 -> IO State) -> IO (State, ba))
-> (Ptr Word8 -> IO State) -> IO (State, ba)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outptr ->
    ba -> (Ptr Word8 -> IO State) -> IO State
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
clearText ((Ptr Word8 -> IO State) -> IO State)
-> (Ptr Word8 -> IO State) -> IO State
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
clearPtr -> do
        ScrubbedBytes
st <- ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
prevSt ((Ptr Any -> IO ()) -> IO ScrubbedBytes)
-> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Any
stPtr ->
                Ptr State -> Ptr Word8 -> Word32 -> Ptr Word8 -> IO ()
c_rc4_combine (Ptr Any -> Ptr State
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
stPtr) Ptr Word8
clearPtr (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Ptr Word8
outptr
        State -> IO State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$! ScrubbedBytes -> State
State ScrubbedBytes
st
    --return $! (State st, B.PS outfptr 0 len)
  where len :: Int
len = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
clearText