1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE TypeFamilies #-}
4 -- | Classical iterative methods to solve the system Ax = b.
6 module Linear.Iteration
9 import Data.List (find)
10 import Data.Maybe (fromJust)
11 import Data.Vector.Fixed (Arity, N1, S)
12 import NumericPrelude hiding ((*))
13 import qualified Algebra.Algebraic as Algebraic
14 import qualified Algebra.Field as Field
15 import qualified Algebra.RealField as RealField
16 import qualified Algebra.ToRational as ToRational
17 import qualified Prelude as P
23 -- | Perform one Jacobi iteration,
25 -- x1 = M^(-1) * (N*x0 + b)
29 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
30 -- >>> let x0 = vec2d (0, 0::Double)
31 -- >>> let b = vec2d (1, 1::Double)
32 -- >>> jacobi_iteration m b x0
34 -- >>> let x1 = jacobi_iteration m b x0
35 -- >>> jacobi_iteration m b x1
38 jacobi_iteration :: (Field.C a, Arity m)
43 jacobi_iteration matrix b x_current =
46 big_m = diagonal matrix
47 big_n = big_m - matrix
48 rhs = big_n*x_current + b
49 x_next = forward_substitute big_m rhs
52 -- | Compute an infinite list of Jacobi iterations starting with the
54 jacobi_iterations :: (Field.C a, Arity m)
59 jacobi_iterations matrix b =
60 iterate (jacobi_iteration matrix b)
63 -- | Solve the system Ax = b using the Jacobi method. This will run
64 -- forever if the iterations do not converge.
68 -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double
69 -- >>> let x0 = vec2d (0, 0::Double)
70 -- >>> let b = vec2d (1, 1::Double)
71 -- >>> let epsilon = 10**(-6)
72 -- >>> jacobi_method m b x0 epsilon
73 -- ((0.0),(0.4999995231628418))
75 jacobi_method :: forall m n a b.
77 Algebraic.C a, -- Normed instance
78 ToRational.C a, -- Normed instance
82 Arity n, -- Normed instance
89 jacobi_method matrix b x0 epsilon =
90 -- fromJust is "safe," because the list is infinite. If the
91 -- algorithm doesn't converge, 'find' will search forever and never
93 fst' $ fromJust $ find error_small_enough diff_pairs
95 x_n = jacobi_iterations matrix b x0
97 pairs :: [(Mat m N1 a, Mat m N1 a)]
98 pairs = zip (tail x_n) x_n
100 append_diff :: (Mat m N1 a, Mat m N1 a)
101 -> (Mat m N1 a, Mat m N1 a, b)
102 append_diff (cur,prev) =
105 diff = norm (cur - prev)
107 diff_pairs :: [(Mat m N1 a, Mat m N1 a, b)]
108 diff_pairs = map append_diff pairs
110 fst' :: (c, d, e) -> c
113 error_small_enough :: (Mat m N1 a, Mat m N1 a, b)-> Bool
114 error_small_enough (_,_,err) = err < epsilon