{-# LANGUAGE TupleSections #-}
{-# LANGUAGE CPP, TypeOperators, FlexibleContexts, TypeFamilies
  , GeneralizedNewtypeDeriving, StandaloneDeriving, UndecidableInstances #-}
{-# OPTIONS_GHC -Wall -fno-warn-orphans #-}
----------------------------------------------------------------------
-- |
-- Module      :  Data.LinearMap
-- Copyright   :  (c) Conal Elliott 2008-2016
-- License     :  BSD3
--
-- Maintainer  :  conal@conal.net
-- Stability   :  experimental
--
-- Linear maps
----------------------------------------------------------------------

module Data.LinearMap
   ( (:-*) , linear, lapply, atBasis, idL, (*.*)
   , inLMap, inLMap2, inLMap3
   , liftMS, liftMS2, liftMS3
   , liftL, liftL2, liftL3
   , exlL, exrL, forkL, firstL, secondL
   , inlL, inrL, joinL -- , leftL, rightL
   )
  where

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative (Applicative)
#endif
import Control.Applicative (liftA2, liftA3)
import Control.Arrow       (first,second)

import Data.MemoTrie      (HasTrie(..),(:->:))
import Data.AdditiveGroup (Sum(..), AdditiveGroup(..))
import Data.VectorSpace   (VectorSpace(..))
import Data.Basis         (HasBasis(..), linearCombo)

-- Linear maps are almost but not quite a Control.Category.  The type
-- class constraints interfere.  They're almost an Arrow also, but for the
-- constraints and the generality of arr.

-- | An optional additive value
type MSum a = Maybe (Sum a)

jsum :: a -> MSum a
jsum :: a -> MSum a
jsum = Sum a -> MSum a
forall a. a -> Maybe a
Just (Sum a -> MSum a) -> (a -> Sum a) -> a -> MSum a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Sum a
forall a. a -> Sum a
Sum

type LMap' u v = MSum (Basis u :->: v)

infixr 1 :-*
-- | Linear map, represented as an optional memo-trie from basis to
-- values, where 'Nothing' means the zero map (an optimization).
newtype u :-* v = LMap { (u :-* v) -> LMap' u v
unLMap :: LMap' u v }

deriving instance (HasTrie (Basis u), AdditiveGroup v) => AdditiveGroup (u :-* v)

instance (HasTrie (Basis u), VectorSpace v) =>
         VectorSpace (u :-* v) where
  type Scalar (u :-* v) = Scalar v
  *^ :: Scalar (u :-* v) -> (u :-* v) -> u :-* v
(*^) Scalar (u :-* v)
s = ((Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: v)))
-> (u :-* v) -> u :-* v
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: v)))
 -> (u :-* v) -> u :-* v)
-> ((v -> v)
    -> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: v)))
-> (v -> v)
-> (u :-* v)
-> u :-* v
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Basis u :->: v) -> Basis u :->: v)
-> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: v))
forall a b. (a -> b) -> MSum a -> MSum b
liftMS(((Basis u :->: v) -> Basis u :->: v)
 -> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: v)))
-> ((v -> v) -> (Basis u :->: v) -> Basis u :->: v)
-> (v -> v)
-> Maybe (Sum (Basis u :->: v))
-> Maybe (Sum (Basis u :->: v))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(v -> v) -> (Basis u :->: v) -> Basis u :->: v
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (Scalar v
Scalar (u :-* v)
s Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^)

-- In GHC 7.10:
-- Constraint is no smaller than the instance head
-- in the constraint: HasTrie (Basis u)
-- (Use UndecidableInstances to permit this)

exlL :: ( HasBasis a, HasTrie (Basis a), HasBasis b, HasTrie (Basis b)
        , Scalar a ~ Scalar b )
     => (a,b) :-* a
exlL :: (a, b) :-* a
exlL = ((a, b) -> a) -> (a, b) :-* a
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (a, b) -> a
forall a b. (a, b) -> a
fst

exrL :: ( HasBasis a, HasTrie (Basis a), HasBasis b, HasTrie (Basis b)
        , Scalar a ~ Scalar b )
     => (a,b) :-* b
exrL :: (a, b) :-* b
exrL = ((a, b) -> b) -> (a, b) :-* b
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (a, b) -> b
forall a b. (a, b) -> b
snd

forkL :: (HasTrie (Basis a), HasBasis c, HasBasis d)
      => (a :-* c) -> (a :-* d) -> (a :-* (c,d))
forkL :: (a :-* c) -> (a :-* d) -> a :-* (c, d)
forkL = ((Maybe (Sum (Basis a :->: c))
 -> Maybe (Sum (Basis a :->: d))
 -> Maybe (Sum (Basis a :->: (c, d))))
-> (a :-* c) -> (a :-* d) -> a :-* (c, d)
forall r s t u v w.
(LMap' r s -> LMap' t u -> LMap' v w)
-> (r :-* s) -> (t :-* u) -> v :-* w
inLMap2((Maybe (Sum (Basis a :->: c))
  -> Maybe (Sum (Basis a :->: d))
  -> Maybe (Sum (Basis a :->: (c, d))))
 -> (a :-* c) -> (a :-* d) -> a :-* (c, d))
-> ((c -> d -> (c, d))
    -> Maybe (Sum (Basis a :->: c))
    -> Maybe (Sum (Basis a :->: d))
    -> Maybe (Sum (Basis a :->: (c, d))))
-> (c -> d -> (c, d))
-> (a :-* c)
-> (a :-* d)
-> a :-* (c, d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(c -> d -> (c, d))
-> Maybe (Sum (Basis a :->: c))
-> Maybe (Sum (Basis a :->: d))
-> Maybe (Sum (Basis a :->: (c, d)))
forall (f :: * -> *) a b c.
(Applicative f, AdditiveGroup (f a), AdditiveGroup (f b)) =>
(a -> b -> c) -> MSum (f a) -> MSum (f b) -> MSum (f c)
liftL2) (,)

firstL  :: ( HasBasis u, HasBasis u', HasBasis v
           , HasTrie (Basis u), HasTrie (Basis v) 
           , Scalar u ~ Scalar v, Scalar u ~ Scalar u'
           ) =>
           (u :-* u') -> ((u,v) :-* (u',v))
firstL :: (u :-* u') -> (u, v) :-* (u', v)
firstL  = ((u, v) -> (u', v)) -> (u, v) :-* (u', v)
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear(((u, v) -> (u', v)) -> (u, v) :-* (u', v))
-> ((u :-* u') -> (u, v) -> (u', v))
-> (u :-* u')
-> (u, v) :-* (u', v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(u -> u') -> (u, v) -> (u', v)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first((u -> u') -> (u, v) -> (u', v))
-> ((u :-* u') -> u -> u') -> (u :-* u') -> (u, v) -> (u', v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(u :-* u') -> u -> u'
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(u :-* v) -> u -> v
lapply

secondL :: ( HasBasis u, HasBasis v, HasBasis v'
           , HasTrie (Basis u), HasTrie (Basis v) 
           , Scalar u ~ Scalar v, Scalar u ~ Scalar v'
           ) =>
           (v :-* v') -> ((u,v) :-* (u,v'))
secondL :: (v :-* v') -> (u, v) :-* (u, v')
secondL = ((u, v) -> (u, v')) -> (u, v) :-* (u, v')
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear(((u, v) -> (u, v')) -> (u, v) :-* (u, v'))
-> ((v :-* v') -> (u, v) -> (u, v'))
-> (v :-* v')
-> (u, v) :-* (u, v')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(v -> v') -> (u, v) -> (u, v')
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second((v -> v') -> (u, v) -> (u, v'))
-> ((v :-* v') -> v -> v') -> (v :-* v') -> (u, v) -> (u, v')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(v :-* v') -> v -> v'
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(u :-* v) -> u -> v
lapply

-- TODO: more efficient firstL

-- liftMS :: (AdditiveGroup a) => (a -> b) -> (MSum a -> MSum b)

-- (s *^) :: v -> v
-- fmap (s *^) :: (Basis u :->: v) -> (Basis u :->: v)
-- (liftMS.fmap) (s *^) :: LMap' u v -> LMap' u v
-- (inLMap.liftMS.fmap) (s *^) :: (u :-* v) -> (u :-* v)


inlL :: (HasBasis a, HasTrie (Basis a), HasBasis b)
     => a :-* (a,b)
inlL :: a :-* (a, b)
inlL = (a -> (a, b)) -> a :-* (a, b)
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (,b
forall v. AdditiveGroup v => v
zeroV)

inrL :: (HasBasis a, HasBasis b, HasTrie (Basis b))
     => b :-* (a,b)
inrL :: b :-* (a, b)
inrL = (b -> (a, b)) -> b :-* (a, b)
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (a
forall v. AdditiveGroup v => v
zeroV,)

joinL :: ( HasBasis a, HasTrie (Basis a)
         , HasBasis b, HasTrie (Basis b)
         , Scalar a ~ Scalar b, Scalar a ~ Scalar c
         , VectorSpace c )
      => (a :-* c) -> (b :-* c) -> ((a,b) :-* c)
a :-* c
f joinL :: (a :-* c) -> (b :-* c) -> (a, b) :-* c
`joinL` b :-* c
g = ((a, b) -> c) -> (a, b) :-* c
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear (\ (a
a,b
b) -> (a :-* c) -> a -> c
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(u :-* v) -> u -> v
lapply a :-* c
f a
a c -> c -> c
forall v. AdditiveGroup v => v -> v -> v
^+^ (b :-* c) -> b -> c
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(u :-* v) -> u -> v
lapply b :-* c
g b
b)

-- Before version 0.7, u :-* v was a type synonym, resulting in a subtle
-- ambiguity: u:-*v == u':-*v' does not imply that u==u', since Basis
-- might map different types to the same basis (e.g., Float & Double).
-- See <http://hackage.haskell.org/trac/ghc/ticket/1897>.
-- See also <http://thread.gmane.org/gmane.comp.lang.haskell.cafe/73271/focus=73332>.

-- TODO: Try a partial trie instead, excluding (known) zero elements.
-- Then 'lapply' could be much faster for sparse situations.  Make sure to
-- correctly sum them.  It'd be more like Jason Foutz's formulation
-- <http://metavar.blogspot.com/2008/02/higher-order-multivariate-automatic.html>
-- which uses in @IntMap@.

-- | Function (assumed linear) as linear map.
linear :: (HasBasis u, HasTrie (Basis u)) =>
          (u -> v) -> (u :-* v)
linear :: (u -> v) -> u :-* v
linear u -> v
f = LMap' u v -> u :-* v
forall u v. LMap' u v -> u :-* v
LMap ((Basis u :->: v) -> LMap' u v
forall a. a -> MSum a
jsum ((Basis u -> v) -> Basis u :->: v
forall a b. HasTrie a => (a -> b) -> a :->: b
trie (u -> v
f (u -> v) -> (Basis u -> u) -> Basis u -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Basis u -> u
forall v. HasBasis v => Basis v -> v
basisValue)))

atZ :: AdditiveGroup b => (a -> b) -> (MSum a -> b)
atZ :: (a -> b) -> MSum a -> b
atZ a -> b
f = b -> (Sum a -> b) -> MSum a -> b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe b
forall v. AdditiveGroup v => v
zeroV (a -> b
f (a -> b) -> (Sum a -> a) -> Sum a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sum a -> a
forall a. Sum a -> a
getSum)

-- atZ :: AdditiveGroup b => (a -> b) -> (a -> b)
-- atZ = id

inLMap :: (LMap' r s -> LMap' t u) -> ((r :-* s) -> (t :-* u))
inLMap :: (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap = (r :-* s) -> LMap' r s
forall u v. (u :-* v) -> LMap' u v
unLMap ((r :-* s) -> LMap' r s)
-> (LMap' t u -> t :-* u)
-> (LMap' r s -> LMap' t u)
-> (r :-* s)
-> t :-* u
forall a' a b b'. (a' -> a) -> (b -> b') -> (a -> b) -> a' -> b'
~> LMap' t u -> t :-* u
forall u v. LMap' u v -> u :-* v
LMap

inLMap2 :: (LMap' r s -> LMap' t u -> LMap' v w)
        -> ((r :-* s) -> (t :-* u) -> (v :-* w))
inLMap2 :: (LMap' r s -> LMap' t u -> LMap' v w)
-> (r :-* s) -> (t :-* u) -> v :-* w
inLMap2 = (r :-* s) -> LMap' r s
forall u v. (u :-* v) -> LMap' u v
unLMap ((r :-* s) -> LMap' r s)
-> ((LMap' t u -> LMap' v w) -> (t :-* u) -> v :-* w)
-> (LMap' r s -> LMap' t u -> LMap' v w)
-> (r :-* s)
-> (t :-* u)
-> v :-* w
forall a' a b b'. (a' -> a) -> (b -> b') -> (a -> b) -> a' -> b'
~> (LMap' t u -> LMap' v w) -> (t :-* u) -> v :-* w
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap

inLMap3 :: (LMap' r s -> LMap' t u -> LMap' v w -> LMap' x y)
        -> ((r :-* s) -> (t :-* u) -> (v :-* w) -> (x :-* y))
inLMap3 :: (LMap' r s -> LMap' t u -> LMap' v w -> LMap' x y)
-> (r :-* s) -> (t :-* u) -> (v :-* w) -> x :-* y
inLMap3 = (r :-* s) -> LMap' r s
forall u v. (u :-* v) -> LMap' u v
unLMap ((r :-* s) -> LMap' r s)
-> ((LMap' t u -> LMap' v w -> LMap' x y)
    -> (t :-* u) -> (v :-* w) -> x :-* y)
-> (LMap' r s -> LMap' t u -> LMap' v w -> LMap' x y)
-> (r :-* s)
-> (t :-* u)
-> (v :-* w)
-> x :-* y
forall a' a b b'. (a' -> a) -> (b -> b') -> (a -> b) -> a' -> b'
~> (LMap' t u -> LMap' v w -> LMap' x y)
-> (t :-* u) -> (v :-* w) -> x :-* y
forall r s t u v w.
(LMap' r s -> LMap' t u -> LMap' v w)
-> (r :-* s) -> (t :-* u) -> v :-* w
inLMap2

-- | Apply a linear map to a vector.
lapply :: ( VectorSpace v, Scalar u ~ Scalar v
          , HasBasis u, HasTrie (Basis u) ) =>
          (u :-* v) -> (u -> v)
lapply :: (u :-* v) -> u -> v
lapply = ((Basis u :->: v) -> u -> v) -> MSum (Basis u :->: v) -> u -> v
forall b a. AdditiveGroup b => (a -> b) -> MSum a -> b
atZ (Basis u :->: v) -> u -> v
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(Basis u :->: v) -> u -> v
lapply' (MSum (Basis u :->: v) -> u -> v)
-> ((u :-* v) -> MSum (Basis u :->: v)) -> (u :-* v) -> u -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (u :-* v) -> MSum (Basis u :->: v)
forall u v. (u :-* v) -> LMap' u v
unLMap

-- | Evaluate a linear map on a basis element.
atBasis :: (AdditiveGroup v, HasTrie (Basis u)) =>
           (u :-* v) -> Basis u -> v
LMap LMap' u v
m atBasis :: (u :-* v) -> Basis u -> v
`atBasis` Basis u
b = ((Basis u :->: v) -> v) -> LMap' u v -> v
forall b a. AdditiveGroup b => (a -> b) -> MSum a -> b
atZ ((Basis u :->: v) -> Basis u -> v
forall a b. HasTrie a => (a :->: b) -> a -> b
`untrie` Basis u
b) LMap' u v
m

-- | Handy for 'lapply' and '(*.*)'.
lapply' :: ( VectorSpace v, Scalar u ~ Scalar v
           , HasBasis u, HasTrie (Basis u) ) =>
           (Basis u :->: v) -> (u -> v)
lapply' :: (Basis u :->: v) -> u -> v
lapply' Basis u :->: v
tr = [(v, Scalar v)] -> v
forall v. VectorSpace v => [(v, Scalar v)] -> v
linearCombo ([(v, Scalar v)] -> v) -> (u -> [(v, Scalar v)]) -> u -> v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Basis u, Scalar v) -> (v, Scalar v))
-> [(Basis u, Scalar v)] -> [(v, Scalar v)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Basis u -> v) -> (Basis u, Scalar v) -> (v, Scalar v)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((Basis u :->: v) -> Basis u -> v
forall a b. HasTrie a => (a :->: b) -> a -> b
untrie Basis u :->: v
tr)) ([(Basis u, Scalar v)] -> [(v, Scalar v)])
-> (u -> [(Basis u, Scalar v)]) -> u -> [(v, Scalar v)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. u -> [(Basis u, Scalar v)]
forall v. HasBasis v => v -> [(Basis v, Scalar v)]
decompose

-- | Identity linear map
idL :: (HasBasis u, HasTrie (Basis u)) =>
       u :-* u
idL :: u :-* u
idL = (u -> u) -> u :-* u
forall u v. (HasBasis u, HasTrie (Basis u)) => (u -> v) -> u :-* v
linear u -> u
forall a. a -> a
id


infixr 9 *.*
-- | Compose linear maps
(*.*) :: ( HasTrie (Basis u)
         , HasBasis v, HasTrie (Basis v)
         , VectorSpace w
         , Scalar v ~ Scalar w ) =>
         (v :-* w) -> (u :-* v) -> (u :-* w)

-- Simple definition, but only optimizes out uv == zero

-- vw *.* uv = LMap ((fmap.fmap.fmap) (lapply vw) (unLMap uv))

*.* :: (v :-* w) -> (u :-* v) -> u :-* w
(*.*) v :-* w
vw = ((Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: w)))
-> (u :-* v) -> u :-* w
forall r s t u. (LMap' r s -> LMap' t u) -> (r :-* s) -> t :-* u
inLMap((Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: w)))
 -> (u :-* v) -> u :-* w)
-> ((v -> w)
    -> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: w)))
-> (v -> w)
-> (u :-* v)
-> u :-* w
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Sum (Basis u :->: v) -> Sum (Basis u :->: w))
-> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap((Sum (Basis u :->: v) -> Sum (Basis u :->: w))
 -> Maybe (Sum (Basis u :->: v)) -> Maybe (Sum (Basis u :->: w)))
-> ((v -> w) -> Sum (Basis u :->: v) -> Sum (Basis u :->: w))
-> (v -> w)
-> Maybe (Sum (Basis u :->: v))
-> Maybe (Sum (Basis u :->: w))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Basis u :->: v) -> Basis u :->: w)
-> Sum (Basis u :->: v) -> Sum (Basis u :->: w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((Basis u :->: v) -> Basis u :->: w)
 -> Sum (Basis u :->: v) -> Sum (Basis u :->: w))
-> ((v -> w) -> (Basis u :->: v) -> Basis u :->: w)
-> (v -> w)
-> Sum (Basis u :->: v)
-> Sum (Basis u :->: w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(v -> w) -> (Basis u :->: v) -> Basis u :->: w
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) ((v :-* w) -> v -> w
forall v u.
(VectorSpace v, Scalar u ~ Scalar v, HasBasis u,
 HasTrie (Basis u)) =>
(u :-* v) -> u -> v
lapply v :-* w
vw)

-- Eep:
--     (*.*) = inLMap.fmap.fmap.fmap.lapply


-- Instead, use Nothing/zero if /either/ map is zeroV (exploiting linearity
-- when uv == zeroV.)

-- LMap Nothing         *.* _                    = LMap Nothing
-- _                    *.* LMap Nothing         = LMap Nothing
-- LMap (Just (Sum vw)) *.* LMap (Just (Sum uv)) = LMap (Just (Sum (lapply' vw <$> uv)))

-- (*.*) = liftA2 (\ (LMap (Sum vw)) (LMap (Sum uv)) -> LMap (Sum (lapply' vw <$> uv)))

-- (*.*) = (liftA2.inSum2.inLMap2) (\ vw uv -> lapply' vw <$> uv)

-- (*.*) = (liftA2.inSum2.inLMap2) (\ vw -> fmap (lapply' vw))

-- (*.*) = (liftA2.inSum2.inLMap2) (fmap . lapply')


-- It may be helpful that @lapply vw@ is evaluated just once and not
-- once per uv.  'untrie' can strip off all of its trie constructors.

-- Less efficient definition:
--
--   vw `compL` uv = linear (lapply vw . lapply uv)
--
--   i.e., compL = inL2 (.)
--
-- The problem with these definitions is that basis elements get converted
-- to values and then decomposed, followed by recombination of the
-- results.

liftMS :: (a -> b) -> (MSum a -> MSum b)
-- liftMS _ Nothing = Nothing
-- liftMS h ma = Just (Sum (h (z ma)))

liftMS :: (a -> b) -> MSum a -> MSum b
liftMS = (Sum a -> Sum b) -> MSum a -> MSum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap((Sum a -> Sum b) -> MSum a -> MSum b)
-> ((a -> b) -> Sum a -> Sum b) -> (a -> b) -> MSum a -> MSum b
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(a -> b) -> Sum a -> Sum b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

liftMS2 :: (AdditiveGroup a, AdditiveGroup b) =>
           (a -> b -> c) ->
           (MSum a -> MSum b -> MSum c)
liftMS2 :: (a -> b -> c) -> MSum a -> MSum b -> MSum c
liftMS2 a -> b -> c
_ MSum a
Nothing MSum b
Nothing = MSum c
forall a. Maybe a
Nothing
liftMS2 a -> b -> c
h MSum a
ma MSum b
mb = Sum c -> MSum c
forall a. a -> Maybe a
Just (c -> Sum c
forall a. a -> Sum a
Sum (a -> b -> c
h (MSum a -> a
forall u. AdditiveGroup u => MSum u -> u
fromMS MSum a
ma) (MSum b -> b
forall u. AdditiveGroup u => MSum u -> u
fromMS MSum b
mb)))

liftMS3 :: (AdditiveGroup a, AdditiveGroup b, AdditiveGroup c) =>
           (a -> b -> c -> d) ->
           (MSum a -> MSum b -> MSum c -> MSum d)
liftMS3 :: (a -> b -> c -> d) -> MSum a -> MSum b -> MSum c -> MSum d
liftMS3 a -> b -> c -> d
_ MSum a
Nothing MSum b
Nothing MSum c
Nothing = MSum d
forall a. Maybe a
Nothing
liftMS3 a -> b -> c -> d
h MSum a
ma MSum b
mb MSum c
mc = Sum d -> MSum d
forall a. a -> Maybe a
Just (d -> Sum d
forall a. a -> Sum a
Sum (a -> b -> c -> d
h (MSum a -> a
forall u. AdditiveGroup u => MSum u -> u
fromMS MSum a
ma) (MSum b -> b
forall u. AdditiveGroup u => MSum u -> u
fromMS MSum b
mb) (MSum c -> c
forall u. AdditiveGroup u => MSum u -> u
fromMS MSum c
mc)))

fromMS :: AdditiveGroup u => MSum u -> u
fromMS :: MSum u -> u
fromMS MSum u
Nothing        = u
forall v. AdditiveGroup v => v
zeroV
fromMS (Just (Sum u
u)) = u
u


-- | Apply a linear function to each element of a linear map.
-- @liftL f l == linear f *.* l@, but works more efficiently.
liftL :: Functor f => (a -> b) -> MSum (f a) -> MSum (f b)
liftL :: (a -> b) -> MSum (f a) -> MSum (f b)
liftL = (f a -> f b) -> MSum (f a) -> MSum (f b)
forall a b. (a -> b) -> MSum a -> MSum b
liftMS ((f a -> f b) -> MSum (f a) -> MSum (f b))
-> ((a -> b) -> f a -> f b) -> (a -> b) -> MSum (f a) -> MSum (f b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

-- | Apply a linear binary function (not to be confused with a bilinear
-- function) to each element of a linear map.
liftL2 :: (Applicative f, AdditiveGroup (f a), AdditiveGroup (f b)) =>
          (a -> b -> c)
       -> (MSum (f a) -> MSum (f b) -> MSum (f c))
liftL2 :: (a -> b -> c) -> MSum (f a) -> MSum (f b) -> MSum (f c)
liftL2 = (f a -> f b -> f c) -> MSum (f a) -> MSum (f b) -> MSum (f c)
forall a b c.
(AdditiveGroup a, AdditiveGroup b) =>
(a -> b -> c) -> MSum a -> MSum b -> MSum c
liftMS2 ((f a -> f b -> f c) -> MSum (f a) -> MSum (f b) -> MSum (f c))
-> ((a -> b -> c) -> f a -> f b -> f c)
-> (a -> b -> c)
-> MSum (f a)
-> MSum (f b)
-> MSum (f c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b -> c) -> f a -> f b -> f c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2

-- | Apply a linear ternary function (not to be confused with a trilinear
-- function) to each element of a linear map.
liftL3 :: ( Applicative f
          , AdditiveGroup (f a), AdditiveGroup (f b), AdditiveGroup (f c)) =>
          (a -> b -> c -> d)
       -> (MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d))
liftL3 :: (a -> b -> c -> d)
-> MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d)
liftL3 = (f a -> f b -> f c -> f d)
-> MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d)
forall a b c d.
(AdditiveGroup a, AdditiveGroup b, AdditiveGroup c) =>
(a -> b -> c -> d) -> MSum a -> MSum b -> MSum c -> MSum d
liftMS3 ((f a -> f b -> f c -> f d)
 -> MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d))
-> ((a -> b -> c -> d) -> f a -> f b -> f c -> f d)
-> (a -> b -> c -> d)
-> MSum (f a)
-> MSum (f b)
-> MSum (f c)
-> MSum (f d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b -> c -> d) -> f a -> f b -> f c -> f d
forall (f :: * -> *) a b c d.
Applicative f =>
(a -> b -> c -> d) -> f a -> f b -> f c -> f d
liftA3

{-


infixr 9 *.*
-- | Compose linear maps
(*.*) :: ( HasBasis u, HasTrie (Basis u)
         , HasBasis v, HasTrie (Basis v)
         , VectorSpace w
         , Scalar v ~ Scalar w ) =>
         (v :-* w) -> (u :-* v) -> (u :-* w)

-- Simple definition, but only optimizes out uv == zero
--
-- (*.*) vw = (fmap.fmap) (lapply vw)

-- Instead, use Nothing/zero if /either/ map is zeroV (exploiting linearity
-- when uv == zeroV.)

-- Nothing       *.* _             = Nothing
-- _             *.* Nothing       = Nothing
-- Just (Sum vw) *.* Just (Sum uv) = Just (Sum (lapply' vw <$> uv))

-- (*.*) = liftA2 (\ (Sum vw) (Sum uv) -> Sum (lapply' vw <$> uv))

-- (*.*) = (liftA2.inSum2) (\ vw uv -> lapply' vw <$> uv)
(*.*) = (liftA2.inSum2) (\ vw uv -> lapply' vw <$> uv)

-- (*.*) = (liftA2.inSum2) (\ vw -> fmap (lapply' vw))

-- (*.*) = (liftA2.inSum2) (fmap . lapply')


-}

-----

(~>) :: (a' -> a) -> (b -> b') -> ((a -> b) -> (a' -> b'))
(a' -> a
f ~> :: (a' -> a) -> (b -> b') -> (a -> b) -> a' -> b'
~> b -> b'
h) a -> b
g = b -> b'
h (b -> b') -> (a' -> b) -> a' -> b'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
g (a -> b) -> (a' -> a) -> a' -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a' -> a
f