]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/System.hs
2d75f611de1dc85825677ce59e72fe008bdbd21d
[numerical-analysis.git] / src / Linear / System.hs
1 {-# LANGUAGE RebindableSyntax #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE TypeFamilies #-}
4
5 module Linear.System (
6 backward_substitute,
7 forward_substitute )
8 where
9
10 import Data.Vector.Fixed ( Arity, N1 )
11 import NumericPrelude hiding ( (*), abs )
12 import qualified NumericPrelude as NP ( (*) )
13 import qualified Algebra.Field as Field ( C )
14
15 import Linear.Matrix ( Mat(..), (!!!), construct, transpose )
16
17
18 -- | Solve the system m' * x = b', where m' is upper-triangular. Will
19 -- probably crash if m' is non-singular. The result is the vector x.
20 --
21 -- Examples:
22 --
23 -- >>> import Linear.Matrix ( Mat2, Mat3, fromList, vec2d, vec3d )
24 --
25 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
26 -- >>> let b = vec3d (1, 2, 3::Double)
27 -- >>> forward_substitute identity b
28 -- ((1.0),(2.0),(3.0))
29 -- >>> (forward_substitute identity b) == b
30 -- True
31 --
32 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Double
33 -- >>> let b = vec2d (1, 1::Double)
34 -- >>> forward_substitute m b
35 -- ((1.0),(0.0))
36 --
37 -- >>> let m = fromList [[4,0],[0,2]] :: Mat2 Double
38 -- >>> let b = vec2d (2, 1.5 :: Double)
39 -- >>> forward_substitute m b
40 -- ((0.5),(0.75))
41 --
42 forward_substitute :: forall a m. (Field.C a, Arity m)
43 => Mat m m a
44 -> Mat m N1 a
45 -> Mat m N1 a
46 forward_substitute m' b' = x'
47 where
48 x' = construct lambda
49
50 -- Convenient accessor for the elements of b'.
51 b :: Int -> a
52 b k = b' !!! (k, 0)
53
54 -- Convenient accessor for the elements of m'.
55 m :: Int -> Int -> a
56 m i j = m' !!! (i, j)
57
58 -- Convenient accessor for the elements of x'.
59 x :: Int -> a
60 x k = x' !!! (k, 0)
61
62 -- The second argument to lambda should always be zero here, so we
63 -- ignore it.
64 lambda :: Int -> Int -> a
65 lambda 0 _ = (b 0) / (m 0 0)
66 lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) |
67 j <- [0..k-1] ]) / (m k k)
68
69
70 -- | Solve the system m*x = b, where m is lower-triangular. Will
71 -- probably crash if m is non-singular. The result is the vector x.
72 --
73 -- Examples:
74 --
75 -- >>> import Linear.Matrix ( Mat3, fromList, vec3d )
76 --
77 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
78 -- >>> let b = vec3d (1, 2, 3::Double)
79 -- >>> backward_substitute identity b
80 -- ((1.0),(2.0),(3.0))
81 -- >>> (backward_substitute identity b) == b
82 -- True
83 --
84 backward_substitute :: (Field.C a, Arity m)
85 => Mat m m a
86 -> Mat m N1 a
87 -> Mat m N1 a
88 backward_substitute m =
89 forward_substitute (transpose m)
90
91
92 -- | Solve the linear system m*x = b where m is positive definite.
93 {-
94 solve_positive_definite :: Mat v w a -> Mat w z a
95 solve_positive_definite m b = x
96 where
97 r = cholesky m
98 -- First we solve r^T * y == b for y. Then let y=r*x
99 rx = forward_substitute (transpose r) b
100 -- Now solve r*x == b.
101 -}