From 64e43e504a8716cb1784de5fc33d7f02e915e2ac Mon Sep 17 00:00:00 2001 From: Michael Orlitzky Date: Tue, 11 Feb 2014 16:10:53 -0500 Subject: [PATCH] Implement forward substitute in terms of a fold. --- src/Linear/Iteration.hs | 27 ++++++++++--------- src/Linear/System.hs | 59 +++++++++++++++++++---------------------- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/src/Linear/Iteration.hs b/src/Linear/Iteration.hs index 6fb279c..6117770 100644 --- a/src/Linear/Iteration.hs +++ b/src/Linear/Iteration.hs @@ -27,6 +27,7 @@ import qualified Algebra.RealField as RealField ( C ) import qualified Algebra.ToRational as ToRational ( C ) import Linear.Matrix ( + Col, Mat(..), (!!!), (*), @@ -41,12 +42,12 @@ import Normed ( Normed(..) ) -- | A generalized implementation for Jacobi, Gauss-Seidel, etc. All -- that we really need to know is how to construct the matrix M, so we -- take a function that does it as an argument. -classical_iteration :: (Eq a, Field.C a, Arity m) - => (Mat m m a -> Mat m m a) - -> Mat m m a - -> Mat m N1 a - -> Mat m N1 a - -> Mat m N1 a +classical_iteration :: (Eq a, Field.C a, m ~ S n, Arity n) + => (Mat m m a -> Mat m m a) + -> Mat m m a + -> Col m a + -> Col m a + -> Col m a classical_iteration m_function matrix b x_current = x_next where @@ -59,8 +60,8 @@ classical_iteration m_function matrix b x_current = -- | Perform one iteration of successive over-relaxation. -- -sor_iteration :: forall m a. - (Eq a, Field.C a, Arity m) +sor_iteration :: forall m n a. + (Eq a, Field.C a, m ~ S n, Arity n) => a -- ^ Omega -> Mat m m a -- ^ Matrix A -> Mat m N1 a -- ^ Vector b @@ -79,7 +80,7 @@ sor_iteration omega = -- | Compute an infinite list of SOR iterations starting with the -- vector x0. -sor_iterations :: (Eq a, Field.C a, Arity m) +sor_iterations :: (Eq a, Field.C a, m ~ S n, Arity n) => a -> Mat m m a -> Mat m N1 a @@ -90,7 +91,7 @@ sor_iterations omega matrix b = -- | Perform one iteration of Gauss-Seidel. -gauss_seidel_iteration :: (Eq a, Field.C a, Arity m) +gauss_seidel_iteration :: (Eq a, Field.C a, m ~ S n, Arity n) => Mat m m a -> Mat m N1 a -> Mat m N1 a @@ -100,7 +101,7 @@ gauss_seidel_iteration = sor_iteration one -- | Compute an infinite list of Gauss-Seidel iterations starting with -- the vector x0. -gauss_seidel_iterations :: (Eq a, Field.C a, Arity m) +gauss_seidel_iterations :: (Eq a, Field.C a, m ~ S n, Arity n) => Mat m m a -> Mat m N1 a -> Mat m N1 a @@ -126,7 +127,7 @@ gauss_seidel_iterations matrix b = -- >>> jacobi_iteration m b x1 -- ((0.0),(0.25)) -- -jacobi_iteration :: (Eq a, Field.C a, Arity m) +jacobi_iteration :: (Eq a, Field.C a, m ~ S n, Arity n) => Mat m m a -> Mat m N1 a -> Mat m N1 a @@ -137,7 +138,7 @@ jacobi_iteration = -- | Compute an infinite list of Jacobi iterations starting with the -- vector x0. -jacobi_iterations :: (Eq a, Field.C a, Arity m) +jacobi_iterations :: (Eq a, Field.C a, m ~ S n, Arity n) => Mat m m a -> Mat m N1 a -> Mat m N1 a diff --git a/src/Linear/System.hs b/src/Linear/System.hs index d68805a..82e1e4b 100644 --- a/src/Linear/System.hs +++ b/src/Linear/System.hs @@ -9,7 +9,7 @@ module Linear.System ( where import qualified Algebra.Algebraic as Algebraic ( C ) -import Data.Vector.Fixed ( Arity ) +import Data.Vector.Fixed ( Arity, S ) import NumericPrelude hiding ( (*), abs ) import qualified NumericPrelude as NP ( (*) ) import qualified Algebra.Field as Field ( C ) @@ -20,9 +20,15 @@ import Linear.Matrix ( (!!!), cholesky, construct, + diagonal, + dot, + ifoldl2, is_lower_triangular, is_upper_triangular, ncols, + row, + set_idx, + zip2, transpose ) @@ -83,34 +89,23 @@ import Linear.Matrix ( -- True -- forward_substitute :: forall a m. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a -forward_substitute m' b' - | not (is_lower_triangular m') = + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a +forward_substitute matrix b + | not (is_lower_triangular matrix) = error "forward substitution on non-lower-triangular matrix" - | otherwise = x' - where - x' = construct lambda - - -- Convenient accessor for the elements of b'. - b :: Int -> a - b k = b' !!! (k, 0) - - -- Convenient accessor for the elements of m'. - m :: Int -> Int -> a - m i j = m' !!! (i, j) + | otherwise = ifoldl2 f zero pairs + where + -- Pairs (m_ii, b_i) needed at each step. + pairs :: Col (S m) (a,a) + pairs = zip2 (diagonal matrix) b - -- Convenient accessor for the elements of x'. - x :: Int -> a - x k = x' !!! (k, 0) + f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a + f i _ x (mii, bi) = set_idx x (i,0) newval + where + newval = (bi - (x `dot` (transpose $ row matrix i))) / mii - -- The second argument to lambda should always be zero here, so we - -- ignore it. - lambda :: Int -> Int -> a - lambda 0 _ = (b 0) / (m 0 0) - lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) | - j <- [0..k-1] ]) / (m k k) -- | Solve the system m*x = b, where m is upper-triangular. Will @@ -133,9 +128,9 @@ forward_substitute m' b' -- ((0.0),(0.0),(1.0)) -- backward_substitute :: forall m a. (Eq a, Field.C a, Arity m) - => Mat m m a - -> Col m a - -> Col m a + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a backward_substitute m' b' | not (is_upper_triangular m') = error "backward substitution on non-upper-triangular matrix" @@ -204,9 +199,9 @@ backward_substitute m' b' -- True -- solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a) - => Mat m m a - -> Col m a - -> Col m a + => Mat (S m) (S m) a + -> Col (S m) a + -> Col (S m) a solve_positive_definite m b = x where r = cholesky m -- 2.43.2