X-Git-Url: http://gitweb.michael.orlitzky.com/?p=numerical-analysis.git;a=blobdiff_plain;f=src%2FLinear%2FSystem.hs;h=82e1e4b68282cd56f74732e38ee7107aa16c1cc4;hp=d68805a61c88ab78bf60cb04393d4e09c40604b2;hb=64e43e504a8716cb1784de5fc33d7f02e915e2ac;hpb=3c226d5d5ceb0781b10d86dcb958846f1cc9b075 diff --git a/src/Linear/System.hs b/src/Linear/System.hs index d68805a..82e1e4b 100644 --- a/src/Linear/System.hs +++ b/src/Linear/System.hs @@ -9,7 +9,7 @@ module Linear.System ( where import qualified Algebra.Algebraic as Algebraic ( C ) -import Data.Vector.Fixed ( Arity ) +import Data.Vector.Fixed ( Arity, S ) import NumericPrelude hiding ( (*), abs ) import qualified NumericPrelude as NP ( (*) ) import qualified Algebra.Field as Field ( C ) @@ -20,9 +20,15 @@ import Linear.Matrix ( (!!!), cholesky, construct, + diagonal, + dot, + ifoldl2, is_lower_triangular, is_upper_triangular, ncols, + row, + set_idx, + zip2, transpose ) @@ -83,34 +89,23 @@ import Linear.Matrix ( -- True -- forward_substitute :: forall a m. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a -forward_substitute m' b' - | not (is_lower_triangular m') = + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a +forward_substitute matrix b + | not (is_lower_triangular matrix) = error "forward substitution on non-lower-triangular matrix" - | otherwise = x' - where - x' = construct lambda - - -- Convenient accessor for the elements of b'. - b :: Int -> a - b k = b' !!! (k, 0) - - -- Convenient accessor for the elements of m'. - m :: Int -> Int -> a - m i j = m' !!! (i, j) + | otherwise = ifoldl2 f zero pairs + where + -- Pairs (m_ii, b_i) needed at each step. + pairs :: Col (S m) (a,a) + pairs = zip2 (diagonal matrix) b - -- Convenient accessor for the elements of x'. - x :: Int -> a - x k = x' !!! (k, 0) + f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a + f i _ x (mii, bi) = set_idx x (i,0) newval + where + newval = (bi - (x `dot` (transpose $ row matrix i))) / mii - -- The second argument to lambda should always be zero here, so we - -- ignore it. - lambda :: Int -> Int -> a - lambda 0 _ = (b 0) / (m 0 0) - lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) | - j <- [0..k-1] ]) / (m k k) -- | Solve the system m*x = b, where m is upper-triangular. Will @@ -133,9 +128,9 @@ forward_substitute m' b' -- ((0.0),(0.0),(1.0)) -- backward_substitute :: forall m a. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a backward_substitute m' b' | not (is_upper_triangular m') = error "backward substitution on non-upper-triangular matrix" @@ -204,9 +199,9 @@ backward_substitute m' b' -- True -- solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a) - => Mat m m a - -> Col m a - -> Col m a + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a solve_positive_definite m b = x where r = cholesky m