]> gitweb.michael.orlitzky.com - numerical-analysis.git/commitdiff
Implement forward substitute in terms of a fold.
authorMichael Orlitzky <michael@orlitzky.com>
Tue, 11 Feb 2014 21:10:53 +0000 (16:10 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Tue, 11 Feb 2014 21:10:53 +0000 (16:10 -0500)
src/Linear/Iteration.hs
src/Linear/System.hs

index 6fb279c78bbb64fc5122070c82bce6f956e16a14..6117770a52313c21546ed57e729854a7d1fc5f8c 100644 (file)
@@ -27,6 +27,7 @@ import qualified Algebra.RealField as RealField ( C )
 import qualified Algebra.ToRational as ToRational ( C )
 
 import Linear.Matrix (
+  Col,
   Mat(..),
   (!!!),
   (*),
@@ -41,12 +42,12 @@ import Normed ( Normed(..) )
 -- | A generalized implementation for Jacobi, Gauss-Seidel, etc. All
 --   that we really need to know is how to construct the matrix M, so we
 --   take a function that does it as an argument.
-classical_iteration :: (Eq a, Field.C a, Arity m)
-                 => (Mat m m  a -> Mat m m  a)
-                 -> Mat m m  a
-                 -> Mat m N1 a
-                 -> Mat m N1 a
-                 -> Mat m N1 a
+classical_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
+                 => (Mat m m a -> Mat m m a)
+                 -> Mat m m a
+                 -> Col m a
+                 -> Col m a
+                 -> Col m a
 classical_iteration m_function matrix b x_current =
   x_next
   where
@@ -59,8 +60,8 @@ classical_iteration m_function matrix b x_current =
 
 -- | Perform one iteration of successive over-relaxation.
 --
-sor_iteration :: forall m a.
-                 (Eq a, Field.C a, Arity m)
+sor_iteration :: forall m a.
+                 (Eq a, Field.C a, m ~ S n, Arity n)
               => a -- ^ Omega
               -> Mat m m  a -- ^ Matrix A
               -> Mat m N1 a -- ^ Vector b
@@ -79,7 +80,7 @@ sor_iteration omega =
 
 -- | Compute an infinite list of SOR iterations starting with the
 --   vector x0.
-sor_iterations :: (Eq a, Field.C a, Arity m)
+sor_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
                => a
                -> Mat m m  a
                -> Mat m N1 a
@@ -90,7 +91,7 @@ sor_iterations omega matrix b =
 
 
 -- | Perform one iteration of Gauss-Seidel.
-gauss_seidel_iteration :: (Eq a, Field.C a, Arity m)
+gauss_seidel_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
                        => Mat m m  a
                        -> Mat m N1 a
                        -> Mat m N1 a
@@ -100,7 +101,7 @@ gauss_seidel_iteration = sor_iteration one
 
 -- | Compute an infinite list of Gauss-Seidel iterations starting with
 --   the vector x0.
-gauss_seidel_iterations :: (Eq a, Field.C a, Arity m)
+gauss_seidel_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
                         => Mat m m  a
                         -> Mat m N1 a
                         -> Mat m N1 a
@@ -126,7 +127,7 @@ gauss_seidel_iterations matrix b =
 --   >>> jacobi_iteration m b x1
 --   ((0.0),(0.25))
 --
-jacobi_iteration :: (Eq a, Field.C a, Arity m)
+jacobi_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
                  => Mat m m  a
                  -> Mat m N1 a
                  -> Mat m N1 a
@@ -137,7 +138,7 @@ jacobi_iteration =
 
 -- | Compute an infinite list of Jacobi iterations starting with the
 --   vector x0.
-jacobi_iterations :: (Eq a, Field.C a, Arity m)
+jacobi_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
                   => Mat m m  a
                   -> Mat m N1 a
                   -> Mat m N1 a
index d68805a61c88ab78bf60cb04393d4e09c40604b2..82e1e4b68282cd56f74732e38ee7107aa16c1cc4 100644 (file)
@@ -9,7 +9,7 @@ module Linear.System (
 where
 
 import qualified Algebra.Algebraic as Algebraic ( C )
-import Data.Vector.Fixed ( Arity )
+import Data.Vector.Fixed ( Arity, S )
 import NumericPrelude hiding ( (*), abs )
 import qualified NumericPrelude as NP ( (*) )
 import qualified Algebra.Field as Field ( C )
@@ -20,9 +20,15 @@ import Linear.Matrix (
   (!!!),
   cholesky,
   construct,
+  diagonal,
+  dot,
+  ifoldl2,
   is_lower_triangular,
   is_upper_triangular,
   ncols,
+  row,
+  set_idx,
+  zip2,
   transpose )
 
 
@@ -83,34 +89,23 @@ import Linear.Matrix (
 --   True
 --
 forward_substitute :: forall a m. (Eq a, Field.C a, Arity m)
-                   => Mat m m a
-                   -> Col m a
-                   -> Col m a
-forward_substitute m' b'
-  | not (is_lower_triangular m') =
+                   => Mat (S m) (S m) a
+                   -> Col (S m) a
+                   -> Col (S m) a
+forward_substitute matrix b
+  | not (is_lower_triangular matrix) =
       error "forward substitution on non-lower-triangular matrix"
-  | otherwise = x'
-  where
-    x' = construct lambda
-
-    -- Convenient accessor for the elements of b'.
-    b :: Int -> a
-    b k = b' !!! (k, 0)
-
-    -- Convenient accessor for the elements of m'.
-    m :: Int -> Int -> a
-    m i j = m' !!! (i, j)
+  | otherwise = ifoldl2 f zero pairs
+      where
+        -- Pairs (m_ii, b_i) needed at each step.
+        pairs :: Col (S m) (a,a)
+        pairs = zip2 (diagonal matrix) b
 
-    -- Convenient accessor for the elements of x'.
-    x :: Int -> a
-    x k = x' !!! (k, 0)
+        f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a
+        f i _ x (mii, bi) = set_idx x (i,0) newval
+          where
+            newval = (bi - (x `dot` (transpose $ row matrix i))) / mii
 
-    -- The second argument to lambda should always be zero here, so we
-    -- ignore it.
-    lambda :: Int -> Int -> a
-    lambda 0 _ = (b 0) / (m 0 0)
-    lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) |
-                                  j <- [0..k-1] ]) / (m k k)
 
 
 -- | Solve the system m*x = b, where m is upper-triangular. Will
@@ -133,9 +128,9 @@ forward_substitute m' b'
 --   ((0.0),(0.0),(1.0))
 --
 backward_substitute :: forall m a. (Eq a, Field.C a, Arity m)
-                    => Mat m m a
-                    -> Col m a
-                    -> Col m a
+                    => Mat (S m) (S m) a
+                    -> Col (S m) a
+                    -> Col (S m) a
 backward_substitute m' b'
   | not (is_upper_triangular m') =
       error "backward substitution on non-upper-triangular matrix"
@@ -204,9 +199,9 @@ backward_substitute m' b'
 --   True
 --
 solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a)
-                        => Mat m m a
-                        -> Col m a
-                        -> Col m a
+                        => Mat (S m) (S m) a
+                        -> Col (S m) a
+                        -> Col (S m) a
 solve_positive_definite m b = x
   where
     r = cholesky m