]> gitweb.michael.orlitzky.com - numerical-analysis.git/commitdiff
Use a better implementation of backwards_substitute.
authorMichael Orlitzky <michael@orlitzky.com>
Wed, 12 Feb 2014 01:14:47 +0000 (20:14 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Wed, 12 Feb 2014 01:14:47 +0000 (20:14 -0500)
src/Linear/System.hs

index 82e1e4b68282cd56f74732e38ee7107aa16c1cc4..6d46a135c66def1b27bbdd44bb8ae15029d9eb92 100644 (file)
@@ -11,21 +11,18 @@ where
 import qualified Algebra.Algebraic as Algebraic ( C )
 import Data.Vector.Fixed ( Arity, S )
 import NumericPrelude hiding ( (*), abs )
-import qualified NumericPrelude as NP ( (*) )
 import qualified Algebra.Field as Field ( C )
 
 import Linear.Matrix (
   Col,
   Mat(..),
-  (!!!),
   cholesky,
-  construct,
   diagonal,
   dot,
   ifoldl2,
+  ifoldr2,
   is_lower_triangular,
   is_upper_triangular,
-  ncols,
   row,
   set_idx,
   zip2,
@@ -109,7 +106,9 @@ forward_substitute matrix b
 
 
 -- | Solve the system m*x = b, where m is upper-triangular. Will
---   probably crash if m is non-singular. The result is the vector x.
+--   probably crash if m is non-singular. The result is the vector
+--   x. The implementation is identical to 'forward_substitute' except
+--   with a right-fold.
 --
 --   Examples:
 --
@@ -131,34 +130,19 @@ backward_substitute :: forall m a. (Eq a, Field.C a, Arity m)
                     => Mat (S m) (S m) a
                     -> Col (S m) a
                     -> Col (S m) a
-backward_substitute m' b'
-  | not (is_upper_triangular m') =
+backward_substitute matrix b
+  | not (is_upper_triangular matrix) =
       error "backward substitution on non-upper-triangular matrix"
-  | otherwise = x'
-    where
-      x' = construct lambda
-
-      -- Convenient accessor for the elements of b'.
-      b :: Int -> a
-      b k = b' !!! (k, 0)
-
-      -- Convenient accessor for the elements of m'.
-      m :: Int -> Int -> a
-      m i j = m' !!! (i, j)
-
-      -- Convenient accessor for the elements of x'.
-      x :: Int -> a
-      x k = x' !!! (k, 0)
+  | otherwise = ifoldr2 f zero pairs
+      where
+        -- Pairs (m_ii, b_i) needed at each step.
+        pairs :: Col (S m) (a,a)
+        pairs = zip2 (diagonal matrix) b
 
-      -- The second argument to lambda should always be zero here, so we
-      -- ignore it.
-      lambda :: Int -> Int -> a
-      lambda k _
-        | k == n = (b k) / (m k k)
-        | otherwise = ((b k) - sum [ (m k j) NP.* (x j) |
-                                    j <- [k+1..n] ]) / (m k k)
-        where
-          n = (ncols m') - 1
+        f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a
+        f i _ x (mii, bi) = set_idx x (i,0) newval
+          where
+            newval = (bi - (x `dot` (transpose $ row matrix i))) / mii
 
 
 -- | Solve the linear system m*x = b where m is positive definite.