]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/System.hs
Fix backward_substitute.
[numerical-analysis.git] / src / Linear / System.hs
1 {-# LANGUAGE RebindableSyntax #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE TypeFamilies #-}
4
5 module Linear.System (
6 backward_substitute,
7 forward_substitute,
8 solve_positive_definite )
9 where
10
11 import qualified Algebra.Algebraic as Algebraic ( C )
12 import Data.Vector.Fixed ( Arity )
13 import NumericPrelude hiding ( (*), abs )
14 import qualified NumericPrelude as NP ( (*) )
15 import qualified Algebra.Field as Field ( C )
16
17 import Linear.Matrix (
18 Col,
19 Mat(..),
20 (!!!),
21 cholesky,
22 construct,
23 is_lower_triangular,
24 is_upper_triangular,
25 ncols,
26 transpose )
27
28
29 -- | Solve the system m' * x = b', where m' is lower-triangular. Will
30 -- probably crash if m' is non-singular. The result is the vector x.
31 --
32 -- Examples:
33 --
34 -- >>> import Linear.Matrix ( Mat2, Mat3, frobenius_norm, fromList )
35 -- >>> import Linear.Matrix ( vec2d, vec3d )
36 -- >>> import Naturals ( N7 )
37 --
38 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
39 -- >>> let b = vec3d (1, 2, 3::Double)
40 -- >>> forward_substitute identity b
41 -- ((1.0),(2.0),(3.0))
42 -- >>> (forward_substitute identity b) == b
43 -- True
44 --
45 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Double
46 -- >>> let b = vec2d (1, 1::Double)
47 -- >>> forward_substitute m b
48 -- ((1.0),(0.0))
49 --
50 -- >>> let m = fromList [[4,0],[0,2]] :: Mat2 Double
51 -- >>> let b = vec2d (2, 1.5 :: Double)
52 -- >>> forward_substitute m b
53 -- ((0.5),(0.75))
54 --
55 -- >>> let f1 = [0.0418]
56 -- >>> let f2 = [0.0805]
57 -- >>> let f3 = [0.1007]
58 -- >>> let f4 = [-0.0045]
59 -- >>> let f5 = [-0.0332]
60 -- >>> let f6 = [-0.0054]
61 -- >>> let f7 = [-0.0267]
62 -- >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
63 -- >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
64 -- >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
65 -- >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
66 -- >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
67 -- >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
68 -- >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
69 -- >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
70 -- >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
71 -- >>> let r = cholesky big_K
72 -- >>> let rt = transpose r
73 -- >>> let e1 = [0.0170647785413895] :: [Double]
74 -- >>> let e2 = [0.0338] :: [Double]
75 -- >>> let e3 = [0.07408] :: [Double]
76 -- >>> let e4 = [-0.00183711730708738] :: [Double]
77 -- >>> let e5 = [-0.0135538432434003] :: [Double]
78 -- >>> let e6 = [-0.00220454076850486] :: [Double]
79 -- >>> let e7 = [-0.00689391035624920] :: [Double]
80 -- >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
81 -- >>> let actual = forward_substitute rt big_F
82 -- >>> frobenius_norm (actual - expected) < 1e-10
83 -- True
84 --
85 forward_substitute :: forall a m. (Eq a, Field.C a, Arity m)
86 => Mat m m a
87 -> Col m a
88 -> Col m a
89 forward_substitute m' b'
90 | not (is_lower_triangular m') =
91 error "forward substitution on non-lower-triangular matrix"
92 | otherwise = x'
93 where
94 x' = construct lambda
95
96 -- Convenient accessor for the elements of b'.
97 b :: Int -> a
98 b k = b' !!! (k, 0)
99
100 -- Convenient accessor for the elements of m'.
101 m :: Int -> Int -> a
102 m i j = m' !!! (i, j)
103
104 -- Convenient accessor for the elements of x'.
105 x :: Int -> a
106 x k = x' !!! (k, 0)
107
108 -- The second argument to lambda should always be zero here, so we
109 -- ignore it.
110 lambda :: Int -> Int -> a
111 lambda 0 _ = (b 0) / (m 0 0)
112 lambda k _ = ((b k) - sum [ (m k j) NP.* (x j) |
113 j <- [0..k-1] ]) / (m k k)
114
115
116 -- | Solve the system m*x = b, where m is upper-triangular. Will
117 -- probably crash if m is non-singular. The result is the vector x.
118 --
119 -- Examples:
120 --
121 -- >>> import Linear.Matrix ( Mat3, fromList, vec3d )
122 --
123 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
124 -- >>> let b = vec3d (1, 2, 3::Double)
125 -- >>> backward_substitute identity b
126 -- ((1.0),(2.0),(3.0))
127 -- >>> (backward_substitute identity b) == b
128 -- True
129 --
130 -- >>> let m1 = fromList [[1,1,1], [0,1,1], [0,0,1]] :: Mat3 Double
131 -- >>> let b = vec3d (1,1,1::Double)
132 -- >>> backward_substitute m1 b
133 -- ((0.0),(0.0),(1.0))
134 --
135 backward_substitute :: forall m a. (Eq a, Field.C a, Arity m)
136 => Mat m m a
137 -> Col m a
138 -> Col m a
139 backward_substitute m' b'
140 | not (is_upper_triangular m') =
141 error "backward substitution on non-upper-triangular matrix"
142 | otherwise = x'
143 where
144 x' = construct lambda
145
146 -- Convenient accessor for the elements of b'.
147 b :: Int -> a
148 b k = b' !!! (k, 0)
149
150 -- Convenient accessor for the elements of m'.
151 m :: Int -> Int -> a
152 m i j = m' !!! (i, j)
153
154 -- Convenient accessor for the elements of x'.
155 x :: Int -> a
156 x k = x' !!! (k, 0)
157
158 -- The second argument to lambda should always be zero here, so we
159 -- ignore it.
160 lambda :: Int -> Int -> a
161 lambda k _
162 | k == n = (b k) / (m k k)
163 | otherwise = ((b k) - sum [ (m k j) NP.* (x j) |
164 j <- [k+1..n] ]) / (m k k)
165 where
166 n = (ncols m') - 1
167
168
169 -- | Solve the linear system m*x = b where m is positive definite.
170 --
171 -- Examples:
172 --
173 -- >>> import Linear.Matrix ( Col4, frobenius_norm, fromList )
174 -- >>> import Naturals ( N7 )
175 --
176 -- >>> let f1 = [0.0418]
177 -- >>> let f2 = [0.0805]
178 -- >>> let f3 = [0.1007]
179 -- >>> let f4 = [-0.0045]
180 -- >>> let f5 = [-0.0332]
181 -- >>> let f6 = [-0.0054]
182 -- >>> let f7 = [-0.0267]
183 -- >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
184 --
185 -- >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
186 -- >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
187 -- >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
188 -- >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
189 -- >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
190 -- >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
191 -- >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
192 -- >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
193 --
194 -- >>> let e1 = [1871/75000] :: [Double]
195 -- >>> let e2 = [899/25000] :: [Double]
196 -- >>> let e3 = [463/15625] :: [Double]
197 -- >>> let e4 = [-3/4000] :: [Double]
198 -- >>> let e5 = [-83/15000] :: [Double]
199 -- >>> let e6 = [-9/10000] :: [Double]
200 -- >>> let e7 = [-89/50000] :: [Double]
201 -- >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
202 -- >>> let actual = solve_positive_definite big_K big_F
203 -- >>> frobenius_norm (actual - expected) < 1e-12
204 -- True
205 --
206 solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a)
207 => Mat m m a
208 -> Col m a
209 -> Col m a
210 solve_positive_definite m b = x
211 where
212 r = cholesky m
213 -- Now, r^T*r*x = b. Let r*x = y, so the system looks like
214 -- r^T * y = b. We can solve this for y.
215 y = forward_substitute (transpose r) b
216 -- Now solve r*x = y to find the value of x.
217 x = backward_substitute r y