]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Iteration.hs
Clean up imports everywhere.
[numerical-analysis.git] / src / Linear / Iteration.hs
1 {-# LANGUAGE NoImplicitPrelude #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE TypeFamilies #-}
4
5 -- | Classical iterative methods to solve the system Ax = b.
6
7 module Linear.Iteration (
8 gauss_seidel_iteration,
9 gauss_seidel_iterations,
10 gauss_seidel_method,
11 jacobi_iteration,
12 jacobi_iterations,
13 jacobi_method,
14 rayleigh_quotient,
15 sor_iteration,
16 sor_iterations,
17 sor_method )
18 where
19
20 import Data.List ( find )
21 import Data.Maybe ( fromJust )
22 import Data.Vector.Fixed ( Arity, N1, S )
23 import NumericPrelude hiding ( (*) )
24 import qualified Algebra.Algebraic as Algebraic ( C )
25 import qualified Algebra.Field as Field ( C )
26 import qualified Algebra.RealField as RealField ( C )
27 import qualified Algebra.ToRational as ToRational ( C )
28
29 import Linear.Matrix (
30 Mat(..),
31 (!!!),
32 (*),
33 diagonal_part,
34 dot,
35 lt_part_strict,
36 transpose )
37 import Linear.System ( forward_substitute )
38 import Normed ( Normed(..) )
39
40
41 -- | A generalized implementation for Jacobi, Gauss-Seidel, etc. All
42 -- that we really need to know is how to construct the matrix M, so we
43 -- take a function that does it as an argument.
44 classical_iteration :: (Field.C a, Arity m)
45 => (Mat m m a -> Mat m m a)
46 -> Mat m m a
47 -> Mat m N1 a
48 -> Mat m N1 a
49 -> Mat m N1 a
50 classical_iteration m_function matrix b x_current =
51 x_next
52 where
53 big_m = m_function matrix
54 big_n = big_m - matrix
55 rhs = big_n*x_current + b
56 -- TODO: Should be solve below! M might not be lower-triangular.
57 x_next = forward_substitute big_m rhs
58
59
60 -- | Perform one iteration of successive over-relaxation.
61 --
62 sor_iteration :: forall m a.
63 (Field.C a, Arity m)
64 => a -- ^ Omega
65 -> Mat m m a -- ^ Matrix A
66 -> Mat m N1 a -- ^ Vector b
67 -> Mat m N1 a -- ^ Vector x_current
68 -> Mat m N1 a -- ^ Output vector x_next
69 sor_iteration omega =
70 classical_iteration m_function
71 where
72 m_function :: Mat m m a -> Mat m m a
73 m_function matrix =
74 let diag = (recip omega) *> (diagonal_part matrix)
75 lt = lt_part_strict matrix
76 in
77 diag + lt
78
79
80 -- | Compute an infinite list of SOR iterations starting with the
81 -- vector x0.
82 sor_iterations :: (Field.C a, Arity m)
83 => a
84 -> Mat m m a
85 -> Mat m N1 a
86 -> Mat m N1 a
87 -> [Mat m N1 a]
88 sor_iterations omega matrix b =
89 iterate (sor_iteration omega matrix b)
90
91
92 -- | Perform one iteration of Gauss-Seidel.
93 gauss_seidel_iteration :: (Field.C a, Arity m)
94 => Mat m m a
95 -> Mat m N1 a
96 -> Mat m N1 a
97 -> Mat m N1 a
98 gauss_seidel_iteration = sor_iteration one
99
100
101 -- | Compute an infinite list of Gauss-Seidel iterations starting with
102 -- the vector x0.
103 gauss_seidel_iterations :: (Field.C a, Arity m)
104 => Mat m m a
105 -> Mat m N1 a
106 -> Mat m N1 a
107 -> [Mat m N1 a]
108 gauss_seidel_iterations matrix b =
109 iterate (gauss_seidel_iteration matrix b)
110
111
112 -- | Perform one Jacobi iteration,
113 --
114 -- x1 = M^(-1) * (N*x0 + b)
115 --
116 -- Examples:
117 --
118 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
119 --
120 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
121 -- >>> let x0 = vec2d (0, 0::Double)
122 -- >>> let b = vec2d (1, 1::Double)
123 -- >>> jacobi_iteration m b x0
124 -- ((0.25),(0.5))
125 -- >>> let x1 = jacobi_iteration m b x0
126 -- >>> jacobi_iteration m b x1
127 -- ((0.0),(0.25))
128 --
129 jacobi_iteration :: (Field.C a, Arity m)
130 => Mat m m a
131 -> Mat m N1 a
132 -> Mat m N1 a
133 -> Mat m N1 a
134 jacobi_iteration =
135 classical_iteration diagonal_part
136
137
138 -- | Compute an infinite list of Jacobi iterations starting with the
139 -- vector x0.
140 jacobi_iterations :: (Field.C a, Arity m)
141 => Mat m m a
142 -> Mat m N1 a
143 -> Mat m N1 a
144 -> [Mat m N1 a]
145 jacobi_iterations matrix b =
146 iterate (jacobi_iteration matrix b)
147
148
149 -- | Solve the system Ax = b using the Jacobi method. This will run
150 -- forever if the iterations do not converge.
151 --
152 -- Examples:
153 --
154 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
155 --
156 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
157 -- >>> let x0 = vec2d (0, 0::Double)
158 -- >>> let b = vec2d (1, 1::Double)
159 -- >>> let epsilon = 10**(-6)
160 -- >>> jacobi_method m b x0 epsilon
161 -- ((0.0),(0.4999995231628418))
162 --
163 jacobi_method :: (RealField.C a,
164 Algebraic.C a, -- Normed instance
165 ToRational.C a, -- Normed instance
166 Algebraic.C b,
167 RealField.C b,
168 Arity m,
169 Arity n, -- Normed instance
170 m ~ S n)
171 => Mat m m a
172 -> Mat m N1 a
173 -> Mat m N1 a
174 -> b
175 -> Mat m N1 a
176 jacobi_method =
177 classical_method jacobi_iterations
178
179
180 -- | Solve the system Ax = b using the Gauss-Seidel method. This will
181 -- run forever if the iterations do not converge.
182 --
183 -- Examples:
184 --
185 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
186 --
187 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
188 -- >>> let x0 = vec2d (0, 0::Double)
189 -- >>> let b = vec2d (1, 1::Double)
190 -- >>> let epsilon = 10**(-12)
191 -- >>> gauss_seidel_method m b x0 epsilon
192 -- ((4.547473508864641e-13),(0.49999999999954525))
193 --
194 gauss_seidel_method :: (RealField.C a,
195 Algebraic.C a, -- Normed instance
196 ToRational.C a, -- Normed instance
197 Algebraic.C b,
198 RealField.C b,
199 Arity m,
200 Arity n, -- Normed instance
201 m ~ S n)
202 => Mat m m a
203 -> Mat m N1 a
204 -> Mat m N1 a
205 -> b
206 -> Mat m N1 a
207 gauss_seidel_method =
208 classical_method gauss_seidel_iterations
209
210
211 -- | Solve the system Ax = b using the Successive Over-Relaxation
212 -- (SOR) method. This will run forever if the iterations do not
213 -- converge.
214 --
215 -- Examples:
216 --
217 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
218 --
219 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
220 -- >>> let x0 = vec2d (0, 0::Double)
221 -- >>> let b = vec2d (1, 1::Double)
222 -- >>> let epsilon = 10**(-12)
223 -- >>> sor_method 1.5 m b x0 epsilon
224 -- ((6.567246746413957e-13),(0.4999999999993727))
225 --
226 sor_method :: (RealField.C a,
227 Algebraic.C a, -- Normed instance
228 ToRational.C a, -- Normed instance
229 Algebraic.C b,
230 RealField.C b,
231 Arity m,
232 Arity n, -- Normed instance
233 m ~ S n)
234 => a
235 -> Mat m m a
236 -> Mat m N1 a
237 -> Mat m N1 a
238 -> b
239 -> Mat m N1 a
240 sor_method omega =
241 classical_method (sor_iterations omega)
242
243
244 -- | General implementation for all classical iteration methods. For
245 -- its first argument, it takes a function which generates the
246 -- sequence of iterates when supplied with the remaining arguments
247 -- (except for the tolerance).
248 --
249 classical_method :: forall m n a b.
250 (RealField.C a,
251 Algebraic.C a, -- Normed instance
252 ToRational.C a, -- Normed instance
253 Algebraic.C b,
254 RealField.C b,
255 Arity m,
256 Arity n, -- Normed instance
257 m ~ S n)
258 => (Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a])
259 -> Mat m m a
260 -> Mat m N1 a
261 -> Mat m N1 a
262 -> b
263 -> Mat m N1 a
264 classical_method iterations_function matrix b x0 epsilon =
265 -- fromJust is "safe," because the list is infinite. If the
266 -- algorithm doesn't converge, 'find' will search forever and never
267 -- return Nothing.
268 fst' $ fromJust $ find error_small_enough diff_pairs
269 where
270 x_n = iterations_function matrix b x0
271
272 pairs :: [(Mat m N1 a, Mat m N1 a)]
273 pairs = zip (tail x_n) x_n
274
275 append_diff :: (Mat m N1 a, Mat m N1 a)
276 -> (Mat m N1 a, Mat m N1 a, b)
277 append_diff (cur,prev) =
278 (cur,prev,diff)
279 where
280 diff = norm (cur - prev)
281
282 diff_pairs :: [(Mat m N1 a, Mat m N1 a, b)]
283 diff_pairs = map append_diff pairs
284
285 fst' :: (c, d, e) -> c
286 fst' (x,_,_) = x
287
288 error_small_enough :: (Mat m N1 a, Mat m N1 a, b)-> Bool
289 error_small_enough (_,_,err) = err < epsilon
290
291
292
293 -- | Compute the Rayleigh quotient of @matrix@ and @vector@.
294 --
295 -- Examples:
296 --
297 -- >>> import Linear.Matrix ( Mat2, fromList, vec2d )
298 --
299 -- >>> let m = fromList [[3,1],[1,2]] :: Mat2 Rational
300 -- >>> let v = vec2d (1, 1::Rational)
301 -- >>> rayleigh_quotient m v
302 -- 7 % 2
303 --
304 rayleigh_quotient :: (RealField.C a,
305 Arity m,
306 Arity n,
307 m ~ S n)
308 => (Mat m m a)
309 -> (Mat m N1 a)
310 -> a
311 rayleigh_quotient matrix vector =
312 (vector `dot` (matrix * vector)) / (norm_squared vector)
313 where
314 -- We don't use the norm function here to avoid the algebraic
315 -- requirement on our field.
316 norm_squared v = ((transpose v) * v) !!! (0,0)