{-# LANGUAGE BangPatterns, FlexibleContexts #-}
module Statistics.Transform
(
CD
, dct
, dct_
, idct
, idct_
, fft
, ifft
) where
import Control.Monad (when)
import Control.Monad.ST (ST)
import Data.Bits (shiftL, shiftR)
import Data.Complex (Complex(..), conjugate, realPart)
import Numeric.SpecFunctions (log2)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
type CD = Complex Double
dct :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v Double -> v Double
dct :: v Double -> v Double
dct = v CD -> v Double
forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker (v CD -> v Double) -> (v Double -> v CD) -> v Double -> v Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> CD) -> v Double -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (Double -> Double -> CD
forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE dct #-}
{-# SPECIAlIZE dct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE dct :: V.Vector Double -> V.Vector Double #-}
dct_ :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
dct_ :: v CD -> v Double
dct_ = v CD -> v Double
forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker (v CD -> v Double) -> (v CD -> v CD) -> v CD -> v Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CD -> CD) -> v CD -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i Double -> Double -> CD
forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE dct_ #-}
{-# SPECIAlIZE dct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE dct_ :: V.Vector CD -> V.Vector Double#-}
dctWorker :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
{-# INLINE dctWorker #-}
dctWorker :: v CD -> v Double
dctWorker v CD
xs
| v CD -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = (CD -> Double) -> v CD -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (Double -> Double) -> (CD -> Double) -> CD -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CD -> Double
forall a. Complex a -> a
realPart) v CD
xs
| v CD -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = (CD -> Double) -> v CD -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map CD -> Double
forall a. Complex a -> a
realPart (v CD -> v Double) -> v CD -> v Double
forall a b. (a -> b) -> a -> b
$ (CD -> CD -> CD) -> v CD -> v CD -> v CD
forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith CD -> CD -> CD
forall a. Num a => a -> a -> a
(*) v CD
weights (v CD -> v CD
forall (v :: * -> *). Vector v CD => v CD -> v CD
fft v CD
interleaved)
| Bool
otherwise = [Char] -> v Double
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
where
interleaved :: v CD
interleaved = v CD -> v Int -> v CD
forall (v :: * -> *) a.
(Vector v a, Vector v Int) =>
v a -> v Int -> v a
G.backpermute v CD
xs (v Int -> v CD) -> v Int -> v CD
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> v Int
forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo Int
0 Int
2 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2) v Int -> v Int -> v Int
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
G.++
Int -> Int -> Int -> v Int
forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
3) Int
1
weights :: v CD
weights = CD -> v CD -> v CD
forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
2 (v CD -> v CD) -> ((Int -> CD) -> v CD) -> (Int -> CD) -> v CD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Int -> CD) -> v CD
forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ((Int -> CD) -> v CD) -> (Int -> CD) -> v CD
forall a b. (a -> b) -> a -> b
$ \Int
x ->
CD
2 CD -> CD -> CD
forall a. Num a => a -> a -> a
* CD -> CD
forall a. Floating a => a -> a
exp ((Double
0Double -> Double -> CD
forall a. a -> a -> Complex a
:+(-Double
1))CD -> CD -> CD
forall a. Num a => a -> a -> a
*Int -> CD
fi (Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)CD -> CD -> CD
forall a. Num a => a -> a -> a
*CD
forall a. Floating a => a
piCD -> CD -> CD
forall a. Fractional a => a -> a -> a
/(CD
2CD -> CD -> CD
forall a. Num a => a -> a -> a
*CD
n))
where n :: CD
n = Int -> CD
fi Int
len
len :: Int
len = v CD -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs
idct :: (G.Vector v CD, G.Vector v Double) => v Double -> v Double
idct :: v Double -> v Double
idct = v CD -> v Double
forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker (v CD -> v Double) -> (v Double -> v CD) -> v Double -> v Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> CD) -> v Double -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (Double -> Double -> CD
forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE idct #-}
{-# SPECIAlIZE idct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE idct :: V.Vector Double -> V.Vector Double #-}
idct_ :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
idct_ :: v CD -> v Double
idct_ = v CD -> v Double
forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker (v CD -> v Double) -> (v CD -> v CD) -> v CD -> v Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CD -> CD) -> v CD -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i Double -> Double -> CD
forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE idct_ #-}
{-# SPECIAlIZE idct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE idct_ :: V.Vector CD -> V.Vector Double #-}
idctWorker :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
{-# INLINE idctWorker #-}
idctWorker :: v CD -> v Double
idctWorker v CD
xs
| v CD -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = Int -> (Int -> Double) -> v Double
forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate Int
len Int -> Double
interleave
| Bool
otherwise = [Char] -> v Double
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
where
interleave :: Int -> Double
interleave Int
z | Int -> Bool
forall a. Integral a => a -> Bool
even Int
z = v Double
vals v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` Int -> Int
halve Int
z
| Bool
otherwise = v Double
vals v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
halve Int
z Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
vals :: v Double
vals = (CD -> Double) -> v CD -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map CD -> Double
forall a. Complex a -> a
realPart (v CD -> v Double) -> (v CD -> v CD) -> v CD -> v Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v CD -> v CD
forall (v :: * -> *). Vector v CD => v CD -> v CD
ifft (v CD -> v Double) -> v CD -> v Double
forall a b. (a -> b) -> a -> b
$ (CD -> CD -> CD) -> v CD -> v CD -> v CD
forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith CD -> CD -> CD
forall a. Num a => a -> a -> a
(*) v CD
weights v CD
xs
weights :: v CD
weights
= CD -> v CD -> v CD
forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
n
(v CD -> v CD) -> v CD -> v CD
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> CD) -> v CD
forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> CD) -> v CD) -> (Int -> CD) -> v CD
forall a b. (a -> b) -> a -> b
$ \Int
x -> CD
2 CD -> CD -> CD
forall a. Num a => a -> a -> a
* CD
n CD -> CD -> CD
forall a. Num a => a -> a -> a
* CD -> CD
forall a. Floating a => a -> a
exp ((Double
0Double -> Double -> CD
forall a. a -> a -> Complex a
:+Double
1) CD -> CD -> CD
forall a. Num a => a -> a -> a
* Int -> CD
fi (Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) CD -> CD -> CD
forall a. Num a => a -> a -> a
* CD
forall a. Floating a => a
piCD -> CD -> CD
forall a. Fractional a => a -> a -> a
/(CD
2CD -> CD -> CD
forall a. Num a => a -> a -> a
*CD
n))
where n :: CD
n = Int -> CD
fi Int
len
len :: Int
len = v CD -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs
ifft :: G.Vector v CD => v CD -> v CD
ifft :: v CD -> v CD
ifft v CD
xs
| v CD -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = (CD -> CD) -> v CD -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((CD -> CD -> CD
forall a. Fractional a => a -> a -> a
/Int -> CD
fi (v CD -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs)) (CD -> CD) -> (CD -> CD) -> CD -> CD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CD -> CD
forall a. Num a => Complex a -> Complex a
conjugate) (v CD -> v CD) -> (v CD -> v CD) -> v CD -> v CD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v CD -> v CD
forall (v :: * -> *). Vector v CD => v CD -> v CD
fft (v CD -> v CD) -> (v CD -> v CD) -> v CD -> v CD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CD -> CD) -> v CD -> v CD
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map CD -> CD
forall a. Num a => Complex a -> Complex a
conjugate (v CD -> v CD) -> v CD -> v CD
forall a b. (a -> b) -> a -> b
$ v CD
xs
| Bool
otherwise = [Char] -> v CD
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.ifft: bad vector length"
{-# INLINABLE ifft #-}
{-# SPECIAlIZE ifft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE ifft :: V.Vector CD -> V.Vector CD #-}
fft :: G.Vector v CD => v CD -> v CD
fft :: v CD -> v CD
fft v CD
v | v CD -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
v = (forall s. ST s (Mutable v s CD)) -> v CD
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create ((forall s. ST s (Mutable v s CD)) -> v CD)
-> (forall s. ST s (Mutable v s CD)) -> v CD
forall a b. (a -> b) -> a -> b
$ do Mutable v s CD
mv <- v CD -> ST s (Mutable v (PrimState (ST s)) CD)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v CD
v
Mutable v s CD -> ST s ()
forall (v :: * -> * -> *) s. MVector v CD => v s CD -> ST s ()
mfft Mutable v s CD
mv
Mutable v s CD -> ST s (Mutable v s CD)
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s CD
mv
| Bool
otherwise = [Char] -> v CD
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.fft: bad vector length"
{-# INLINABLE fft #-}
{-# SPECIAlIZE fft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE fft :: V.Vector CD -> V.Vector CD #-}
mfft :: (M.MVector v CD) => v s CD -> ST s ()
{-# INLINE mfft #-}
mfft :: v s CD -> ST s ()
mfft v s CD
vec = Int -> Int -> ST s ()
bitReverse Int
0 Int
0
where
bitReverse :: Int -> Int -> ST s ()
bitReverse Int
i Int
j | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = Int -> Int -> ST s ()
stage Int
0 Int
1
| Bool
otherwise = do
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ v (PrimState (ST s)) CD -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
M.swap v s CD
v (PrimState (ST s)) CD
vec Int
i Int
j
let inner :: Int -> Int -> ST s ()
inner Int
k Int
l | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l = Int -> Int -> ST s ()
inner (Int
k Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
k)
| Bool
otherwise = Int -> Int -> ST s ()
bitReverse (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k)
Int -> Int -> ST s ()
inner (Int
len Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) Int
j
stage :: Int -> Int -> ST s ()
stage Int
l !Int
l1 | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m = () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = do
let !l2 :: Int
l2 = Int
l1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
1
!e :: Double
e = -Double
6.283185307179586Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l2
flight :: Int -> Double -> ST s ()
flight Int
j !Double
a | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l1 = Int -> Int -> ST s ()
stage (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
l2
| Bool
otherwise = do
let butterfly :: Int -> ST s ()
butterfly Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len = Int -> Double -> ST s ()
flight (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Double
aDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
e)
| Bool
otherwise = do
let i1 :: Int
i1 = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l1
Double
xi1 :+ Double
yi1 <- v (PrimState (ST s)) CD -> Int -> ST s CD
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
v (PrimState (ST s)) CD
vec Int
i1
let !c :: Double
c = Double -> Double
forall a. Floating a => a -> a
cos Double
a
!s :: Double
s = Double -> Double
forall a. Floating a => a -> a
sin Double
a
d :: CD
d = (Double
cDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
xi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
sDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
yi1) Double -> Double -> CD
forall a. a -> a -> Complex a
:+ (Double
sDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
xi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
cDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
yi1)
CD
ci <- v (PrimState (ST s)) CD -> Int -> ST s CD
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
v (PrimState (ST s)) CD
vec Int
i
v (PrimState (ST s)) CD -> Int -> CD -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
v (PrimState (ST s)) CD
vec Int
i1 (CD
ci CD -> CD -> CD
forall a. Num a => a -> a -> a
- CD
d)
v (PrimState (ST s)) CD -> Int -> CD -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
v (PrimState (ST s)) CD
vec Int
i (CD
ci CD -> CD -> CD
forall a. Num a => a -> a -> a
+ CD
d)
Int -> ST s ()
butterfly (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
l2)
Int -> ST s ()
butterfly Int
j
Int -> Double -> ST s ()
flight Int
0 Double
0
len :: Int
len = v s CD -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
M.length v s CD
vec
m :: Int
m = Int -> Int
log2 Int
len
fi :: Int -> CD
fi :: Int -> CD
fi = Int -> CD
forall a b. (Integral a, Num b) => a -> b
fromIntegral
halve :: Int -> Int
halve :: Int -> Int
halve = (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
vectorOK :: G.Vector v a => v a -> Bool
{-# INLINE vectorOK #-}
vectorOK :: v a -> Bool
vectorOK v a
v = (Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int -> Int
log2 Int
n) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n where n :: Int
n = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
v