-- |
-- Module    : Numeric.MathFunctions.Comparison
-- Copyright : (c) 2011 Bryan O'Sullivan
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Functions for approximate comparison of floating point numbers.
--
-- Approximate floating point comparison, based on Bruce Dawson's
-- \"Comparing floating point numbers\":
-- <http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm>
module Numeric.MathFunctions.Comparison
    ( -- * Relative erros
      relativeError
    , eqRelErr
      -- * Ulps-based comparison
    , addUlps
    , ulpDistance
    , ulpDelta
    , within
    ) where

import Control.Monad.ST (runST)
import Data.Primitive.ByteArray (newByteArray, readByteArray, writeByteArray)
import Data.Word (Word64)
import Data.Int (Int64)



----------------------------------------------------------------
-- Ulps-based comparison
----------------------------------------------------------------

-- |
-- Calculate relative error of two numbers:
--
-- \[ \frac{|a - b|}{\max(|a|,|b|)} \]
--
-- It lies in [0,1) interval for numbers with same sign and (1,2] for
-- numbers with different sign. If both arguments are zero or negative
-- zero function returns 0. If at least one argument is transfinite it
-- returns NaN
relativeError :: Double -> Double -> Double
relativeError :: Double -> Double -> Double
relativeError Double
a Double
b
  | Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 Bool -> Bool -> Bool
&& Double
b Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 = Double
0
  | Bool
otherwise        = Double -> Double
forall a. Num a => a -> a
abs (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
b) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double -> Double
forall a. Ord a => a -> a -> a
max (Double -> Double
forall a. Num a => a -> a
abs Double
a) (Double -> Double
forall a. Num a => a -> a
abs Double
b)

-- | Check that relative error between two numbers @a@ and @b@. If
-- 'relativeError' returns NaN it returns @False@.
eqRelErr :: Double -- ^ /eps/ relative error should be in [0,1) range
         -> Double -- ^ /a/
         -> Double -- ^ /b/
         -> Bool
eqRelErr :: Double -> Double -> Double -> Bool
eqRelErr Double
eps Double
a Double
b = Double -> Double -> Double
relativeError Double
a Double
b Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
eps



----------------------------------------------------------------
-- Ulps-based comparison
----------------------------------------------------------------

-- |
-- Add N ULPs (units of least precision) to @Double@ number.
addUlps :: Int -> Double -> Double
addUlps :: Int -> Double -> Double
addUlps Int
n Double
a = (forall s. ST s Double) -> Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Double) -> Double)
-> (forall s. ST s Double) -> Double
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray s
buf <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
8
  Word64
ai0 <- MutableByteArray (PrimState (ST s)) -> Int -> Double -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Double
a ST s () -> ST s Word64 -> ST s Word64
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0
  -- Convert to ulps number represented as Int64
  let big :: Word64
big     = Word64
0x8000000000000000
      order :: Word64 -> Int64
      order :: Word64 -> Int64
order Word64
i | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
big   = Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
i
              | Bool
otherwise = Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int64) -> Word64 -> Int64
forall a b. (a -> b) -> a -> b
$ Word64
forall a. Bounded a => a
maxBound Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- (Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
big)
      unorder :: Int64 -> Word64
      unorder :: Int64 -> Word64
unorder Int64
i | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0    = Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
i
                | Bool
otherwise = Word64
big Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
forall a. Bounded a => a
maxBound Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- (Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
i))
  let ai0' :: Word64
ai0' = Int64 -> Word64
unorder (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ Word64 -> Int64
order Word64
ai0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
  MutableByteArray (PrimState (ST s)) -> Int -> Word64 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Word64
ai0' ST s () -> ST s Double -> ST s Double
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Double
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0

-- |
-- Measure distance between two @Double@s in ULPs (units of least
-- precision). Note that it's different from @abs (ulpDelta a b)@
-- since it returns correct result even when 'ulpDelta' overflows.
ulpDistance :: Double
            -> Double
            -> Word64
ulpDistance :: Double -> Double -> Word64
ulpDistance Double
a Double
b = (forall s. ST s Word64) -> Word64
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Word64) -> Word64)
-> (forall s. ST s Word64) -> Word64
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray s
buf <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
8
  Word64
ai0 <- MutableByteArray (PrimState (ST s)) -> Int -> Double -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Double
a ST s () -> ST s Word64 -> ST s Word64
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0
  Word64
bi0 <- MutableByteArray (PrimState (ST s)) -> Int -> Double -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Double
b ST s () -> ST s Word64 -> ST s Word64
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0
  -- IEEE754 floats use most significant bit as sign bit (not
  -- 2-complement) and we need to rearrange representations of float
  -- number so that they could be compared lexicographically as
  -- Word64.
  let big :: Word64
big     = Word64
0x8000000000000000
      order :: Word64 -> Word64
order Word64
i | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
big   = Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
big
              | Bool
otherwise = Word64
forall a. Bounded a => a
maxBound Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
i
      ai :: Word64
ai = Word64 -> Word64
order Word64
ai0
      bi :: Word64
bi = Word64 -> Word64
order Word64
bi0
      d :: Word64
d  | Word64
ai Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
bi   = Word64
ai Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
bi
         | Bool
otherwise = Word64
bi Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
ai
  Word64 -> ST s Word64
forall (m :: * -> *) a. Monad m => a -> m a
return (Word64 -> ST s Word64) -> Word64 -> ST s Word64
forall a b. (a -> b) -> a -> b
$! Word64
d

-- |
-- Measure signed distance between two @Double@s in ULPs (units of least
-- precision). Note that unlike 'ulpDistance' it can overflow.
--
-- > >>> ulpDelta 1 (1 + m_epsilon)
-- > 1
ulpDelta :: Double
         -> Double
         -> Int64
ulpDelta :: Double -> Double -> Int64
ulpDelta Double
a Double
b = (forall s. ST s Int64) -> Int64
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int64) -> Int64)
-> (forall s. ST s Int64) -> Int64
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray s
buf <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
8
  Word64
ai0 <- MutableByteArray (PrimState (ST s)) -> Int -> Double -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Double
a ST s () -> ST s Word64 -> ST s Word64
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0
  Word64
bi0 <- MutableByteArray (PrimState (ST s)) -> Int -> Double -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0 Double
b ST s () -> ST s Word64 -> ST s Word64
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray (PrimState (ST s)) -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
buf Int
0
  -- IEEE754 floats use most significant bit as sign bit (not
  -- 2-complement) and we need to rearrange representations of float
  -- number so that they could be compared lexicographically as
  -- Word64.
  let big :: Word64
big     = Word64
0x8000000000000000 :: Word64
      order :: Word64 -> Word64
order Word64
i | Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
big   = Word64
i Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
big
              | Bool
otherwise = Word64
forall a. Bounded a => a
maxBound Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
i
      ai :: Word64
ai = Word64 -> Word64
order Word64
ai0
      bi :: Word64
bi = Word64 -> Word64
order Word64
bi0
  Int64 -> ST s Int64
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> ST s Int64) -> Int64 -> ST s Int64
forall a b. (a -> b) -> a -> b
$! Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int64) -> Word64 -> Int64
forall a b. (a -> b) -> a -> b
$ Word64
bi Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
ai


-- | Compare two 'Double' values for approximate equality, using
-- Dawson's method.
--
-- The required accuracy is specified in ULPs (units of least
-- precision).  If the two numbers differ by the given number of ULPs
-- or less, this function returns @True@.
within :: Int                   -- ^ Number of ULPs of accuracy desired.
       -> Double -> Double -> Bool
within :: Int -> Double -> Double -> Bool
within Int
ulps Double
a Double
b
  | Int
ulps Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0  = Bool
False
  | Bool
otherwise = Double -> Double -> Word64
ulpDistance Double
a Double
b Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ulps