X-Git-Url: http://gitweb.michael.orlitzky.com/?a=blobdiff_plain;f=src%2FLinear%2FSystem.hs;h=6d46a135c66def1b27bbdd44bb8ae15029d9eb92;hb=5c0366134e8e1c12772cb685ac14b70d22d6ffed;hp=d68805a61c88ab78bf60cb04393d4e09c40604b2;hpb=7221311858e4029c2f2d2de6bfdca2dd641548dc;p=numerical-analysis.git diff --git a/src/Linear/System.hs b/src/Linear/System.hs index d68805a..6d46a13 100644 --- a/src/Linear/System.hs +++ b/src/Linear/System.hs @@ -9,20 +9,23 @@ module Linear.System ( where import qualified Algebra.Algebraic as Algebraic ( C ) -import Data.Vector.Fixed ( Arity ) +import Data.Vector.Fixed ( Arity, S ) import NumericPrelude hiding ( (*), abs ) -import qualified NumericPrelude as NP ( (*) ) import qualified Algebra.Field as Field ( C ) import Linear.Matrix ( Col, Mat(..), - (!!!), cholesky, - construct, + diagonal, + dot, + ifoldl2, + ifoldr2, is_lower_triangular, is_upper_triangular, - ncols, + row, + set_idx, + zip2, transpose ) @@ -83,38 +86,29 @@ import Linear.Matrix ( -- True -- forward_substitute :: forall a m. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a -forward_substitute m' b' - | not (is_lower_triangular m') = + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a +forward_substitute matrix b + | not (is_lower_triangular matrix) = error "forward substitution on non-lower-triangular matrix" - | otherwise = x' - where - x' = construct lambda - - -- Convenient accessor for the elements of b'. - b :: Int -> a - b k = b' !!! (k, 0) - - -- Convenient accessor for the elements of m'. - m :: Int -> Int -> a - m i j = m' !!! (i, j) + | otherwise = ifoldl2 f zero pairs + where + -- Pairs (m_ii, b_i) needed at each step. + pairs :: Col (S m) (a,a) + pairs = zip2 (diagonal matrix) b - -- Convenient accessor for the elements of x'. - x :: Int -> a - x k = x' !!! (k, 0) + f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a + f i _ x (mii, bi) = set_idx x (i,0) newval + where + newval = (bi - (x `dot` (transpose $ row matrix i))) / mii - -- The second argument to lambda should always be zero here, so we - -- ignore it. - lambda :: Int -> Int -> a - lambda 0 _ = (b 0) / (m 0 0) - lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) | - j <- [0..k-1] ]) / (m k k) -- | Solve the system m*x = b, where m is upper-triangular. Will --- probably crash if m is non-singular. The result is the vector x. +-- probably crash if m is non-singular. The result is the vector +-- x. The implementation is identical to 'forward_substitute' except +-- with a right-fold. -- -- Examples: -- @@ -133,37 +127,22 @@ forward_substitute m' b' -- ((0.0),(0.0),(1.0)) -- backward_substitute :: forall m a. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a -backward_substitute m' b' - | not (is_upper_triangular m') = + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a +backward_substitute matrix b + | not (is_upper_triangular matrix) = error "backward substitution on non-upper-triangular matrix" - | otherwise = x' - where - x' = construct lambda - - -- Convenient accessor for the elements of b'. - b :: Int -> a - b k = b' !!! (k, 0) - - -- Convenient accessor for the elements of m'. - m :: Int -> Int -> a - m i j = m' !!! (i, j) - - -- Convenient accessor for the elements of x'. - x :: Int -> a - x k = x' !!! (k, 0) + | otherwise = ifoldr2 f zero pairs + where + -- Pairs (m_ii, b_i) needed at each step. + pairs :: Col (S m) (a,a) + pairs = zip2 (diagonal matrix) b - -- The second argument to lambda should always be zero here, so we - -- ignore it. - lambda :: Int -> Int -> a - lambda k _ - | k == n = (b k) / (m k k) - | otherwise = ((b k) - sum [ (m k j) NP.* (x j) | - j <- [k+1..n] ]) / (m k k) - where - n = (ncols m') - 1 + f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a + f i _ x (mii, bi) = set_idx x (i,0) newval + where + newval = (bi - (x `dot` (transpose $ row matrix i))) / mii -- | Solve the linear system m*x = b where m is positive definite. @@ -204,9 +183,9 @@ backward_substitute m' b' -- True -- solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a) - => Mat m m a - -> Col m a - -> Col m a + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a solve_positive_definite m b = x where r = cholesky m