]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Get determinants working.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
8
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
13 --
14 module Linear.Matrix
15 where
16
17 import Data.List (intercalate)
18
19 import Data.Vector.Fixed (
20 N1,
21 N2,
22 N3,
23 N4,
24 N5,
25 S,
26 Z,
27 mk1,
28 mk2,
29 mk3,
30 mk4,
31 mk5
32 )
33 import qualified Data.Vector.Fixed as V (
34 and,
35 fromList,
36 head,
37 length,
38 map,
39 maximum,
40 replicate,
41 toList,
42 zipWith
43 )
44 import Data.Vector.Fixed.Boxed (Vec)
45 import Data.Vector.Fixed.Internal (Arity, arity)
46 import Linear.Vector
47 import Normed
48
49 import NumericPrelude hiding ((*), abs)
50 import qualified NumericPrelude as NP ((*))
51 import qualified Algebra.Algebraic as Algebraic
52 import Algebra.Algebraic (root)
53 import qualified Algebra.Additive as Additive
54 import qualified Algebra.Ring as Ring
55 import qualified Algebra.Module as Module
56 import qualified Algebra.RealRing as RealRing
57 import qualified Algebra.ToRational as ToRational
58 import qualified Algebra.Transcendental as Transcendental
59 import qualified Prelude as P
60
61 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
62 type Mat1 a = Mat N1 N1 a
63 type Mat2 a = Mat N2 N2 a
64 type Mat3 a = Mat N3 N3 a
65 type Mat4 a = Mat N4 N4 a
66 type Mat5 a = Mat N5 N5 a
67
68 instance (Eq a) => Eq (Mat m n a) where
69 -- | Compare a row at a time.
70 --
71 -- Examples:
72 --
73 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
74 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
75 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
76 -- >>> m1 == m2
77 -- True
78 -- >>> m1 == m3
79 -- False
80 --
81 (Mat rows1) == (Mat rows2) =
82 V.and $ V.zipWith comp rows1 rows2
83 where
84 -- Compare a row, one column at a time.
85 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
86
87
88 instance (Show a) => Show (Mat m n a) where
89 -- | Display matrices and vectors as ordinary tuples. This is poor
90 -- practice, but these results are primarily displayed
91 -- interactively and convenience trumps correctness (said the guy
92 -- who insists his vector lengths be statically checked at
93 -- compile-time).
94 --
95 -- Examples:
96 --
97 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
98 -- >>> show m
99 -- ((1,2),(3,4))
100 --
101 show (Mat rows) =
102 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
103 where
104 row_strings = V.map show_vector rows
105 show_vector v1 =
106 "(" ++ (intercalate "," element_strings) ++ ")"
107 where
108 v1l = V.toList v1
109 element_strings = P.map show v1l
110
111
112 -- | Convert a matrix to a nested list.
113 toList :: Mat m n a -> [[a]]
114 toList (Mat rows) = map V.toList (V.toList rows)
115
116 -- | Create a matrix from a nested list.
117 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
118 fromList vs = Mat (V.fromList $ map V.fromList vs)
119
120
121 -- | Unsafe indexing.
122 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
123 (!!!) m (i, j) = (row m i) ! j
124
125 -- | Safe indexing.
126 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
127 (!!?) m@(Mat rows) (i, j)
128 | i < 0 || j < 0 = Nothing
129 | i > V.length rows = Nothing
130 | otherwise = if j > V.length (row m j)
131 then Nothing
132 else Just $ (row m j) ! j
133
134
135 -- | The number of rows in the matrix.
136 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
137 nrows _ = arity (undefined :: m)
138
139 -- | The number of columns in the first row of the
140 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
141 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
142 ncols _ = arity (undefined :: n)
143
144
145 -- | Return the @i@th row of @m@. Unsafe.
146 row :: Mat m n a -> Int -> (Vec n a)
147 row (Mat rows) i = rows ! i
148
149
150 -- | Return the @j@th column of @m@. Unsafe.
151 column :: Mat m n a -> Int -> (Vec m a)
152 column (Mat rows) j =
153 V.map (element j) rows
154 where
155 element = flip (!)
156
157
158
159
160 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
161 -- implementation.. it would be a little cleaner to use imap, but it
162 -- doesn't seem to work.
163 --
164 -- TODO: Don't cheat with fromList.
165 --
166 -- Examples:
167 --
168 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
169 -- >>> transpose m
170 -- ((1,3),(2,4))
171 --
172 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
173 transpose m = Mat $ V.fromList column_list
174 where
175 column_list = [ column m i | i <- [0..(ncols m)-1] ]
176
177
178 -- | Is @m@ symmetric?
179 --
180 -- Examples:
181 --
182 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
183 -- >>> symmetric m1
184 -- True
185 --
186 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
187 -- >>> symmetric m2
188 -- False
189 --
190 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
191 symmetric m =
192 m == (transpose m)
193
194
195 -- | Construct a new matrix from a function @lambda@. The function
196 -- @lambda@ should take two parameters i,j corresponding to the
197 -- entries in the matrix. The i,j entry of the resulting matrix will
198 -- have the value returned by lambda i j.
199 --
200 -- TODO: Don't cheat with fromList.
201 --
202 -- Examples:
203 --
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
207 --
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat rows
211 where
212 -- The arity trick is used in Data.Vector.Fixed.length.
213 imax = (arity (undefined :: m)) - 1
214 jmax = (arity (undefined :: n)) - 1
215 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
216 rows = V.fromList [ row' i | i <- [0..imax] ]
217
218
219 -- | Given a positive-definite matrix @m@, computes the
220 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
221 -- values on the diagonal of @r@ positive.
222 --
223 -- Examples:
224 --
225 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
226 -- >>> cholesky m1
227 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
228 -- >>> (transpose (cholesky m1)) * (cholesky m1)
229 -- ((20.000000000000004,-1.0),(-1.0,20.0))
230 --
231 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
232 => (Mat m n a) -> (Mat m n a)
233 cholesky m = construct r
234 where
235 r :: Int -> Int -> a
236 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
237 | i < j =
238 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
239 | otherwise = 0
240
241
242 -- | Returns True if the given matrix is upper-triangular, and False
243 -- otherwise.
244 --
245 -- Examples:
246 --
247 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
248 -- >>> is_upper_triangular m
249 -- False
250 --
251 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
252 -- >>> is_upper_triangular m
253 -- True
254 --
255 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
256 => Mat m n a -> Bool
257 is_upper_triangular m =
258 and $ concat results
259 where
260 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
261
262 test :: Int -> Int -> Bool
263 test i j
264 | i <= j = True
265 | otherwise = m !!! (i,j) == 0
266
267
268 -- | Returns True if the given matrix is lower-triangular, and False
269 -- otherwise.
270 --
271 -- Examples:
272 --
273 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
274 -- >>> is_lower_triangular m
275 -- True
276 --
277 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
278 -- >>> is_lower_triangular m
279 -- False
280 --
281 is_lower_triangular :: (Eq a,
282 Ring.C a,
283 Arity m,
284 Arity n)
285 => Mat m n a
286 -> Bool
287 is_lower_triangular = is_upper_triangular . transpose
288
289
290 -- | Returns True if the given matrix is triangular, and False
291 -- otherwise.
292 --
293 -- Examples:
294 --
295 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
296 -- >>> is_triangular m
297 -- True
298 --
299 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
300 -- >>> is_triangular m
301 -- True
302 --
303 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
304 -- >>> is_triangular m
305 -- False
306 --
307 is_triangular :: (Eq a,
308 Ring.C a,
309 Arity m,
310 Arity n)
311 => Mat m n a
312 -> Bool
313 is_triangular m = is_upper_triangular m || is_lower_triangular m
314
315
316 -- | Return the (i,j)th minor of m.
317 --
318 -- Examples:
319 --
320 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
321 -- >>> minor m 0 0 :: Mat2 Int
322 -- ((5,6),(8,9))
323 -- >>> minor m 1 1 :: Mat2 Int
324 -- ((1,3),(7,9))
325 --
326 minor :: (m ~ S r,
327 n ~ S t,
328 Arity r,
329 Arity t)
330 => Mat m n a
331 -> Int
332 -> Int
333 -> Mat r t a
334 minor (Mat rows) i j = m
335 where
336 rows' = delete rows i
337 m = Mat $ V.map ((flip delete) j) rows'
338
339
340 class (Eq a, Ring.C a) => Determined p a where
341 determinant :: (p a) -> a
342
343 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
344 determinant (Mat rows) = (V.head . V.head) rows
345
346 instance (Eq a,
347 Ring.C a,
348 Arity n,
349 Determined (Mat (S n) (S n)) a)
350 => Determined (Mat (S (S n)) (S (S n))) a where
351 -- | The recursive definition with a special-case for triangular matrices.
352 --
353 -- Examples:
354 --
355 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
356 -- >>> determinant m
357 -- -1
358 --
359 determinant m
360 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
361 | otherwise = determinant_recursive
362 where
363 m' i j = m !!! (i,j)
364
365 det_minor i j = determinant (minor m i j)
366
367 determinant_recursive =
368 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
369 | j <- [0..(ncols m)-1] ]
370
371
372
373 -- | Matrix multiplication.
374 --
375 -- Examples:
376 --
377 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
378 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
379 -- >>> m1 * m2
380 -- ((22,28),(49,64))
381 --
382 infixl 7 *
383 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
384 => Mat m n a
385 -> Mat n p a
386 -> Mat m p a
387 (*) m1 m2 = construct lambda
388 where
389 lambda i j =
390 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
391
392
393
394 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
395
396 (Mat rows1) + (Mat rows2) =
397 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
398
399 (Mat rows1) - (Mat rows2) =
400 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
401
402 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
403
404
405 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
406 -- The first * is ring multiplication, the second is matrix
407 -- multiplication.
408 m1 * m2 = m1 * m2
409
410
411 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
412 -- We can multiply a matrix by a scalar of the same type as its
413 -- elements.
414 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
415
416
417 instance (Algebraic.C a,
418 ToRational.C a,
419 Arity m,
420 Arity n)
421 => Normed (Mat (S m) (S n) a) where
422 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
423 -- all matrices as big vectors.
424 --
425 -- Examples:
426 --
427 -- >>> let v1 = vec2d (3,4)
428 -- >>> norm_p 1 v1
429 -- 7.0
430 -- >>> norm_p 2 v1
431 -- 5.0
432 --
433 norm_p p (Mat rows) =
434 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
435 where
436 p' = toInteger p
437 xs = concat $ V.toList $ V.map V.toList rows
438
439 -- | The infinity norm.
440 --
441 -- Examples:
442 --
443 -- >>> let v1 = vec3d (1,5,2)
444 -- >>> norm_infty v1
445 -- 5
446 --
447 norm_infty (Mat rows) =
448 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
449
450
451
452
453
454 -- Vector helpers. We want it to be easy to create low-dimension
455 -- column vectors, which are nx1 matrices.
456
457 -- | Convenient constructor for 2D vectors.
458 --
459 -- Examples:
460 --
461 -- >>> import Roots.Simple
462 -- >>> let fst m = m !!! (0,0)
463 -- >>> let snd m = m !!! (1,0)
464 -- >>> let h = 0.5 :: Double
465 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
466 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
467 -- >>> let g u = vec2d ((g1 u), (g2 u))
468 -- >>> let u0 = vec2d (1.0, 1.0)
469 -- >>> let eps = 1/(10^9)
470 -- >>> fixed_point g eps u0
471 -- ((1.0728549599342185),(1.0820591495686167))
472 --
473 vec1d :: (a) -> Mat N1 N1 a
474 vec1d (x) = Mat (mk1 (mk1 x))
475
476 vec2d :: (a,a) -> Mat N2 N1 a
477 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
478
479 vec3d :: (a,a,a) -> Mat N3 N1 a
480 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
481
482 vec4d :: (a,a,a,a) -> Mat N4 N1 a
483 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
484
485 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
486 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
487
488 -- Since we commandeered multiplication, we need to create 1x1
489 -- matrices in order to multiply things.
490 scalar :: a -> Mat N1 N1 a
491 scalar x = Mat (mk1 (mk1 x))
492
493 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
494 => Mat m n a
495 -> Mat m n a
496 -> a
497 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
498
499
500 -- | The angle between @v1@ and @v2@ in Euclidean space.
501 --
502 -- Examples:
503 --
504 -- >>> let v1 = vec2d (1.0, 0.0)
505 -- >>> let v2 = vec2d (0.0, 1.0)
506 -- >>> angle v1 v2 == pi/2.0
507 -- True
508 --
509 angle :: (Transcendental.C a,
510 RealRing.C a,
511 n ~ N1,
512 m ~ S t,
513 Arity t,
514 ToRational.C a)
515 => Mat m n a
516 -> Mat m n a
517 -> a
518 angle v1 v2 =
519 acos theta
520 where
521 theta = (recip norms) NP.* (v1 `dot` v2)
522 norms = (norm v1) NP.* (norm v2)