{-# LANGUAGE TypeOperators, MultiParamTypeClasses, UndecidableInstances
           , TypeSynonymInstances, FlexibleInstances
           , FlexibleContexts, TypeFamilies
           , ScopedTypeVariables, CPP  #-}

-- The ScopedTypeVariables is there just as a bug work-around.  Without it
-- I get a bogus error about context mismatch for mutually recursive
-- definitions.  This bug was introduced between ghc 6.9.20080622 and
-- 6.10.0.20081007.


-- {-# OPTIONS_GHC -ddump-simpl-stats -ddump-simpl #-}

-- TODO: remove FlexibleContexts

{-# OPTIONS_GHC -Wall #-}
----------------------------------------------------------------------
-- |
-- Module      :  Data.Maclaurin
-- Copyright   :  (c) Conal Elliott 2008
-- License     :  BSD3
-- 
-- Maintainer  :  conal@conal.net
-- Stability   :  experimental
-- 
-- Infinite derivative towers via linear maps, using the Maclaurin
-- representation.  See blog posts <http://conal.net/blog/tag/derivative/>.
----------------------------------------------------------------------

module Data.Maclaurin
  (
    (:>)(D), powVal, derivative, derivAtBasis  -- maybe not D
  , (:~>), pureD
  , fmapD, (<$>>){-, (<*>>)-}, liftD2, liftD3
  , idD, fstD, sndD
  , linearD, distrib
  -- , (@.)
  , (>-<)
  -- * Misc
  , pairD, unpairD, tripleD, untripleD
  ) 
    where

-- import Control.Applicative (liftA2)
import Data.Function (on)

import Data.VectorSpace
import Data.NumInstances ()
import Data.MemoTrie
import Data.Basis
import Data.LinearMap

import Data.Boolean

#if MIN_VERSION_base(4,8,0)
import Prelude hiding ((<*))
#endif

infixr 9 `D`
-- | Tower of derivatives.
data a :> b = D { (a :> b) -> b
powVal :: b, (a :> b) -> a :-* (a :> b)
derivative :: a :-* (a :> b) }

-- | Infinitely differentiable functions
type a :~> b = a -> (a:>b)

-- Handy for missing methods.
noOv :: String -> a
noOv :: String -> a
noOv String
op = String -> a
forall a. HasCallStack => String -> a
error (String
op String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": not defined on a :> b")

-- | Constant derivative tower.
pureD :: (AdditiveGroup b, HasBasis a, HasTrie (Basis a)) => b -> a:>b
pureD :: b -> a :> b
pureD b
b = b
b b -> (a :-* (a :> b)) -> a :> b
forall a b. b -> (a :-* (a :> b)) -> a :> b
`D` a :-* (a :> b)
forall v. AdditiveGroup v => v
zeroV


infixl 4 <$>>
-- | Map a /linear/ function over a derivative tower.
fmapD, (<$>>) :: HasTrie (Basis a) => (b -> c) -> (a :> b) -> (a :> c)
fmapD :: (b -> c) -> (a :> b) -> a :> c
fmapD b -> c
f = (a :> b) -> a :> c
lf
 where
   lf :: (a :> b) -> a :> c
lf (D b
b0 a :-* (a :> b)
b') = c -> (a :-* (a :> c)) -> a :> c
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (b -> c
f b
b0) (((Maybe (Sum (Basis a :->: (a :> b)))
 -> Maybe (Sum (Basis a :->: (a :> c))))
-> (a :-* (a :> b)) -> a :-* (a :> c)
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis a :->: (a :> b)))
  -> Maybe (Sum (Basis a :->: (a :> c))))
 -> (a :-* (a :> b)) -> a :-* (a :> c))
-> (((a :> b) -> a :> c)
    -> Maybe (Sum (Basis a :->: (a :> b)))
    -> Maybe (Sum (Basis a :->: (a :> c))))
-> ((a :> b) -> a :> c)
-> (a :-* (a :> b))
-> a :-* (a :> c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a :> b) -> a :> c)
-> Maybe (Sum (Basis a :->: (a :> b)))
-> Maybe (Sum (Basis a :->: (a :> c)))
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> MSum (f a) -> MSum (f b)
liftL) (a :> b) -> a :> c
lf a :-* (a :> b)
b')

<$>> :: (b -> c) -> (a :> b) -> a :> c
(<$>>) = (b -> c) -> (a :> b) -> a :> c
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
fmapD

-- | Apply a /linear/ binary function over derivative towers.
liftD2 :: (HasBasis a, HasTrie (Basis a), AdditiveGroup b, AdditiveGroup c) =>
          (b -> c -> d) -> (a :> b) -> (a :> c) -> (a :> d)
liftD2 :: (b -> c -> d) -> (a :> b) -> (a :> c) -> a :> d
liftD2 b -> c -> d
f = (a :> b) -> (a :> c) -> a :> d
lf
 where
   lf :: (a :> b) -> (a :> c) -> a :> d
lf (D b
b0 a :-* (a :> b)
b') (D c
c0 a :-* (a :> c)
c') = d -> (a :-* (a :> d)) -> a :> d
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (b -> c -> d
f b
b0 c
c0) (((Maybe (Sum (Basis a :->: (a :> b)))
 -> Maybe (Sum (Basis a :->: (a :> c)))
 -> Maybe (Sum (Basis a :->: (a :> d))))
-> (a :-* (a :> b)) -> (a :-* (a :> c)) -> a :-* (a :> d)
forall r s t u v w.
(LMap' r s -> LMap' t u -> LMap' v w)
-> (r :-* s) -> (t :-* u) -> v :-* w
inLMap2((Maybe (Sum (Basis a :->: (a :> b)))
  -> Maybe (Sum (Basis a :->: (a :> c)))
  -> Maybe (Sum (Basis a :->: (a :> d))))
 -> (a :-* (a :> b)) -> (a :-* (a :> c)) -> a :-* (a :> d))
-> (((a :> b) -> (a :> c) -> a :> d)
    -> Maybe (Sum (Basis a :->: (a :> b)))
    -> Maybe (Sum (Basis a :->: (a :> c)))
    -> Maybe (Sum (Basis a :->: (a :> d))))
-> ((a :> b) -> (a :> c) -> a :> d)
-> (a :-* (a :> b))
-> (a :-* (a :> c))
-> a :-* (a :> d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a :> b) -> (a :> c) -> a :> d)
-> Maybe (Sum (Basis a :->: (a :> b)))
-> Maybe (Sum (Basis a :->: (a :> c)))
-> Maybe (Sum (Basis a :->: (a :> d)))
forall (f :: * -> *) a b c.
(Applicative f, AdditiveGroup (f a), AdditiveGroup (f b)) =>
(a -> b -> c) -> MSum (f a) -> MSum (f b) -> MSum (f c)
liftL2) (a :> b) -> (a :> c) -> a :> d
lf a :-* (a :> b)
b' a :-* (a :> c)
c')


-- | Apply a /linear/ ternary function over derivative towers.
liftD3 :: (HasBasis a, HasTrie (Basis a)
          , AdditiveGroup b, AdditiveGroup c, AdditiveGroup d) =>
          (b -> c -> d -> e)
       -> (a :> b) -> (a :> c) -> (a :> d) -> (a :> e)
liftD3 :: (b -> c -> d -> e) -> (a :> b) -> (a :> c) -> (a :> d) -> a :> e
liftD3 b -> c -> d -> e
f = (a :> b) -> (a :> c) -> (a :> d) -> a :> e
lf
 where
   lf :: (a :> b) -> (a :> c) -> (a :> d) -> a :> e
lf (D b
b0 a :-* (a :> b)
b') (D c
c0 a :-* (a :> c)
c') (D d
d0 a :-* (a :> d)
d') =
     e -> (a :-* (a :> e)) -> a :> e
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (b -> c -> d -> e
f b
b0 c
c0 d
d0) (((Maybe (Sum (Basis a :->: (a :> b)))
 -> Maybe (Sum (Basis a :->: (a :> c)))
 -> Maybe (Sum (Basis a :->: (a :> d)))
 -> Maybe (Sum (Basis a :->: (a :> e))))
-> (a :-* (a :> b))
-> (a :-* (a :> c))
-> (a :-* (a :> d))
-> a :-* (a :> e)
forall r s t u v w x y.
(LMap' r s -> LMap' t u -> LMap' v w -> LMap' x y)
-> (r :-* s) -> (t :-* u) -> (v :-* w) -> x :-* y
inLMap3((Maybe (Sum (Basis a :->: (a :> b)))
  -> Maybe (Sum (Basis a :->: (a :> c)))
  -> Maybe (Sum (Basis a :->: (a :> d)))
  -> Maybe (Sum (Basis a :->: (a :> e))))
 -> (a :-* (a :> b))
 -> (a :-* (a :> c))
 -> (a :-* (a :> d))
 -> a :-* (a :> e))
-> (((a :> b) -> (a :> c) -> (a :> d) -> a :> e)
    -> Maybe (Sum (Basis a :->: (a :> b)))
    -> Maybe (Sum (Basis a :->: (a :> c)))
    -> Maybe (Sum (Basis a :->: (a :> d)))
    -> Maybe (Sum (Basis a :->: (a :> e))))
-> ((a :> b) -> (a :> c) -> (a :> d) -> a :> e)
-> (a :-* (a :> b))
-> (a :-* (a :> c))
-> (a :-* (a :> d))
-> a :-* (a :> e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a :> b) -> (a :> c) -> (a :> d) -> a :> e)
-> Maybe (Sum (Basis a :->: (a :> b)))
-> Maybe (Sum (Basis a :->: (a :> c)))
-> Maybe (Sum (Basis a :->: (a :> d)))
-> Maybe (Sum (Basis a :->: (a :> e)))
forall (f :: * -> *) a b c d.
(Applicative f, AdditiveGroup (f a), AdditiveGroup (f b),
 AdditiveGroup (f c)) =>
(a -> b -> c -> d)
-> MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d)
liftL3) (a :> b) -> (a :> c) -> (a :> d) -> a :> e
lf a :-* (a :> b)
b' a :-* (a :> c)
c' a :-* (a :> d)
d')


-- TODO: Can liftD2 and liftD3 be defined in terms of a (<*>>) similar to
-- (<*>)?  If so, can the speed be as good?

-- liftD2 f a b = (f <$>> a) <*>> b
-- 
-- liftD3 f a b c = liftD2 f a b <*>> c


-- | Differentiable identity function.  Sometimes called "the
-- derivation variable" or similar, but it's not really a variable.
idD :: (VectorSpace u , HasBasis u, HasTrie (Basis u)) =>
       u :~> u
idD :: u :~> u
idD = (u -> u) -> u :~> u
forall u v.
(HasBasis u, HasTrie (Basis u), AdditiveGroup v) =>
(u -> v) -> u :~> v
linearD u -> u
forall a. a -> a
id

-- or
--   dId v = D v pureD

-- | Every linear function has a constant derivative equal to the function
-- itself (as a linear map).
linearD :: (HasBasis u, HasTrie (Basis u), AdditiveGroup v) =>
           (u -> v) -> (u :~> v)

-- linearD f u = f u `D` linear (pureD . f)

-- HEY!  I think there's a hugely wasteful recomputation going on in
-- 'linearD' above.  Note the definition of 'linear':
-- 
--     linear f = trie (f . basisValue)
-- 
-- Substituting,
-- 
--     linearD f u = f u `D` trie ((pureD . f) . basisValue)
-- 
-- The trie gets rebuilt for each @u@.

-- Look for similar problems.

linearD :: (u -> v) -> u :~> v
linearD u -> v
f = \ u
u -> u -> v
f u
u v -> (u :-* (u :> v)) -> u :> v
forall a b. b -> (a :-* (a :> b)) -> a :> b
`D` u :-* (u :> v)
d
 where
   d :: u :-* (u :> v)
d = (u :~> v) -> u :-* (u :> v)
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (v -> u :> v
forall b a.
(AdditiveGroup b, HasBasis a, HasTrie (Basis a)) =>
b -> a :> b
pureD (v -> u :> v) -> (u -> v) -> u :~> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. u -> v
f)

-- (`D` d) . f

-- linearD f = (`D` linear (pureD . f)) . f


-- Other examples of linear functions

-- | Differentiable version of 'fst'
fstD :: ( HasBasis a, HasTrie (Basis a)
        , HasBasis b, HasTrie (Basis b)
        , Scalar a ~ Scalar b
        ) => (a,b) :~> a
fstD :: (a, b) :~> a
fstD = ((a, b) -> a) -> (a, b) :~> a
forall u v.
(HasBasis u, HasTrie (Basis u), AdditiveGroup v) =>
(u -> v) -> u :~> v
linearD (a, b) -> a
forall a b. (a, b) -> a
fst

-- | Differentiable version of 'snd'
sndD :: ( HasBasis a, HasTrie (Basis a)
        , HasBasis b, HasTrie (Basis b)
        , Scalar a ~ Scalar b
        ) => (a,b) :~> b
sndD :: (a, b) :~> b
sndD = ((a, b) -> b) -> (a, b) :~> b
forall u v.
(HasBasis u, HasTrie (Basis u), AdditiveGroup v) =>
(u -> v) -> u :~> v
linearD (a, b) -> b
forall a b. (a, b) -> b
snd

-- | Derivative tower for applying a binary function that distributes over
-- addition, such as multiplication.  A bit weaker assumption than
-- bilinearity.  Is bilinearity necessary for correctness here?
distrib :: forall a b c u. (HasBasis a, HasTrie (Basis a) , AdditiveGroup u) =>
           (b -> c -> u) -> (a :> b) -> (a :> c) -> (a :> u)

distrib :: (b -> c -> u) -> (a :> b) -> (a :> c) -> a :> u
distrib b -> c -> u
op = (a :> b) -> (a :> c) -> a :> u
(#)
 where
   u :: a :> b
u@(D b
u0 a :-* (a :> b)
u') # :: (a :> b) -> (a :> c) -> a :> u
# v :: a :> c
v@(D c
v0 a :-* (a :> c)
v') =
     u -> (a :-* (a :> u)) -> a :> u
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (b
u0 b -> c -> u
`op` c
v0) ( ((Maybe (Sum (Basis a :->: (a :> b)))
 -> Maybe (Sum (Basis a :->: (a :> u))))
-> (a :-* (a :> b)) -> a :-* (a :> u)
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis a :->: (a :> b)))
  -> Maybe (Sum (Basis a :->: (a :> u))))
 -> (a :-* (a :> b)) -> a :-* (a :> u))
-> (((Basis a :->: (a :> b)) -> Basis a :->: (a :> u))
    -> Maybe (Sum (Basis a :->: (a :> b)))
    -> Maybe (Sum (Basis a :->: (a :> u))))
-> ((Basis a :->: (a :> b)) -> Basis a :->: (a :> u))
-> (a :-* (a :> b))
-> a :-* (a :> u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Basis a :->: (a :> b)) -> Basis a :->: (a :> u))
-> Maybe (Sum (Basis a :->: (a :> b)))
-> Maybe (Sum (Basis a :->: (a :> u)))
forall a b. (a -> b) -> MSum a -> MSum b
liftMS) (((Basis a -> a :> b) -> Basis a -> a :> u)
-> (Basis a :->: (a :> b)) -> Basis a :->: (a :> u)
forall a c b d.
(HasTrie a, HasTrie c) =>
((a -> b) -> c -> d) -> (a :->: b) -> c :->: d
inTrie (((a :> b) -> (a :> c) -> a :> u
# a :> c
v) ((a :> b) -> a :> u) -> (Basis a -> a :> b) -> Basis a -> a :> u
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)) a :-* (a :> b)
u' (a :-* (a :> u)) -> (a :-* (a :> u)) -> a :-* (a :> u)
forall v. AdditiveGroup v => v -> v -> v
^+^
                      ((Maybe (Sum (Basis a :->: (a :> c)))
 -> Maybe (Sum (Basis a :->: (a :> u))))
-> (a :-* (a :> c)) -> a :-* (a :> u)
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis a :->: (a :> c)))
  -> Maybe (Sum (Basis a :->: (a :> u))))
 -> (a :-* (a :> c)) -> a :-* (a :> u))
-> (((Basis a :->: (a :> c)) -> Basis a :->: (a :> u))
    -> Maybe (Sum (Basis a :->: (a :> c)))
    -> Maybe (Sum (Basis a :->: (a :> u))))
-> ((Basis a :->: (a :> c)) -> Basis a :->: (a :> u))
-> (a :-* (a :> c))
-> a :-* (a :> u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Basis a :->: (a :> c)) -> Basis a :->: (a :> u))
-> Maybe (Sum (Basis a :->: (a :> c)))
-> Maybe (Sum (Basis a :->: (a :> u)))
forall a b. (a -> b) -> MSum a -> MSum b
liftMS) (((Basis a -> a :> c) -> Basis a -> a :> u)
-> (Basis a :->: (a :> c)) -> Basis a :->: (a :> u)
forall a c b d.
(HasTrie a, HasTrie c) =>
((a -> b) -> c -> d) -> (a :->: b) -> c :->: d
inTrie ((a :> b
u (a :> b) -> (a :> c) -> a :> u
#) ((a :> c) -> a :> u) -> (Basis a -> a :> c) -> Basis a -> a :> u
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)) a :-* (a :> c)
v' )


-- TODO: I think this distrib is exponential in increasing degree.  Switch
-- to the Horner representation.  See /The Music of Streams/ by Doug
-- McIlroy.


-- instance Show b => Show (a :> b) where show    = noOv "show"

instance Show b => Show (a :> b) where
  show :: (a :> b) -> String
show (D b
b0 a :-* (a :> b)
_) = String
"D " String -> String -> String
forall a. [a] -> [a] -> [a]
++ b -> String
forall a. Show a => a -> String
show b
b0  String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ..."

instance Eq   (a :> b) where == :: (a :> b) -> (a :> b) -> Bool
(==)    = String -> (a :> b) -> (a :> b) -> Bool
forall a. String -> a
noOv String
"(==)"

type instance BooleanOf (a :> b) = BooleanOf b

instance (AdditiveGroup v, HasBasis u, HasTrie (Basis u), IfB v) =>
      IfB (u :> v) where
  ifB :: bool -> (u :> v) -> (u :> v) -> u :> v
ifB = (v -> v -> v) -> (u :> v) -> (u :> v) -> u :> v
forall a b c d.
(HasBasis a, HasTrie (Basis a), AdditiveGroup b,
 AdditiveGroup c) =>
(b -> c -> d) -> (a :> b) -> (a :> c) -> a :> d
liftD2 ((v -> v -> v) -> (u :> v) -> (u :> v) -> u :> v)
-> (bool -> v -> v -> v) -> bool -> (u :> v) -> (u :> v) -> u :> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. bool -> v -> v -> v
forall a bool. (IfB a, bool ~ BooleanOf a) => bool -> a -> a -> a
ifB

instance OrdB v => OrdB (u :> v) where
  <* :: (u :> v) -> (u :> v) -> bool
(<*) = v -> v -> bool
forall a bool. (OrdB a, bool ~ BooleanOf a) => a -> a -> bool
(<*) (v -> v -> bool) -> ((u :> v) -> v) -> (u :> v) -> (u :> v) -> bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (u :> v) -> v
forall a b. (a :> b) -> b
powVal

instance ( AdditiveGroup b, HasBasis a, HasTrie (Basis a)
         , OrdB b, IfB b, Ord  b) => Ord  (a :> b) where
  compare :: (a :> b) -> (a :> b) -> Ordering
compare = b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (b -> b -> Ordering)
-> ((a :> b) -> b) -> (a :> b) -> (a :> b) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a :> b) -> b
forall a b. (a :> b) -> b
powVal
  min :: (a :> b) -> (a :> b) -> a :> b
min     = (a :> b) -> (a :> b) -> a :> b
forall a. (IfB a, OrdB a) => a -> a -> a
minB
  max :: (a :> b) -> (a :> b) -> a :> b
max     = (a :> b) -> (a :> b) -> a :> b
forall a. (IfB a, OrdB a) => a -> a -> a
maxB

-- minB & maxB use ifB, and so can work even if b is an expression type,
-- as in deep DSELs.

instance (HasBasis a, HasTrie (Basis a), AdditiveGroup u) => AdditiveGroup (a :> u) where
  zeroV :: a :> u
zeroV   = u -> a :> u
forall b a.
(AdditiveGroup b, HasBasis a, HasTrie (Basis a)) =>
b -> a :> b
pureD  u
forall v. AdditiveGroup v => v
zeroV
  negateV :: (a :> u) -> a :> u
negateV = (u -> u) -> (a :> u) -> a :> u
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
fmapD  u -> u
forall v. AdditiveGroup v => v -> v
negateV
  D u
a0 a :-* (a :> u)
a' ^+^ :: (a :> u) -> (a :> u) -> a :> u
^+^ D u
b0 a :-* (a :> u)
b' = u -> (a :-* (a :> u)) -> a :> u
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (u
a0 u -> u -> u
forall v. AdditiveGroup v => v -> v -> v
^+^ u
b0) (a :-* (a :> u)
a' (a :-* (a :> u)) -> (a :-* (a :> u)) -> a :-* (a :> u)
forall v. AdditiveGroup v => v -> v -> v
^+^ a :-* (a :> u)
b')
  -- Less efficient: adds zero
  -- (^+^)   = liftD2 (^+^)

instance (HasBasis a, HasTrie (Basis a), VectorSpace u)
      => VectorSpace (a :> u) where
  type Scalar (a :> u) = (a :> Scalar u)
  *^ :: Scalar (a :> u) -> (a :> u) -> a :> u
(*^) = (Scalar u -> u -> u) -> (a :> Scalar u) -> (a :> u) -> a :> u
forall a b c u.
(HasBasis a, HasTrie (Basis a), AdditiveGroup u) =>
(b -> c -> u) -> (a :> b) -> (a :> c) -> a :> u
distrib Scalar u -> u -> u
forall v. VectorSpace v => Scalar v -> v -> v
(*^)                     

instance ( InnerSpace u, s ~ Scalar u, AdditiveGroup s
         , HasBasis a, HasTrie (Basis a) ) =>
     InnerSpace (a :> u) where
  <.> :: (a :> u) -> (a :> u) -> Scalar (a :> u)
(<.>) = (u -> u -> s) -> (a :> u) -> (a :> u) -> a :> s
forall a b c u.
(HasBasis a, HasTrie (Basis a), AdditiveGroup u) =>
(b -> c -> u) -> (a :> b) -> (a :> c) -> a :> u
distrib u -> u -> s
forall v. InnerSpace v => v -> v -> Scalar v
(<.>)

-- infixr 9 @.
-- -- | Chain rule.  See also '(>-<)'.
-- (@.) :: (HasTrie (Basis b), HasTrie (Basis a), VectorSpace c s) =>
--         (b :~> c) -> (a :~> b) -> (a :~> c)
-- (h @. g) a0 = D c0 (inL2 (@.) c' b')
--   where
--     D b0 b' = g a0
--     D c0 c' = h b0

infix  0 >-<

-- | Specialized chain rule.  See also '(\@.)'
(>-<) :: (HasBasis a, HasTrie (Basis a), VectorSpace u) =>
         (u -> u) -> ((a :> u) -> (a :> Scalar u))
      -> (a :> u) -> (a :> u)
u -> u
f >-< :: (u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> u) -> a :> Scalar u
f' = \ u :: a :> u
u@(D u
u0 a :-* (a :> u)
u') -> u -> (a :-* (a :> u)) -> a :> u
forall a b. b -> (a :-* (a :> b)) -> a :> b
D (u -> u
f u
u0) (((Maybe (Sum (Basis a :->: (a :> u)))
 -> Maybe (Sum (Basis a :->: (a :> u))))
-> (a :-* (a :> u)) -> a :-* (a :> u)
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis a :->: (a :> u)))
  -> Maybe (Sum (Basis a :->: (a :> u))))
 -> (a :-* (a :> u)) -> a :-* (a :> u))
-> (((Basis a :->: (a :> u)) -> Basis a :->: (a :> u))
    -> Maybe (Sum (Basis a :->: (a :> u)))
    -> Maybe (Sum (Basis a :->: (a :> u))))
-> ((Basis a :->: (a :> u)) -> Basis a :->: (a :> u))
-> (a :-* (a :> u))
-> a :-* (a :> u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Basis a :->: (a :> u)) -> Basis a :->: (a :> u))
-> Maybe (Sum (Basis a :->: (a :> u)))
-> Maybe (Sum (Basis a :->: (a :> u)))
forall a b. (a -> b) -> MSum a -> MSum b
liftMS) ((a :> u) -> a :> Scalar u
f' a :> u
u Scalar (Basis a :->: (a :> u))
-> (Basis a :->: (a :> u)) -> Basis a :->: (a :> u)
forall v. VectorSpace v => Scalar v -> v -> v
*^) a :-* (a :> u)
u')


-- TODO: express '(>-<)' in terms of '(@.)'.  If I can't, then understand why not.

instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
         , Num s, VectorSpace s, Scalar s ~ s
         )
      => Num (a:>s) where
  fromInteger :: Integer -> a :> s
fromInteger = s -> a :> s
forall b a.
(AdditiveGroup b, HasBasis a, HasTrie (Basis a)) =>
b -> a :> b
pureD (s -> a :> s) -> (Integer -> s) -> Integer -> a :> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> s
forall a. Num a => Integer -> a
fromInteger
  + :: (a :> s) -> (a :> s) -> a :> s
(+)    = (a :> s) -> (a :> s) -> a :> s
forall v. AdditiveGroup v => v -> v -> v
(^+^)
  * :: (a :> s) -> (a :> s) -> a :> s
(*)    = (s -> s -> s) -> (a :> s) -> (a :> s) -> a :> s
forall a b c u.
(HasBasis a, HasTrie (Basis a), AdditiveGroup u) =>
(b -> c -> u) -> (a :> b) -> (a :> c) -> a :> u
distrib s -> s -> s
forall a. Num a => a -> a -> a
(*)
  negate :: (a :> s) -> a :> s
negate = s -> s
forall a. Num a => a -> a
negate (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< -(a :> s) -> a :> s
1
  abs :: (a :> s) -> a :> s
abs    = s -> s
forall a. Num a => a -> a
abs    (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Num a => a -> a
signum
  signum :: (a :> s) -> a :> s
signum = s -> s
forall a. Num a => a -> a
signum (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
0  -- derivative wrong at zero

instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
         , Fractional s, VectorSpace s, Scalar s ~ s)
         => Fractional (a:>s) where
  fromRational :: Rational -> a :> s
fromRational = s -> a :> s
forall b a.
(AdditiveGroup b, HasBasis a, HasTrie (Basis a)) =>
b -> a :> b
pureD (s -> a :> s) -> (Rational -> s) -> Rational -> a :> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> s
forall a. Fractional a => Rational -> a
fromRational
  recip :: (a :> s) -> a :> s
recip        = s -> s
forall a. Fractional a => a -> a
recip (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< - ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip (a :> s) -> a :> s
forall a. Num a => a -> a
sqr

sqr :: Num a => a -> a
sqr :: a -> a
sqr a
x = a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x

instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
         , Floating s, VectorSpace s, Scalar s ~ s)
         => Floating (a:>s) where
  pi :: a :> s
pi    = s -> a :> s
forall b a.
(AdditiveGroup b, HasBasis a, HasTrie (Basis a)) =>
b -> a :> b
pureD s
forall a. Floating a => a
pi
  exp :: (a :> s) -> a :> s
exp   = s -> s
forall a. Floating a => a -> a
exp   (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Floating a => a -> a
exp
  log :: (a :> s) -> a :> s
log   = s -> s
forall a. Floating a => a -> a
log   (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Fractional a => a -> a
recip
  sqrt :: (a :> s) -> a :> s
sqrt  = s -> s
forall a. Floating a => a -> a
sqrt  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip ((a :> s) -> a :> s
2 ((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
* (a :> s) -> a :> s
forall a. Floating a => a -> a
sqrt)
  sin :: (a :> s) -> a :> s
sin   = s -> s
forall a. Floating a => a -> a
sin   (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Floating a => a -> a
cos
  cos :: (a :> s) -> a :> s
cos   = s -> s
forall a. Floating a => a -> a
cos   (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< - (a :> s) -> a :> s
forall a. Floating a => a -> a
sin
  sinh :: (a :> s) -> a :> s
sinh  = s -> s
forall a. Floating a => a -> a
sinh  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Floating a => a -> a
cosh
  cosh :: (a :> s) -> a :> s
cosh  = s -> s
forall a. Floating a => a -> a
cosh  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< (a :> s) -> a :> Scalar s
forall a. Floating a => a -> a
sinh
  asin :: (a :> s) -> a :> s
asin  = s -> s
forall a. Floating a => a -> a
asin  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip (((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Floating a => a -> a
sqrt ((a :> s) -> a :> s
1((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
-(a :> s) -> a :> s
forall a. Num a => a -> a
sqr))
  acos :: (a :> s) -> a :> s
acos  = s -> s
forall a. Floating a => a -> a
acos  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip (- ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Floating a => a -> a
sqrt ((a :> s) -> a :> s
1((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
-(a :> s) -> a :> s
forall a. Num a => a -> a
sqr))
  atan :: (a :> s) -> a :> s
atan  = s -> s
forall a. Floating a => a -> a
atan  (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip ((a :> s) -> a :> s
1((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
+(a :> s) -> a :> s
forall a. Num a => a -> a
sqr)
  asinh :: (a :> s) -> a :> s
asinh = s -> s
forall a. Floating a => a -> a
asinh (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip (((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Floating a => a -> a
sqrt ((a :> s) -> a :> s
1((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
+(a :> s) -> a :> s
forall a. Num a => a -> a
sqr))
  acosh :: (a :> s) -> a :> s
acosh = s -> s
forall a. Floating a => a -> a
acosh (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip (- ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Floating a => a -> a
sqrt ((a :> s) -> a :> s
forall a. Num a => a -> a
sqr((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
-(a :> s) -> a :> s
1))
  atanh :: (a :> s) -> a :> s
atanh = s -> s
forall a. Floating a => a -> a
atanh (s -> s) -> ((a :> s) -> a :> Scalar s) -> (a :> s) -> a :> s
forall a u.
(HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> a :> Scalar u) -> (a :> u) -> a :> u
>-< ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Fractional a => a -> a
recip ((a :> s) -> a :> s
1((a :> s) -> a :> s) -> ((a :> s) -> a :> s) -> (a :> s) -> a :> s
forall a. Num a => a -> a -> a
-(a :> s) -> a :> s
forall a. Num a => a -> a
sqr)


-- | Sample the derivative at a basis element.  Optimized for partial
-- application to save work for non-scalar derivatives.
derivAtBasis :: (HasTrie (Basis a), HasBasis a, AdditiveGroup b) =>
                (a :> b) -> (Basis a -> (a :> b))
derivAtBasis :: (a :> b) -> Basis a -> a :> b
derivAtBasis a :> b
f = (a :-* (a :> b)) -> Basis a -> a :> b
forall v u.
(AdditiveGroup v, HasTrie (Basis u)) =>
(u :-* v) -> Basis u -> v
atBasis ((a :> b) -> a :-* (a :> b)
forall a b. (a :> b) -> a :-* (a :> b)
derivative a :> b
f)


---- Misc

pairD :: (HasBasis a, HasTrie (Basis a), VectorSpace b, VectorSpace c)
      => (a:>b,a:>c) -> a:>(b,c)

pairD :: (a :> b, a :> c) -> a :> (b, c)
pairD (a :> b
u,a :> c
v) = (b -> c -> (b, c)) -> (a :> b) -> (a :> c) -> a :> (b, c)
forall a b c d.
(HasBasis a, HasTrie (Basis a), AdditiveGroup b,
 AdditiveGroup c) =>
(b -> c -> d) -> (a :> b) -> (a :> c) -> a :> d
liftD2 (,) a :> b
u a :> c
v

unpairD :: HasTrie (Basis a) => (a :> (b,c)) -> (a:>b, a:>c)
unpairD :: (a :> (b, c)) -> (a :> b, a :> c)
unpairD a :> (b, c)
d = ((b, c) -> b
forall a b. (a, b) -> a
fst ((b, c) -> b) -> (a :> (b, c)) -> a :> b
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
<$>> a :> (b, c)
d, (b, c) -> c
forall a b. (a, b) -> b
snd ((b, c) -> c) -> (a :> (b, c)) -> a :> c
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
<$>> a :> (b, c)
d)


tripleD :: ( HasBasis a, HasTrie (Basis a)
           , VectorSpace b, VectorSpace c, VectorSpace d
           ) => (a:>b,a:>c,a:>d) -> a:>(b,c,d)
tripleD :: (a :> b, a :> c, a :> d) -> a :> (b, c, d)
tripleD (a :> b
u,a :> c
v,a :> d
w) = (b -> c -> d -> (b, c, d))
-> (a :> b) -> (a :> c) -> (a :> d) -> a :> (b, c, d)
forall a b c d e.
(HasBasis a, HasTrie (Basis a), AdditiveGroup b, AdditiveGroup c,
 AdditiveGroup d) =>
(b -> c -> d -> e) -> (a :> b) -> (a :> c) -> (a :> d) -> a :> e
liftD3 (,,) a :> b
u a :> c
v a :> d
w

untripleD :: HasTrie (Basis a) => (a :> (b,c,d)) -> (a:>b, a:>c, a:>d)
untripleD :: (a :> (b, c, d)) -> (a :> b, a :> c, a :> d)
untripleD a :> (b, c, d)
d =
  ((\ (b
a,c
_,d
_) -> b
a) ((b, c, d) -> b) -> (a :> (b, c, d)) -> a :> b
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
<$>> a :> (b, c, d)
d, (\ (b
_,c
b,d
_) -> c
b) ((b, c, d) -> c) -> (a :> (b, c, d)) -> a :> c
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
<$>> a :> (b, c, d)
d, (\ (b
_,c
_,d
c) -> d
c) ((b, c, d) -> d) -> (a :> (b, c, d)) -> a :> d
forall a b c. HasTrie (Basis a) => (b -> c) -> (a :> b) -> a :> c
<$>> a :> (b, c, d)
d)