From 7304f41e81fe97d40afe18b8215fb00a58702502 Mon Sep 17 00:00:00 2001 From: Michael Orlitzky Date: Sat, 23 Feb 2013 19:05:10 -0500 Subject: [PATCH] Add the Linear.System module. --- src/Linear/System.hs | 104 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 src/Linear/System.hs diff --git a/src/Linear/System.hs b/src/Linear/System.hs new file mode 100644 index 0000000..e0fdf1c --- /dev/null +++ b/src/Linear/System.hs @@ -0,0 +1,104 @@ +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +module Linear.System +where + +import Data.Vector.Fixed (Dim, N1, Vector) + +import Linear.Matrix + +import NumericPrelude hiding ((*), abs) +import qualified NumericPrelude as NP ((*)) +import qualified Algebra.Field as Field + +import Debug.Trace (trace, traceShow) + +-- | Solve the system m' * x = b', where m' is upper-triangular. Will +-- probably crash if m' is non-singular. The result is the vector x. +-- +-- Examples: +-- +-- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double +-- >>> let b = vec3d (1,2,3) +-- >>> forward_substitute identity b +-- ((1.0),(2.0),(3.0)) +-- >>> (forward_substitute identity b) == b +-- True +-- +-- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Double +-- >>> let b = vec2d (1,1) +-- >>> forward_substitute m b +-- ((1.0),(0.0)) +-- +forward_substitute :: forall a v w z. + (Show a, Field.C a, + Vector z a, + Vector w (z a), + Vector w a, + Dim z ~ N1, + v ~ w) + => Mat v w a + -> Mat w z a + -> Mat w z a +forward_substitute m' b' = x' + where + x' = construct lambda + + -- Convenient accessor for the elements of b'. + b :: Int -> a + b k = b' !!! (k, 0) + + -- Convenient accessor for the elements of m'. + m :: Int -> Int -> a + m i j = m' !!! (i, j) + + -- Convenient accessor for the elements of x'. + x :: Int -> a + x k = x' !!! (k, 0) + + -- The second argument to lambda should always be zero here, so we + -- ignore it. + lambda :: Int -> Int -> a + lambda 0 _ = (b 0) / (m 0 0) + lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) | + j <- [0..k-1] ]) / (m k k) + + +-- | Solve the system m*x = b, where m is lower-triangular. Will +-- probably crash if m is non-singular. The result is the vector x. +-- +-- Examples: +-- +-- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double +-- >>> let b = vec3d (1,2,3) +-- >>> backward_substitute identity b +-- ((1.0),(2.0),(3.0)) +-- >>> (backward_substitute identity b) == b +-- True +-- +backward_substitute :: (Show a, Field.C a, + Vector z a, + Vector v (w a), + Vector w (z a), + Vector w a, + Dim z ~ N1, + v ~ w) + => Mat v w a + -> Mat w z a + -> Mat w z a +backward_substitute m b = + forward_substitute (transpose m) b + + +-- | Solve the linear system m*x = b where m is positive definite. +{- +solve_positive_definite :: Mat v w a -> Mat w z a +solve_positive_definite m b = x + where + r = cholesky m + -- First we solve r^T * y == b for y. Then let y=r*x + rx = forward_substitute (transpose r) b + -- Now solve r*x == b. +-} -- 2.44.2