-- |
-- Module      : Data.ASN1.BitArray
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
{-# LANGUAGE DeriveDataTypeable #-}
module Data.ASN1.BitArray
    ( BitArray(..)
    , BitArrayOutOfBound(..)
    , bitArrayLength
    , bitArrayGetBit
    , bitArraySetBitValue
    , bitArraySetBit
    , bitArrayClearBit
    , bitArrayGetData
    , toBitArray
    ) where

import Data.Bits
import Data.Word
import Data.Maybe
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Typeable
import Control.Exception (Exception, throw)

-- | throwed in case of out of bounds in the bitarray.
data BitArrayOutOfBound = BitArrayOutOfBound Word64
    deriving (Int -> BitArrayOutOfBound -> ShowS
[BitArrayOutOfBound] -> ShowS
BitArrayOutOfBound -> String
(Int -> BitArrayOutOfBound -> ShowS)
-> (BitArrayOutOfBound -> String)
-> ([BitArrayOutOfBound] -> ShowS)
-> Show BitArrayOutOfBound
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BitArrayOutOfBound] -> ShowS
$cshowList :: [BitArrayOutOfBound] -> ShowS
show :: BitArrayOutOfBound -> String
$cshow :: BitArrayOutOfBound -> String
showsPrec :: Int -> BitArrayOutOfBound -> ShowS
$cshowsPrec :: Int -> BitArrayOutOfBound -> ShowS
Show,BitArrayOutOfBound -> BitArrayOutOfBound -> Bool
(BitArrayOutOfBound -> BitArrayOutOfBound -> Bool)
-> (BitArrayOutOfBound -> BitArrayOutOfBound -> Bool)
-> Eq BitArrayOutOfBound
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BitArrayOutOfBound -> BitArrayOutOfBound -> Bool
$c/= :: BitArrayOutOfBound -> BitArrayOutOfBound -> Bool
== :: BitArrayOutOfBound -> BitArrayOutOfBound -> Bool
$c== :: BitArrayOutOfBound -> BitArrayOutOfBound -> Bool
Eq,Typeable)
instance Exception BitArrayOutOfBound

-- | represent a bitarray / bitmap
--
-- the memory representation start at bit 0
data BitArray = BitArray Word64 ByteString
    deriving (Int -> BitArray -> ShowS
[BitArray] -> ShowS
BitArray -> String
(Int -> BitArray -> ShowS)
-> (BitArray -> String) -> ([BitArray] -> ShowS) -> Show BitArray
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BitArray] -> ShowS
$cshowList :: [BitArray] -> ShowS
show :: BitArray -> String
$cshow :: BitArray -> String
showsPrec :: Int -> BitArray -> ShowS
$cshowsPrec :: Int -> BitArray -> ShowS
Show,BitArray -> BitArray -> Bool
(BitArray -> BitArray -> Bool)
-> (BitArray -> BitArray -> Bool) -> Eq BitArray
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BitArray -> BitArray -> Bool
$c/= :: BitArray -> BitArray -> Bool
== :: BitArray -> BitArray -> Bool
$c== :: BitArray -> BitArray -> Bool
Eq)

-- | returns the length of bits in this bitarray
bitArrayLength :: BitArray -> Word64
bitArrayLength :: BitArray -> Word64
bitArrayLength (BitArray Word64
l ByteString
_) = Word64
l

bitArrayOutOfBound :: Word64 -> a
bitArrayOutOfBound :: Word64 -> a
bitArrayOutOfBound Word64
n = BitArrayOutOfBound -> a
forall a e. Exception e => e -> a
throw (BitArrayOutOfBound -> a) -> BitArrayOutOfBound -> a
forall a b. (a -> b) -> a -> b
$ Word64 -> BitArrayOutOfBound
BitArrayOutOfBound Word64
n

-- | get the nth bits
bitArrayGetBit :: BitArray -> Word64 -> Bool
bitArrayGetBit :: BitArray -> Word64 -> Bool
bitArrayGetBit (BitArray Word64
l ByteString
d) Word64
n
    | Word64
n Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
l    = Word64 -> Bool
forall a. Word64 -> a
bitArrayOutOfBound Word64
n
    | Bool
otherwise = (Word8 -> Int -> Bool) -> Int -> Word8 -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit (Int
7Int -> Int -> Int
forall a. Num a => a -> a -> a
-Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
bitn) (Word8 -> Bool) -> Word8 -> Bool
forall a b. (a -> b) -> a -> b
$ ByteString -> Int -> Word8
B.index ByteString
d (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
offset)
        where (Word64
offset, Word64
bitn) = Word64
n Word64 -> Word64 -> (Word64, Word64)
forall a. Integral a => a -> a -> (a, a)
`divMod` Word64
8

-- | set the nth bit to the value specified
bitArraySetBitValue :: BitArray -> Word64 -> Bool -> BitArray
bitArraySetBitValue :: BitArray -> Word64 -> Bool -> BitArray
bitArraySetBitValue (BitArray Word64
l ByteString
d) Word64
n Bool
v
    | Word64
n Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
l    = Word64 -> BitArray
forall a. Word64 -> a
bitArrayOutOfBound Word64
n
    | Bool
otherwise =
        let (ByteString
before,ByteString
after) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
offset) ByteString
d in
        -- array bound check before prevent fromJust from failing.
        let (Word8
w,ByteString
remaining) = Maybe (Word8, ByteString) -> (Word8, ByteString)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Word8, ByteString) -> (Word8, ByteString))
-> Maybe (Word8, ByteString) -> (Word8, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
after in
        Word64 -> ByteString -> BitArray
BitArray Word64
l (ByteString
before ByteString -> ByteString -> ByteString
`B.append` (Word8 -> Int -> Word8
setter Word8
w (Int
7Int -> Int -> Int
forall a. Num a => a -> a -> a
-Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
bitn) Word8 -> ByteString -> ByteString
`B.cons` ByteString
remaining))
  where
        (Word64
offset, Word64
bitn) = Word64
n Word64 -> Word64 -> (Word64, Word64)
forall a. Integral a => a -> a -> (a, a)
`divMod` Word64
8
        setter :: Word8 -> Int -> Word8
setter = if Bool
v then Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
setBit else Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
clearBit

-- | set the nth bits
bitArraySetBit :: BitArray -> Word64 -> BitArray
bitArraySetBit :: BitArray -> Word64 -> BitArray
bitArraySetBit BitArray
bitarray Word64
n = BitArray -> Word64 -> Bool -> BitArray
bitArraySetBitValue BitArray
bitarray Word64
n Bool
True

-- | clear the nth bits
bitArrayClearBit :: BitArray -> Word64 -> BitArray
bitArrayClearBit :: BitArray -> Word64 -> BitArray
bitArrayClearBit BitArray
bitarray Word64
n = BitArray -> Word64 -> Bool -> BitArray
bitArraySetBitValue BitArray
bitarray Word64
n Bool
False

-- | get padded bytestring of the bitarray
bitArrayGetData :: BitArray -> ByteString
bitArrayGetData :: BitArray -> ByteString
bitArrayGetData (BitArray Word64
_ ByteString
d) = ByteString
d

-- | number of bit to skip at the end (padding)
toBitArray :: ByteString -> Int -> BitArray
toBitArray :: ByteString -> Int -> BitArray
toBitArray ByteString
l Int
toSkip =
    Word64 -> ByteString -> BitArray
BitArray (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
l Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
toSkip)) ByteString
l