module Cardano.Ledger.NonIntegral
  ( (***),
    exp',
    ln',
    findE,
    splitLn,
    scaleExp,
    CompareResult (..),
    taylorExpCmp,
  )
where

data CompareResult a
  = BELOW a Int
  | ABOVE a Int
  | MaxReached Int
  deriving (Int -> CompareResult a -> ShowS
[CompareResult a] -> ShowS
CompareResult a -> String
(Int -> CompareResult a -> ShowS)
-> (CompareResult a -> String)
-> ([CompareResult a] -> ShowS)
-> Show (CompareResult a)
forall a. Show a => Int -> CompareResult a -> ShowS
forall a. Show a => [CompareResult a] -> ShowS
forall a. Show a => CompareResult a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CompareResult a] -> ShowS
$cshowList :: forall a. Show a => [CompareResult a] -> ShowS
show :: CompareResult a -> String
$cshow :: forall a. Show a => CompareResult a -> String
showsPrec :: Int -> CompareResult a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> CompareResult a -> ShowS
Show, CompareResult a -> CompareResult a -> Bool
(CompareResult a -> CompareResult a -> Bool)
-> (CompareResult a -> CompareResult a -> Bool)
-> Eq (CompareResult a)
forall a. Eq a => CompareResult a -> CompareResult a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompareResult a -> CompareResult a -> Bool
$c/= :: forall a. Eq a => CompareResult a -> CompareResult a -> Bool
== :: CompareResult a -> CompareResult a -> Bool
$c== :: forall a. Eq a => CompareResult a -> CompareResult a -> Bool
Eq)

scaleExp :: (RealFrac a) => a -> (Integer, a)
scaleExp :: a -> (Integer, a)
scaleExp a
x = (Integer
x', a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
x')
  where
    x' :: Integer
x' = a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
ceiling a
x

-- | Exponentiation
(***) :: (RealFrac a, Enum a, Show a) => a -> a -> a
a
a *** :: a -> a -> a
*** a
b
  | a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
1
  | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
0
  | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 = a
1
  | Bool
otherwise = a -> a
forall a. (RealFrac a, Show a) => a -> a
exp' (a
b a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. (RealFrac a, Enum a, Show a) => a -> a
ln' a
a)

ipow' :: Num a => a -> Integer -> a
ipow' :: a -> Integer -> a
ipow' a
x Integer
n
  | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = a
1
  | Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = let y :: a
y = a -> Integer -> a
forall a. Num a => a -> Integer -> a
ipow' a
x Integer
d in a
y a -> a -> a
forall a. Num a => a -> a -> a
* a
y
  | Bool
otherwise = a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> Integer -> a
forall a. Num a => a -> Integer -> a
ipow' a
x (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
  where
    (Integer
d, Integer
m) = Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
divMod Integer
n Integer
2

ipow :: Fractional a => a -> Integer -> a
ipow :: a -> Integer -> a
ipow a
x Integer
n
  | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> Integer -> a
forall a. Num a => a -> Integer -> a
ipow' a
x (-Integer
n)
  | Bool
otherwise = a -> Integer -> a
forall a. Num a => a -> Integer -> a
ipow' a
x Integer
n

logAs :: (Num a) => a -> [a]
logAs :: a -> [a]
logAs a
a = a
a' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a
a' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a -> [a]
forall a. Num a => a -> [a]
logAs (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
  where
    a' :: a
a' = a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
a

-- | Approximate ln(1+x) for x \in [0, \infty)
-- a_1 = x, a_{2k} = a_{2k+1} = x·k^2, k >= 1
-- b_n = n, n >= 0
lncf :: (Fractional a, Enum a, Ord a, Show a) => Int -> a -> a
lncf :: Int -> a -> a
lncf Int
maxN a
x
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = String -> a
forall a. HasCallStack => String -> a
error (String
"x = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not inside domain [0,..)")
  | Bool
otherwise = Int -> Int -> a -> Maybe a -> a -> a -> a -> a -> [a] -> [a] -> a
forall a.
(Fractional a, Ord a, Show a) =>
Int -> Int -> a -> Maybe a -> a -> a -> a -> a -> [a] -> [a] -> a
cf Int
maxN Int
0 a
forall a. Fractional a => a
eps Maybe a
forall a. Maybe a
Nothing a
1 a
0 a
0 a
1 [a]
as [a
1, a
2 ..]
  where
    as :: [a]
as = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> a
forall a. Num a => a -> a -> a
* a
x) (a -> [a]
forall a. Num a => a -> [a]
logAs a
1)

eps :: (Fractional a) => a
eps :: a
eps = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
10 a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
24 :: Int)

-- | Compute continued fraction using max steps or bounded list of a/b factors.
-- The 'maxN' parameter gives the maximum recursion depth, 'n' gives the current
-- rursion depth, 'lastVal' is the optional last value ('Nothing' for the first
-- iteration). 'aNm2' / 'bNm2' are A_{n-2} / B_{n-2}, 'aNm1' / 'bNm1' are
-- A_{n-1} / B_{n-1}, and 'aN' / 'bN' are A_n / B_n respectively, 'an' / 'bn'
-- are lists of succecsive a_n / b_n values for the recurrence relation:
--
-- A_{-1} = 1,    A_0 = b_0
-- B_{-1} = 0,    B_0 = 1
-- A_n = b_n*A_{n-1} + a_n*A_{n-2}
-- B_n = b_n*B_{n-1} + a_n*B_{n-2}
--
-- The convergent 'xn' is calculated as x_n = A_n/B_n
--
--                        a_1
-- result = b_0 + ---------------------
--                           a_2
--                b_1 + ---------------
--                              a_3
--                      b_2 + ---------
--                                  .
--                            b_3 +  .
--                                    .
--
-- The recursion stops once 'maxN' iterations have been reached, or either the
-- list 'as' or 'bs' is exhausted or 'lastVal' differs less than 'epsilon' from the
-- new convergent.
cf ::
  (Fractional a, Ord a, Show a) =>
  Int ->
  Int ->
  a ->
  Maybe a ->
  a ->
  a ->
  a ->
  a ->
  [a] ->
  [a] ->
  a
cf :: Int -> Int -> a -> Maybe a -> a -> a -> a -> a -> [a] -> [a] -> a
cf Int
maxN Int
n a
epsilon Maybe a
lastVal a
aNm2 a
bNm2 a
aNm1 a
bNm1 (a
an : [a]
as) (a
bn : [a]
bs)
  | Int
maxN Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = a
xn
  | Bool
converges = a
xn
  | Bool
otherwise = Int -> Int -> a -> Maybe a -> a -> a -> a -> a -> [a] -> [a] -> a
forall a.
(Fractional a, Ord a, Show a) =>
Int -> Int -> a -> Maybe a -> a -> a -> a -> a -> [a] -> [a] -> a
cf Int
maxN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
epsilon (a -> Maybe a
forall a. a -> Maybe a
Just a
xn) a
aNm1 a
bNm1 a
aN a
bN [a]
as [a]
bs
  where
    converges :: Bool
converges = Bool -> (a -> Bool) -> Maybe a -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (\a
x -> a -> a
forall a. Num a => a -> a
abs (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
xn) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
epsilon) Maybe a
lastVal
    xn :: a
xn = a
aN a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
bN -- convergent
    aN :: a
aN = a
bn a -> a -> a
forall a. Num a => a -> a -> a
* a
aNm1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
an a -> a -> a
forall a. Num a => a -> a -> a
* a
aNm2
    bN :: a
bN = a
bn a -> a -> a
forall a. Num a => a -> a -> a
* a
bNm1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
an a -> a -> a
forall a. Num a => a -> a -> a
* a
bNm2
cf Int
_ Int
_ a
_ Maybe a
_ a
_ a
_ a
aN a
bN [a]
_ [a]
_ = a
aN a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
bN

-- | Simple way to find integer powers that bound x. At every step the bounds
-- are doubled. Assumption x > 0, the calculated bound is `factor^l <= x <=
-- factor^u`, initially x' is assumed to be `1/factor` and x'' `factor`, l = -1
-- and u = 1.
bound ::
  (Fractional a, Ord a) =>
  a ->
  a ->
  a ->
  a ->
  Integer ->
  Integer ->
  (Integer, Integer)
bound :: a -> a -> a -> a -> Integer -> Integer -> (Integer, Integer)
bound a
factor a
x a
x' a
x'' Integer
l Integer
u
  | a
x' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
x Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
x'' = (Integer
l, Integer
u)
  | Bool
otherwise = a -> a -> a -> a -> Integer -> Integer -> (Integer, Integer)
forall a.
(Fractional a, Ord a) =>
a -> a -> a -> a -> Integer -> Integer -> (Integer, Integer)
bound a
factor a
x (a
x' a -> a -> a
forall a. Num a => a -> a -> a
* a
x') (a
x'' a -> a -> a
forall a. Num a => a -> a -> a
* a
x'') (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
l) (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
u)

-- | Bisect bounds to find the smallest integer power such that
-- `factor^n<=x<factor^(n+1)`.
contract ::
  (Fractional a, Ord a) =>
  a ->
  a ->
  Integer ->
  Integer ->
  Integer
contract :: a -> a -> Integer -> Integer -> Integer
contract a
factor a
x = Integer -> Integer -> Integer
go
  where
    go :: Integer -> Integer -> Integer
go Integer
l Integer
u
      | Integer
l Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
u = Integer
l
      | Bool
otherwise =
          if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x'
            then Integer -> Integer -> Integer
go Integer
l Integer
mid
            else Integer -> Integer -> Integer
go Integer
mid Integer
u
      where
        mid :: Integer
mid = Integer
l Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ ((Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
l) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2)
        x' :: a
x' = a -> Integer -> a
forall a. Fractional a => a -> Integer -> a
ipow a
factor Integer
mid

exp1 :: (RealFrac a, Show a) => a
exp1 :: a
exp1 = a -> a
forall a. (RealFrac a, Show a) => a -> a
exp' a
1

-- | find n with `e^n<=x<e^(n+1)`
findE :: (RealFrac a) => a -> a -> Integer
findE :: a -> a -> Integer
findE a
e a
x = a -> a -> Integer -> Integer -> Integer
forall a.
(Fractional a, Ord a) =>
a -> a -> Integer -> Integer -> Integer
contract a
e a
x Integer
lower Integer
upper
  where
    (Integer
lower, Integer
upper) = a -> a -> a -> a -> Integer -> Integer -> (Integer, Integer)
forall a.
(Fractional a, Ord a) =>
a -> a -> a -> a -> Integer -> Integer -> (Integer, Integer)
bound a
e a
x (a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
e) a
e (-Integer
1) Integer
1

-- | Compute natural logarithm via continued fraction, first splitting integral
-- part and then using continued fractions approximation for `ln(1+x)`
ln' :: (RealFrac a, Enum a, Show a) => a -> a
ln' :: a -> a
ln' a
x
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 = String -> a
forall a. HasCallStack => String -> a
error (a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not in domain of ln")
  | Bool
otherwise = Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> a -> a
forall a. (Fractional a, Enum a, Ord a, Show a) => Int -> a -> a
lncf Int
1000 a
x'
  where
    (Integer
n, a
x') = a -> (Integer, a)
forall a. (RealFrac a, Show a) => a -> (Integer, a)
splitLn a
x

splitLn :: (RealFrac a, Show a) => a -> (Integer, a)
splitLn :: a -> (Integer, a)
splitLn a
x = (Integer
n, a
x')
  where
    n :: Integer
n = a -> a -> Integer
forall a. RealFrac a => a -> a -> Integer
findE a
forall a. (RealFrac a, Show a) => a
exp1 a
x
    y' :: a
y' = a -> Integer -> a
forall a. Fractional a => a -> Integer -> a
ipow a
forall a. (RealFrac a, Show a) => a
exp1 Integer
n
    x' :: a
x' = (a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y') a -> a -> a
forall a. Num a => a -> a -> a
- a
1 -- x / e^n > 1!

exp' :: (RealFrac a, Show a) => a -> a
exp' :: a -> a
exp' a
x
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. (RealFrac a, Show a) => a -> a
exp' (-a
x)
  | Bool
otherwise = a -> Integer -> a
forall a. Fractional a => a -> Integer -> a
ipow a
x' Integer
n
  where
    (Integer
n, a
x_) = a -> (Integer, a)
forall a. RealFrac a => a -> (Integer, a)
scaleExp a
x
    x' :: a
x' = Int -> Int -> a -> a -> a -> a -> a
forall a.
(RealFrac a, Show a) =>
Int -> Int -> a -> a -> a -> a -> a
taylorExp Int
1000 Int
1 a
x_ a
1 a
1 a
1

taylorExp :: (RealFrac a, Show a) => Int -> Int -> a -> a -> a -> a -> a
taylorExp :: Int -> Int -> a -> a -> a -> a -> a
taylorExp Int
maxN Int
n a
x a
lastX a
acc a
divisor
  | Int
maxN Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = a
acc
  | a -> a
forall a. Num a => a -> a
abs a
nextX a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
forall a. Fractional a => a
eps = a
acc
  | Bool
otherwise = Int -> Int -> a -> a -> a -> a -> a
forall a.
(RealFrac a, Show a) =>
Int -> Int -> a -> a -> a -> a -> a
taylorExp Int
maxN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
x a
nextX (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
nextX) (a
divisor a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
  where
    nextX :: a
nextX = (a
lastX a -> a -> a
forall a. Num a => a -> a -> a
* a
x) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
divisor

-- | Efficient way to compare the result of the Taylor expansion of the
-- exponential function to a threshold value. Using error estimation one can
-- stop early, once it's known the result will certainly be above or below the
-- target value.
taylorExpCmp :: (RealFrac a) => a -> a -> a -> CompareResult a
taylorExpCmp :: a -> a -> a -> CompareResult a
taylorExpCmp a
boundX a
cmp a
x = Int -> Int -> a -> a -> a -> CompareResult a
go Int
1000 Int
0 a
x a
1 a
1
  where
    go :: Int -> Int -> a -> a -> a -> CompareResult a
go Int
maxN Int
n a
err a
acc a
divisor
      | Int
maxN Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = Int -> CompareResult a
forall a. Int -> CompareResult a
MaxReached Int
n
      | a
cmp a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
acc' a -> a -> a
forall a. Num a => a -> a -> a
+ a
errorTerm = a -> Int -> CompareResult a
forall a. a -> Int -> CompareResult a
ABOVE a
acc' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | a
cmp a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
acc' a -> a -> a
forall a. Num a => a -> a -> a
- a
errorTerm = a -> Int -> CompareResult a
forall a. a -> Int -> CompareResult a
BELOW a
acc' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = Int -> Int -> a -> a -> a -> CompareResult a
go Int
maxN (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
err' a
acc' a
divisor'
      where
        errorTerm :: a
errorTerm = a -> a
forall a. Num a => a -> a
abs (a
err' a -> a -> a
forall a. Num a => a -> a -> a
* a
boundX)
        divisor' :: a
divisor' = a
divisor a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
        nextX :: a
nextX = a
err
        err' :: a
err' = (a
err a -> a -> a
forall a. Num a => a -> a -> a
* a
x) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
divisor'
        acc' :: a
acc' = a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
nextX