]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/System.hs
Use a better implementation of backwards_substitute.
[numerical-analysis.git] / src / Linear / System.hs
1 {-# LANGUAGE RebindableSyntax #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE TypeFamilies #-}
4
5 module Linear.System (
6 backward_substitute,
7 forward_substitute,
8 solve_positive_definite )
9 where
10
11 import qualified Algebra.Algebraic as Algebraic ( C )
12 import Data.Vector.Fixed ( Arity, S )
13 import NumericPrelude hiding ( (*), abs )
14 import qualified Algebra.Field as Field ( C )
15
16 import Linear.Matrix (
17 Col,
18 Mat(..),
19 cholesky,
20 diagonal,
21 dot,
22 ifoldl2,
23 ifoldr2,
24 is_lower_triangular,
25 is_upper_triangular,
26 row,
27 set_idx,
28 zip2,
29 transpose )
30
31
32 -- | Solve the system m' * x = b', where m' is lower-triangular. Will
33 -- probably crash if m' is non-singular. The result is the vector x.
34 --
35 -- Examples:
36 --
37 -- >>> import Linear.Matrix ( Mat2, Mat3, frobenius_norm, fromList )
38 -- >>> import Linear.Matrix ( vec2d, vec3d )
39 -- >>> import Naturals ( N7 )
40 --
41 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
42 -- >>> let b = vec3d (1, 2, 3::Double)
43 -- >>> forward_substitute identity b
44 -- ((1.0),(2.0),(3.0))
45 -- >>> (forward_substitute identity b) == b
46 -- True
47 --
48 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Double
49 -- >>> let b = vec2d (1, 1::Double)
50 -- >>> forward_substitute m b
51 -- ((1.0),(0.0))
52 --
53 -- >>> let m = fromList [[4,0],[0,2]] :: Mat2 Double
54 -- >>> let b = vec2d (2, 1.5 :: Double)
55 -- >>> forward_substitute m b
56 -- ((0.5),(0.75))
57 --
58 -- >>> let f1 = [0.0418]
59 -- >>> let f2 = [0.0805]
60 -- >>> let f3 = [0.1007]
61 -- >>> let f4 = [-0.0045]
62 -- >>> let f5 = [-0.0332]
63 -- >>> let f6 = [-0.0054]
64 -- >>> let f7 = [-0.0267]
65 -- >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
66 -- >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
67 -- >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
68 -- >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
69 -- >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
70 -- >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
71 -- >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
72 -- >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
73 -- >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
74 -- >>> let r = cholesky big_K
75 -- >>> let rt = transpose r
76 -- >>> let e1 = [0.0170647785413895] :: [Double]
77 -- >>> let e2 = [0.0338] :: [Double]
78 -- >>> let e3 = [0.07408] :: [Double]
79 -- >>> let e4 = [-0.00183711730708738] :: [Double]
80 -- >>> let e5 = [-0.0135538432434003] :: [Double]
81 -- >>> let e6 = [-0.00220454076850486] :: [Double]
82 -- >>> let e7 = [-0.00689391035624920] :: [Double]
83 -- >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
84 -- >>> let actual = forward_substitute rt big_F
85 -- >>> frobenius_norm (actual - expected) < 1e-10
86 -- True
87 --
88 forward_substitute :: forall a m. (Eq a, Field.C a, Arity m)
89 => Mat (S m) (S m) a
90 -> Col (S m) a
91 -> Col (S m) a
92 forward_substitute matrix b
93 | not (is_lower_triangular matrix) =
94 error "forward substitution on non-lower-triangular matrix"
95 | otherwise = ifoldl2 f zero pairs
96 where
97 -- Pairs (m_ii, b_i) needed at each step.
98 pairs :: Col (S m) (a,a)
99 pairs = zip2 (diagonal matrix) b
100
101 f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a
102 f i _ x (mii, bi) = set_idx x (i,0) newval
103 where
104 newval = (bi - (x `dot` (transpose $ row matrix i))) / mii
105
106
107
108 -- | Solve the system m*x = b, where m is upper-triangular. Will
109 -- probably crash if m is non-singular. The result is the vector
110 -- x. The implementation is identical to 'forward_substitute' except
111 -- with a right-fold.
112 --
113 -- Examples:
114 --
115 -- >>> import Linear.Matrix ( Mat3, fromList, vec3d )
116 --
117 -- >>> let identity = fromList [[1,0,0],[0,1,0],[0,0,1]] :: Mat3 Double
118 -- >>> let b = vec3d (1, 2, 3::Double)
119 -- >>> backward_substitute identity b
120 -- ((1.0),(2.0),(3.0))
121 -- >>> (backward_substitute identity b) == b
122 -- True
123 --
124 -- >>> let m1 = fromList [[1,1,1], [0,1,1], [0,0,1]] :: Mat3 Double
125 -- >>> let b = vec3d (1,1,1::Double)
126 -- >>> backward_substitute m1 b
127 -- ((0.0),(0.0),(1.0))
128 --
129 backward_substitute :: forall m a. (Eq a, Field.C a, Arity m)
130 => Mat (S m) (S m) a
131 -> Col (S m) a
132 -> Col (S m) a
133 backward_substitute matrix b
134 | not (is_upper_triangular matrix) =
135 error "backward substitution on non-upper-triangular matrix"
136 | otherwise = ifoldr2 f zero pairs
137 where
138 -- Pairs (m_ii, b_i) needed at each step.
139 pairs :: Col (S m) (a,a)
140 pairs = zip2 (diagonal matrix) b
141
142 f :: Int -> Int -> Col (S m) a -> (a, a) -> Col (S m) a
143 f i _ x (mii, bi) = set_idx x (i,0) newval
144 where
145 newval = (bi - (x `dot` (transpose $ row matrix i))) / mii
146
147
148 -- | Solve the linear system m*x = b where m is positive definite.
149 --
150 -- Examples:
151 --
152 -- >>> import Linear.Matrix ( Col4, frobenius_norm, fromList )
153 -- >>> import Naturals ( N7 )
154 --
155 -- >>> let f1 = [0.0418]
156 -- >>> let f2 = [0.0805]
157 -- >>> let f3 = [0.1007]
158 -- >>> let f4 = [-0.0045]
159 -- >>> let f5 = [-0.0332]
160 -- >>> let f6 = [-0.0054]
161 -- >>> let f7 = [-0.0267]
162 -- >>> let big_F = fromList [f1,f2,f3,f4,f5,f6,f7] :: Col N7 Double
163 --
164 -- >>> let k1 = [6, -3, 0, 0, 0, 0, 0] :: [Double]
165 -- >>> let k2 = [-3, 10.5, -7.5, 0, 0, 0, 0] :: [Double]
166 -- >>> let k3 = [0, -7.5, 12.5, 0, 0, 0, 0] :: [Double]
167 -- >>> let k4 = [0, 0, 0, 6, 0, 0, 0] :: [Double]
168 -- >>> let k5 = [0, 0, 0, 0, 6, 0, 0] :: [Double]
169 -- >>> let k6 = [0, 0, 0, 0, 0, 6, 0] :: [Double]
170 -- >>> let k7 = [0, 0, 0, 0, 0, 0, 15] :: [Double]
171 -- >>> let big_K = fromList [k1,k2,k3,k4,k5,k6,k7] :: Mat N7 N7 Double
172 --
173 -- >>> let e1 = [1871/75000] :: [Double]
174 -- >>> let e2 = [899/25000] :: [Double]
175 -- >>> let e3 = [463/15625] :: [Double]
176 -- >>> let e4 = [-3/4000] :: [Double]
177 -- >>> let e5 = [-83/15000] :: [Double]
178 -- >>> let e6 = [-9/10000] :: [Double]
179 -- >>> let e7 = [-89/50000] :: [Double]
180 -- >>> let expected = fromList [e1,e2,e3,e4,e5,e6,e7] :: Col N7 Double
181 -- >>> let actual = solve_positive_definite big_K big_F
182 -- >>> frobenius_norm (actual - expected) < 1e-12
183 -- True
184 --
185 solve_positive_definite :: (Arity m, Algebraic.C a, Eq a, Field.C a)
186 => Mat (S m) (S m) a
187 -> Col (S m) a
188 -> Col (S m) a
189 solve_positive_definite m b = x
190 where
191 r = cholesky m
192 -- Now, r^T*r*x = b. Let r*x = y, so the system looks like
193 -- r^T * y = b. We can solve this for y.
194 y = forward_substitute (transpose r) b
195 -- Now solve r*x = y to find the value of x.
196 x = backward_substitute r y