]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Iteration.hs
Implement forward substitute in terms of a fold.
[numerical-analysis.git] / src / Linear / Iteration.hs
1 {-# LANGUAGE NoImplicitPrelude #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE TypeFamilies #-}
4
5 -- | Classical iterative methods to solve the system Ax = b.
6
7 module Linear.Iteration (
8 gauss_seidel_iteration,
9 gauss_seidel_iterations,
10 gauss_seidel_method,
11 jacobi_iteration,
12 jacobi_iterations,
13 jacobi_method,
14 rayleigh_quotient,
15 sor_iteration,
16 sor_iterations,
17 sor_method )
18 where
19
20 import Data.List ( find )
21 import Data.Maybe ( fromJust )
22 import Data.Vector.Fixed ( Arity, N1, S )
23 import NumericPrelude hiding ( (*) )
24 import qualified Algebra.Algebraic as Algebraic ( C )
25 import qualified Algebra.Field as Field ( C )
26 import qualified Algebra.RealField as RealField ( C )
27 import qualified Algebra.ToRational as ToRational ( C )
28
29 import Linear.Matrix (
30 Col,
31 Mat(..),
32 (!!!),
33 (*),
34 diagonal_part,
35 dot,
36 lt_part_strict,
37 transpose )
38 import Linear.System ( forward_substitute )
39 import Normed ( Normed(..) )
40
41
42 -- | A generalized implementation for Jacobi, Gauss-Seidel, etc. All
43 -- that we really need to know is how to construct the matrix M, so we
44 -- take a function that does it as an argument.
45 classical_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
46 => (Mat m m a -> Mat m m a)
47 -> Mat m m a
48 -> Col m a
49 -> Col m a
50 -> Col m a
51 classical_iteration m_function matrix b x_current =
52 x_next
53 where
54 big_m = m_function matrix
55 big_n = big_m - matrix
56 rhs = big_n*x_current + b
57 -- TODO: Should be solve below! M might not be lower-triangular.
58 x_next = forward_substitute big_m rhs
59
60
61 -- | Perform one iteration of successive over-relaxation.
62 --
63 sor_iteration :: forall m n a.
64 (Eq a, Field.C a, m ~ S n, Arity n)
65 => a -- ^ Omega
66 -> Mat m m a -- ^ Matrix A
67 -> Mat m N1 a -- ^ Vector b
68 -> Mat m N1 a -- ^ Vector x_current
69 -> Mat m N1 a -- ^ Output vector x_next
70 sor_iteration omega =
71 classical_iteration m_function
72 where
73 m_function :: Mat m m a -> Mat m m a
74 m_function matrix =
75 let diag = (recip omega) *> (diagonal_part matrix)
76 lt = lt_part_strict matrix
77 in
78 diag + lt
79
80
81 -- | Compute an infinite list of SOR iterations starting with the
82 -- vector x0.
83 sor_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
84 => a
85 -> Mat m m a
86 -> Mat m N1 a
87 -> Mat m N1 a
88 -> [Mat m N1 a]
89 sor_iterations omega matrix b =
90 iterate (sor_iteration omega matrix b)
91
92
93 -- | Perform one iteration of Gauss-Seidel.
94 gauss_seidel_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
95 => Mat m m a
96 -> Mat m N1 a
97 -> Mat m N1 a
98 -> Mat m N1 a
99 gauss_seidel_iteration = sor_iteration one
100
101
102 -- | Compute an infinite list of Gauss-Seidel iterations starting with
103 -- the vector x0.
104 gauss_seidel_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
105 => Mat m m a
106 -> Mat m N1 a
107 -> Mat m N1 a
108 -> [Mat m N1 a]
109 gauss_seidel_iterations matrix b =
110 iterate (gauss_seidel_iteration matrix b)
111
112
113 -- | Perform one Jacobi iteration,
114 --
115 -- x1 = M^(-1) * (N*x0 + b)
116 --
117 -- Examples:
118 --
119 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
120 --
121 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
122 -- >>> let x0 = vec2d (0, 0::Double)
123 -- >>> let b = vec2d (1, 1::Double)
124 -- >>> jacobi_iteration m b x0
125 -- ((0.25),(0.5))
126 -- >>> let x1 = jacobi_iteration m b x0
127 -- >>> jacobi_iteration m b x1
128 -- ((0.0),(0.25))
129 --
130 jacobi_iteration :: (Eq a, Field.C a, m ~ S n, Arity n)
131 => Mat m m a
132 -> Mat m N1 a
133 -> Mat m N1 a
134 -> Mat m N1 a
135 jacobi_iteration =
136 classical_iteration diagonal_part
137
138
139 -- | Compute an infinite list of Jacobi iterations starting with the
140 -- vector x0.
141 jacobi_iterations :: (Eq a, Field.C a, m ~ S n, Arity n)
142 => Mat m m a
143 -> Mat m N1 a
144 -> Mat m N1 a
145 -> [Mat m N1 a]
146 jacobi_iterations matrix b =
147 iterate (jacobi_iteration matrix b)
148
149
150 -- | Solve the system Ax = b using the Jacobi method. This will run
151 -- forever if the iterations do not converge.
152 --
153 -- Examples:
154 --
155 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
156 --
157 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
158 -- >>> let x0 = vec2d (0, 0::Double)
159 -- >>> let b = vec2d (1, 1::Double)
160 -- >>> let epsilon = 10**(-6)
161 -- >>> jacobi_method m b x0 epsilon
162 -- ((0.0),(0.4999995231628418))
163 --
164 jacobi_method :: (RealField.C a,
165 Algebraic.C a, -- Normed instance
166 ToRational.C a, -- Normed instance
167 Algebraic.C b,
168 RealField.C b,
169 Arity m,
170 Arity n, -- Normed instance
171 m ~ S n)
172 => Mat m m a
173 -> Mat m N1 a
174 -> Mat m N1 a
175 -> b
176 -> Mat m N1 a
177 jacobi_method =
178 classical_method jacobi_iterations
179
180
181 -- | Solve the system Ax = b using the Gauss-Seidel method. This will
182 -- run forever if the iterations do not converge.
183 --
184 -- Examples:
185 --
186 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
187 --
188 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
189 -- >>> let x0 = vec2d (0, 0::Double)
190 -- >>> let b = vec2d (1, 1::Double)
191 -- >>> let epsilon = 10**(-12)
192 -- >>> gauss_seidel_method m b x0 epsilon
193 -- ((4.547473508864641e-13),(0.49999999999954525))
194 --
195 gauss_seidel_method :: (RealField.C a,
196 Algebraic.C a, -- Normed instance
197 ToRational.C a, -- Normed instance
198 Algebraic.C b,
199 RealField.C b,
200 Arity m,
201 Arity n, -- Normed instance
202 m ~ S n)
203 => Mat m m a
204 -> Mat m N1 a
205 -> Mat m N1 a
206 -> b
207 -> Mat m N1 a
208 gauss_seidel_method =
209 classical_method gauss_seidel_iterations
210
211
212 -- | Solve the system Ax = b using the Successive Over-Relaxation
213 -- (SOR) method. This will run forever if the iterations do not
214 -- converge.
215 --
216 -- Examples:
217 --
218 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
219 --
220 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
221 -- >>> let x0 = vec2d (0, 0::Double)
222 -- >>> let b = vec2d (1, 1::Double)
223 -- >>> let epsilon = 10**(-12)
224 -- >>> sor_method 1.5 m b x0 epsilon
225 -- ((6.567246746413957e-13),(0.4999999999993727))
226 --
227 sor_method :: (RealField.C a,
228 Algebraic.C a, -- Normed instance
229 ToRational.C a, -- Normed instance
230 Algebraic.C b,
231 RealField.C b,
232 Arity m,
233 Arity n, -- Normed instance
234 m ~ S n)
235 => a
236 -> Mat m m a
237 -> Mat m N1 a
238 -> Mat m N1 a
239 -> b
240 -> Mat m N1 a
241 sor_method omega =
242 classical_method (sor_iterations omega)
243
244
245 -- | General implementation for all classical iteration methods. For
246 -- its first argument, it takes a function which generates the
247 -- sequence of iterates when supplied with the remaining arguments
248 -- (except for the tolerance).
249 --
250 classical_method :: forall m n a b.
251 (RealField.C a,
252 Algebraic.C a, -- Normed instance
253 ToRational.C a, -- Normed instance
254 Algebraic.C b,
255 RealField.C b,
256 Arity m,
257 Arity n, -- Normed instance
258 m ~ S n)
259 => (Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a])
260 -> Mat m m a
261 -> Mat m N1 a
262 -> Mat m N1 a
263 -> b
264 -> Mat m N1 a
265 classical_method iterations_function matrix b x0 epsilon =
266 -- fromJust is "safe," because the list is infinite. If the
267 -- algorithm doesn't converge, 'find' will search forever and never
268 -- return Nothing.
269 fst' $ fromJust $ find error_small_enough diff_pairs
270 where
271 x_n = iterations_function matrix b x0
272
273 pairs :: [(Mat m N1 a, Mat m N1 a)]
274 pairs = zip (tail x_n) x_n
275
276 append_diff :: (Mat m N1 a, Mat m N1 a)
277 -> (Mat m N1 a, Mat m N1 a, b)
278 append_diff (cur,prev) =
279 (cur,prev,diff)
280 where
281 diff = norm (cur - prev)
282
283 diff_pairs :: [(Mat m N1 a, Mat m N1 a, b)]
284 diff_pairs = map append_diff pairs
285
286 fst' :: (c, d, e) -> c
287 fst' (x,_,_) = x
288
289 error_small_enough :: (Mat m N1 a, Mat m N1 a, b)-> Bool
290 error_small_enough (_,_,err) = err < epsilon
291
292
293
294 -- | Compute the Rayleigh quotient of @matrix@ and @vector@.
295 --
296 -- Examples:
297 --
298 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
299 --
300 -- >>> let m = fromList [[3,1],[1,2]] :: Mat2 Rational
301 -- >>> let v = vec2d (1, 1::Rational)
302 -- >>> rayleigh_quotient m v
303 -- 7 % 2
304 --
305 rayleigh_quotient :: (RealField.C a,
306 Arity m,
307 Arity n,
308 m ~ S n)
309 => (Mat m m a)
310 -> (Mat m N1 a)
311 -> a
312 rayleigh_quotient matrix vector =
313 (vector `dot` (matrix * vector)) / (norm_squared vector)
314 where
315 -- We don't use the norm function here to avoid the algebraic
316 -- requirement on our field.
317 norm_squared v = ((transpose v) * v) !!! (0,0)