]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Switch Linear.Vector to the numeric prelude and add the element_sum function to it.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6 {-# LANGUAGE ScopedTypeVariables #-}
7 {-# LANGUAGE TypeFamilies #-}
8 {-# LANGUAGE RebindableSyntax #-}
9
10 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
11 -- assume that the underlying representation is
12 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
13 -- generality and failed.
14 --
15 module Linear.Matrix
16 where
17
18 import Data.List (intercalate)
19
20 import Data.Vector.Fixed (
21 (!),
22 N1,
23 N2,
24 N3,
25 N4,
26 N5,
27 S,
28 Z,
29 generate,
30 mk1,
31 mk2,
32 mk3,
33 mk4,
34 mk5
35 )
36 import qualified Data.Vector.Fixed as V (
37 and,
38 fromList,
39 head,
40 length,
41 map,
42 maximum,
43 replicate,
44 toList,
45 zipWith
46 )
47 import Data.Vector.Fixed.Cont (Arity, arity)
48 import Linear.Vector
49 import Normed
50
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
62
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
69
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
72 --
73 -- Examples:
74 --
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
78 -- >>> m1 == m2
79 -- True
80 -- >>> m1 == m3
81 -- False
82 --
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
85 where
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
88
89
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
95 -- compile-time).
96 --
97 -- Examples:
98 --
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
100 -- >>> show m
101 -- ((1,2),(3,4))
102 --
103 show (Mat rows) =
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
105 where
106 row_strings = V.map show_vector rows
107 show_vector v1 =
108 "(" ++ (intercalate "," element_strings) ++ ")"
109 where
110 v1l = V.toList v1
111 element_strings = P.map show v1l
112
113
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
117
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
121
122
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
126
127 -- | Safe indexing.
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
133 then Nothing
134 else Just $ (row m j) ! j
135
136
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
140
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
145
146
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
150
151
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
156 where
157 element = flip (!)
158
159
160
161
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
165 --
166 -- TODO: Don't cheat with fromList.
167 --
168 -- Examples:
169 --
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
171 -- >>> transpose m
172 -- ((1,3),(2,4))
173 --
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
176 where
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
178
179
180 -- | Is @m@ symmetric?
181 --
182 -- Examples:
183 --
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
185 -- >>> symmetric m1
186 -- True
187 --
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
189 -- >>> symmetric m2
190 -- False
191 --
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
193 symmetric m =
194 m == (transpose m)
195
196
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
201 --
202 -- Examples:
203 --
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
207 --
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
211 where
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
214
215
216 -- | Create an identity matrix with the right dimensions.
217 --
218 -- Examples:
219 --
220 -- >>> identity_matrix :: Mat3 Int
221 -- ((1,0,0),(0,1,0),(0,0,1))
222 -- >>> identity_matrix :: Mat3 Double
223 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
224 --
225 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
226 identity_matrix =
227 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
228
229 -- | Given a positive-definite matrix @m@, computes the
230 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
231 -- values on the diagonal of @r@ positive.
232 --
233 -- Examples:
234 --
235 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
236 -- >>> cholesky m1
237 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
238 -- >>> (transpose (cholesky m1)) * (cholesky m1)
239 -- ((20.000000000000004,-1.0),(-1.0,20.0))
240 --
241 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
242 => (Mat m n a) -> (Mat m n a)
243 cholesky m = construct r
244 where
245 r :: Int -> Int -> a
246 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
247 | i < j =
248 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
249 | otherwise = 0
250
251
252 -- | Returns True if the given matrix is upper-triangular, and False
253 -- otherwise.
254 --
255 -- Examples:
256 --
257 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
258 -- >>> is_upper_triangular m
259 -- False
260 --
261 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
262 -- >>> is_upper_triangular m
263 -- True
264 --
265 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
266 => Mat m n a -> Bool
267 is_upper_triangular m =
268 and $ concat results
269 where
270 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
271
272 test :: Int -> Int -> Bool
273 test i j
274 | i <= j = True
275 | otherwise = m !!! (i,j) == 0
276
277
278 -- | Returns True if the given matrix is lower-triangular, and False
279 -- otherwise.
280 --
281 -- Examples:
282 --
283 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
284 -- >>> is_lower_triangular m
285 -- True
286 --
287 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
288 -- >>> is_lower_triangular m
289 -- False
290 --
291 is_lower_triangular :: (Eq a,
292 Ring.C a,
293 Arity m,
294 Arity n)
295 => Mat m n a
296 -> Bool
297 is_lower_triangular = is_upper_triangular . transpose
298
299
300 -- | Returns True if the given matrix is triangular, and False
301 -- otherwise.
302 --
303 -- Examples:
304 --
305 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
306 -- >>> is_triangular m
307 -- True
308 --
309 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
310 -- >>> is_triangular m
311 -- True
312 --
313 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
314 -- >>> is_triangular m
315 -- False
316 --
317 is_triangular :: (Eq a,
318 Ring.C a,
319 Arity m,
320 Arity n)
321 => Mat m n a
322 -> Bool
323 is_triangular m = is_upper_triangular m || is_lower_triangular m
324
325
326 -- | Return the (i,j)th minor of m.
327 --
328 -- Examples:
329 --
330 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
331 -- >>> minor m 0 0 :: Mat2 Int
332 -- ((5,6),(8,9))
333 -- >>> minor m 1 1 :: Mat2 Int
334 -- ((1,3),(7,9))
335 --
336 minor :: (m ~ S r,
337 n ~ S t,
338 Arity r,
339 Arity t)
340 => Mat m n a
341 -> Int
342 -> Int
343 -> Mat r t a
344 minor (Mat rows) i j = m
345 where
346 rows' = delete rows i
347 m = Mat $ V.map ((flip delete) j) rows'
348
349
350 class (Eq a, Ring.C a) => Determined p a where
351 determinant :: (p a) -> a
352
353 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
354 determinant (Mat rows) = (V.head . V.head) rows
355
356 instance (Eq a,
357 Ring.C a,
358 Arity n,
359 Determined (Mat (S n) (S n)) a)
360 => Determined (Mat (S (S n)) (S (S n))) a where
361 -- | The recursive definition with a special-case for triangular matrices.
362 --
363 -- Examples:
364 --
365 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
366 -- >>> determinant m
367 -- -1
368 --
369 determinant m
370 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
371 | otherwise = determinant_recursive
372 where
373 m' i j = m !!! (i,j)
374
375 det_minor i j = determinant (minor m i j)
376
377 determinant_recursive =
378 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
379 | j <- [0..(ncols m)-1] ]
380
381
382
383 -- | Matrix multiplication.
384 --
385 -- Examples:
386 --
387 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
388 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
389 -- >>> m1 * m2
390 -- ((22,28),(49,64))
391 --
392 infixl 7 *
393 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
394 => Mat m n a
395 -> Mat n p a
396 -> Mat m p a
397 (*) m1 m2 = construct lambda
398 where
399 lambda i j =
400 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
401
402
403
404 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
405
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
408
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
411
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
413
414
415 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
416 -- The first * is ring multiplication, the second is matrix
417 -- multiplication.
418 m1 * m2 = m1 * m2
419
420
421 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
422 -- We can multiply a matrix by a scalar of the same type as its
423 -- elements.
424 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
425
426
427 instance (Algebraic.C a,
428 ToRational.C a,
429 Arity m)
430 => Normed (Mat (S m) N1 a) where
431 -- | Generic p-norms for vectors in R^n that are represented as nx1
432 -- matrices.
433 --
434 -- Examples:
435 --
436 -- >>> let v1 = vec2d (3,4)
437 -- >>> norm_p 1 v1
438 -- 7.0
439 -- >>> norm_p 2 v1
440 -- 5.0
441 --
442 norm_p p (Mat rows) =
443 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
444 where
445 p' = toInteger p
446 xs = concat $ V.toList $ V.map V.toList rows
447
448 -- | The infinity norm.
449 --
450 -- Examples:
451 --
452 -- >>> let v1 = vec3d (1,5,2)
453 -- >>> norm_infty v1
454 -- 5
455 --
456 norm_infty (Mat rows) =
457 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
458
459
460 -- | Compute the Frobenius norm of a matrix. This essentially treats
461 -- the matrix as one long vector containing all of its entries (in
462 -- any order, it doesn't matter).
463 --
464 -- Examples:
465 --
466 -- >>> let m = fromList [[1, 2, 3],[4,5,6],[7,8,9]] :: Mat3 Double
467 -- >>> frobenius_norm m == sqrt 285
468 -- True
469 --
470 -- >>> let m = fromList [[1, -1, 1],[-1,1,-1],[1,-1,1]] :: Mat3 Double
471 -- >>> frobenius_norm m == 3
472 -- True
473 --
474 frobenius_norm :: (Algebraic.C a, Ring.C a) => Mat m n a -> a
475 frobenius_norm (Mat rows) =
476 sqrt $ element_sum $ V.map row_sum rows
477 where
478 -- | Square and add up the entries of a row.
479 row_sum = element_sum . V.map (^2)
480
481
482 -- Vector helpers. We want it to be easy to create low-dimension
483 -- column vectors, which are nx1 matrices.
484
485 -- | Convenient constructor for 2D vectors.
486 --
487 -- Examples:
488 --
489 -- >>> import Roots.Simple
490 -- >>> let fst m = m !!! (0,0)
491 -- >>> let snd m = m !!! (1,0)
492 -- >>> let h = 0.5 :: Double
493 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
494 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
495 -- >>> let g u = vec2d ((g1 u), (g2 u))
496 -- >>> let u0 = vec2d (1.0, 1.0)
497 -- >>> let eps = 1/(10^9)
498 -- >>> fixed_point g eps u0
499 -- ((1.0728549599342185),(1.0820591495686167))
500 --
501 vec1d :: (a) -> Mat N1 N1 a
502 vec1d (x) = Mat (mk1 (mk1 x))
503
504 vec2d :: (a,a) -> Mat N2 N1 a
505 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
506
507 vec3d :: (a,a,a) -> Mat N3 N1 a
508 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
509
510 vec4d :: (a,a,a,a) -> Mat N4 N1 a
511 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
512
513 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
514 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
515
516 -- Since we commandeered multiplication, we need to create 1x1
517 -- matrices in order to multiply things.
518 scalar :: a -> Mat N1 N1 a
519 scalar x = Mat (mk1 (mk1 x))
520
521 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
522 => Mat m n a
523 -> Mat m n a
524 -> a
525 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
526
527
528 -- | The angle between @v1@ and @v2@ in Euclidean space.
529 --
530 -- Examples:
531 --
532 -- >>> let v1 = vec2d (1.0, 0.0)
533 -- >>> let v2 = vec2d (0.0, 1.0)
534 -- >>> angle v1 v2 == pi/2.0
535 -- True
536 --
537 angle :: (Transcendental.C a,
538 RealRing.C a,
539 n ~ N1,
540 m ~ S t,
541 Arity t,
542 ToRational.C a)
543 => Mat m n a
544 -> Mat m n a
545 -> a
546 angle v1 v2 =
547 acos theta
548 where
549 theta = (recip norms) NP.* (v1 `dot` v2)
550 norms = (norm v1) NP.* (norm v2)
551
552
553
554 -- | Given a square @matrix@, return a new matrix of the same size
555 -- containing only the on-diagonal entries of @matrix@. The
556 -- off-diagonal entries are set to zero.
557 --
558 -- Examples:
559 --
560 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
561 -- >>> diagonal_part m
562 -- ((1,0,0),(0,5,0),(0,0,9))
563 --
564 diagonal_part :: (Arity m, Ring.C a)
565 => Mat m m a
566 -> Mat m m a
567 diagonal_part matrix =
568 construct lambda
569 where
570 lambda i j = if i == j then matrix !!! (i,j) else 0
571
572
573 -- | Given a square @matrix@, return a new matrix of the same size
574 -- containing only the on-diagonal and below-diagonal entries of
575 -- @matrix@. The above-diagonal entries are set to zero.
576 --
577 -- Examples:
578 --
579 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
580 -- >>> lt_part m
581 -- ((1,0,0),(4,5,0),(7,8,9))
582 --
583 lt_part :: (Arity m, Ring.C a)
584 => Mat m m a
585 -> Mat m m a
586 lt_part matrix =
587 construct lambda
588 where
589 lambda i j = if i >= j then matrix !!! (i,j) else 0
590
591
592 -- | Given a square @matrix@, return a new matrix of the same size
593 -- containing only the below-diagonal entries of @matrix@. The on-
594 -- and above-diagonal entries are set to zero.
595 --
596 -- Examples:
597 --
598 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
599 -- >>> lt_part_strict m
600 -- ((0,0,0),(4,0,0),(7,8,0))
601 --
602 lt_part_strict :: (Arity m, Ring.C a)
603 => Mat m m a
604 -> Mat m m a
605 lt_part_strict matrix =
606 construct lambda
607 where
608 lambda i j = if i > j then matrix !!! (i,j) else 0
609
610
611 -- | Given a square @matrix@, return a new matrix of the same size
612 -- containing only the on-diagonal and above-diagonal entries of
613 -- @matrix@. The below-diagonal entries are set to zero.
614 --
615 -- Examples:
616 --
617 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
618 -- >>> ut_part m
619 -- ((1,2,3),(0,5,6),(0,0,9))
620 --
621 ut_part :: (Arity m, Ring.C a)
622 => Mat m m a
623 -> Mat m m a
624 ut_part = transpose . lt_part . transpose
625
626
627 -- | Given a square @matrix@, return a new matrix of the same size
628 -- containing only the above-diagonal entries of @matrix@. The on-
629 -- and below-diagonal entries are set to zero.
630 --
631 -- Examples:
632 --
633 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
634 -- >>> ut_part_strict m
635 -- ((0,2,3),(0,0,6),(0,0,0))
636 --
637 ut_part_strict :: (Arity m, Ring.C a)
638 => Mat m m a
639 -> Mat m m a
640 ut_part_strict = transpose . lt_part_strict . transpose