]> gitweb.michael.orlitzky.com - numerical-analysis.git/commitdiff
Fix backward_substitute.
authorMichael Orlitzky <michael@orlitzky.com>
Tue, 11 Feb 2014 20:07:53 +0000 (15:07 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Tue, 11 Feb 2014 20:07:53 +0000 (15:07 -0500)
Add 'solve_positive_definite' to Linear.System.
Add tests for everything in Linear.System.

src/Linear/System.hs

index 2d75f611de1dc85825677ce59e72fe008bdbd21d..d68805a61c88ab78bf60cb04393d4e09c40604b2 100644 (file)
@@ -4,23 +4,36 @@
 
 module Linear.System (
   backward_substitute,
 
 module Linear.System (
   backward_substitute,
-  forward_substitute )
+  forward_substitute,
+  solve_positive_definite )
 where
 
 where
 
-import Data.Vector.Fixed ( Arity, N1 )
+import qualified Algebra.Algebraic as Algebraic ( C )
+import Data.Vector.Fixed ( Arity )
 import NumericPrelude hiding ( (*), abs )
 import qualified NumericPrelude as NP ( (*) )
 import qualified Algebra.Field as Field ( C )
 
 import NumericPrelude hiding ( (*), abs )
 import qualified NumericPrelude as NP ( (*) )
 import qualified Algebra.Field as Field ( C )
 
-import Linear.Matrix ( Mat(..), (!!!), construct, transpose )
+import Linear.Matrix (
+  Col,
+  Mat(..),
+  (!!!),
+  cholesky,
+  construct,
+  is_lower_triangular,
+  is_upper_triangular,
+  ncols,
+  transpose )
 
 
 
 
--- | Solve the system m' * x = b', where m' is upper-triangular. Will
+-- | Solve the system m' * x = b', where m' is lower-triangular. Will
 --   probably crash if m' is non-singular. The result is the vector x.
 --
 --   Examples:
 --
 --   probably crash if m' is non-singular. The result is the vector x.
 --
 --   Examples:
 --
---   >>> import Linear.Matrix ( Mat2, Mat3, fromList, vec2d, vec3d )
+--   >>> import Linear.Matrix ( Mat2, Mat3, frobenius_norm, fromList )
+--   >>> import Linear.Matrix ( vec2d, vec3d )
+--   >>> import Naturals ( N7 )
 --
 --   >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
 --   >>> let b = vec3d (1, 2, 3::Double)
 --
 --   >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
 --   >>> let b = vec3d (1, 2, 3::Double)
@@ -39,11 +52,44 @@ import Linear.Matrix ( Mat(..), (!!!), construct, transpose )
 --   >>> forward_substitute m b
 --   ((0.5),(0.75))
 --
 --   >>> forward_substitute m b
 --   ((0.5),(0.75))
 --
-forward_substitute :: forall a m. (Field.C a, Arity m)
+--   >>> let f1 = [0.0418]
+--   >>> let f2 = [0.0805]
+--   >>> let f3 = [0.1007]
+--   >>> let f4 = [-0.0045]
+--   >>> let f5 = [-0.0332]
+--   >>> let f6 = [-0.0054]
+--   >>> let f7 = [-0.0267]
+--   >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
+--   >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
+--   >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
+--   >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
+--   >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
+--   >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
+--   >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
+--   >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
+--   >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
+--   >>> let r = cholesky big_K
+--   >>> let rt = transpose r
+--   >>> let e1 = [0.0170647785413895] :: [Double]
+--   >>> let e2 = [0.0338] :: [Double]
+--   >>> let e3 = [0.07408] :: [Double]
+--   >>> let e4 = [-0.00183711730708738] :: [Double]
+--   >>> let e5 = [-0.0135538432434003] :: [Double]
+--   >>> let e6 = [-0.00220454076850486] :: [Double]
+--   >>> let e7 = [-0.00689391035624920] :: [Double]
+--   >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
+--   >>> let actual = forward_substitute rt big_F
+--   >>> frobenius_norm (actual - expected) < 1e-10
+--   True
+--
+forward_substitute :: forall a m. (Eq a, Field.C a, Arity m)
                    => Mat m m a
                    => Mat m m a
-                   -> Mat m N1 a
-                   -> Mat m N1 a
-forward_substitute m' b' = x'
+                   -> Col m a
+                   -> Col m a
+forward_substitute m' b'
+  | not (is_lower_triangular m') =
+      error "forward substitution on non-lower-triangular matrix"
+  | otherwise = x'
   where
     x' = construct lambda
 
   where
     x' = construct lambda
 
@@ -67,7 +113,7 @@ forward_substitute m' b' = x'
                                   j <- [0..k-1] ]) / (m k k)
 
 
                                   j <- [0..k-1] ]) / (m k k)
 
 
--- | Solve the system m*x = b, where m is lower-triangular. Will
+-- | Solve the system m*x = b, where m is upper-triangular. Will
 --   probably crash if m is non-singular. The result is the vector x.
 --
 --   Examples:
 --   probably crash if m is non-singular. The result is the vector x.
 --
 --   Examples:
@@ -81,21 +127,91 @@ forward_substitute m' b' = x'
 --   >>> (backward_substitute identity b) == b
 --   True
 --
 --   >>> (backward_substitute identity b) == b
 --   True
 --
-backward_substitute :: (Field.C a, Arity m)
+--   >>> let m1 = fromList [[1,1,1], [0,1,1], [0,0,1]] :: Mat3 Double
+--   >>> let b = vec3d (1,1,1::Double)
+--   >>> backward_substitute m1 b
+--   ((0.0),(0.0),(1.0))
+--
+backward_substitute :: forall m a. (Eq a, Field.C a, Arity m)
                     => Mat m m a
                     => Mat m m a
-                    -> Mat m N1 a
-                    -> Mat m N1 a
-backward_substitute m =
-  forward_substitute (transpose m)
+                    -> Col m a
+                    -> Col m a
+backward_substitute m' b'
+  | not (is_upper_triangular m') =
+      error "backward substitution on non-upper-triangular matrix"
+  | otherwise = x'
+    where
+      x' = construct lambda
+
+      -- Convenient accessor for the elements of b'.
+      b :: Int -> a
+      b k = b' !!! (k, 0)
+
+      -- Convenient accessor for the elements of m'.
+      m :: Int -> Int -> a
+      m i j = m' !!! (i, j)
+
+      -- Convenient accessor for the elements of x'.
+      x :: Int -> a
+      x k = x' !!! (k, 0)
+
+      -- The second argument to lambda should always be zero here, so we
+      -- ignore it.
+      lambda :: Int -> Int -> a
+      lambda k _
+        | k == n = (b k) / (m k k)
+        | otherwise = ((b k) - sum [ (m k j) NP.* (x j) |
+                                    j <- [k+1..n] ]) / (m k k)
+        where
+          n = (ncols m') - 1
 
 
 -- | Solve the linear system m*x = b where m is positive definite.
 
 
 -- | Solve the linear system m*x = b where m is positive definite.
-{-
-solve_positive_definite :: Mat v w a -> Mat w z a
+--
+--   Examples:
+--
+--   >>> import Linear.Matrix ( Col4, frobenius_norm, fromList )
+--   >>> import Naturals ( N7 )
+--
+--   >>> let f1 = [0.0418]
+--   >>> let f2 = [0.0805]
+--   >>> let f3 = [0.1007]
+--   >>> let f4 = [-0.0045]
+--   >>> let f5 = [-0.0332]
+--   >>> let f6 = [-0.0054]
+--   >>> let f7 = [-0.0267]
+--   >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
+--
+--   >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
+--   >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
+--   >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
+--   >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
+--   >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
+--   >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
+--   >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
+--   >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
+--
+--   >>> let e1 = [1871/75000] :: [Double]
+--   >>> let e2 = [899/25000] :: [Double]
+--   >>> let e3 = [463/15625] :: [Double]
+--   >>> let e4 = [-3/4000] :: [Double]
+--   >>> let e5 = [-83/15000] :: [Double]
+--   >>> let e6 = [-9/10000] :: [Double]
+--   >>> let e7 = [-89/50000] :: [Double]
+--   >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
+--   >>> let actual = solve_positive_definite big_K big_F
+--   >>> frobenius_norm (actual - expected) < 1e-12
+--   True
+--
+solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a)
+                        => Mat m m a
+                        -> Col m a
+                        -> Col m a
 solve_positive_definite m b = x
   where
     r = cholesky m
 solve_positive_definite m b = x
   where
     r = cholesky m
-    -- First we solve r^T * y == b for y. Then let y=r*x
-    rx = forward_substitute (transpose r) b
-    -- Now solve r*x == b.
--}
+    -- Now, r^T*r*x = b. Let r*x = y, so the system looks like
+    -- r^T * y = b. We can solve this for y.
+    y = forward_substitute (transpose r) b
+    -- Now solve r*x = y to find the value of x.
+    x = backward_substitute r y