]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Bump dependencies.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
8
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
13 --
14 module Linear.Matrix
15 where
16
17 import Data.List (intercalate)
18
19 import Data.Vector.Fixed (
20 (!),
21 N1,
22 N2,
23 N3,
24 N4,
25 N5,
26 S,
27 Z,
28 mk1,
29 mk2,
30 mk3,
31 mk4,
32 mk5
33 )
34 import qualified Data.Vector.Fixed as V (
35 and,
36 fromList,
37 head,
38 length,
39 map,
40 maximum,
41 replicate,
42 toList,
43 zipWith
44 )
45 import Data.Vector.Fixed.Boxed (Vec)
46 import Data.Vector.Fixed.Cont (Arity, arity)
47 import Linear.Vector
48 import Normed
49
50 import NumericPrelude hiding ((*), abs)
51 import qualified NumericPrelude as NP ((*))
52 import qualified Algebra.Algebraic as Algebraic
53 import Algebra.Algebraic (root)
54 import qualified Algebra.Additive as Additive
55 import qualified Algebra.Ring as Ring
56 import qualified Algebra.Module as Module
57 import qualified Algebra.RealRing as RealRing
58 import qualified Algebra.ToRational as ToRational
59 import qualified Algebra.Transcendental as Transcendental
60 import qualified Prelude as P
61
62 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
63 type Mat1 a = Mat N1 N1 a
64 type Mat2 a = Mat N2 N2 a
65 type Mat3 a = Mat N3 N3 a
66 type Mat4 a = Mat N4 N4 a
67 type Mat5 a = Mat N5 N5 a
68
69 instance (Eq a) => Eq (Mat m n a) where
70 -- | Compare a row at a time.
71 --
72 -- Examples:
73 --
74 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
75 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
77 -- >>> m1 == m2
78 -- True
79 -- >>> m1 == m3
80 -- False
81 --
82 (Mat rows1) == (Mat rows2) =
83 V.and $ V.zipWith comp rows1 rows2
84 where
85 -- Compare a row, one column at a time.
86 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
87
88
89 instance (Show a) => Show (Mat m n a) where
90 -- | Display matrices and vectors as ordinary tuples. This is poor
91 -- practice, but these results are primarily displayed
92 -- interactively and convenience trumps correctness (said the guy
93 -- who insists his vector lengths be statically checked at
94 -- compile-time).
95 --
96 -- Examples:
97 --
98 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
99 -- >>> show m
100 -- ((1,2),(3,4))
101 --
102 show (Mat rows) =
103 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
104 where
105 row_strings = V.map show_vector rows
106 show_vector v1 =
107 "(" ++ (intercalate "," element_strings) ++ ")"
108 where
109 v1l = V.toList v1
110 element_strings = P.map show v1l
111
112
113 -- | Convert a matrix to a nested list.
114 toList :: Mat m n a -> [[a]]
115 toList (Mat rows) = map V.toList (V.toList rows)
116
117 -- | Create a matrix from a nested list.
118 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
119 fromList vs = Mat (V.fromList $ map V.fromList vs)
120
121
122 -- | Unsafe indexing.
123 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
124 (!!!) m (i, j) = (row m i) ! j
125
126 -- | Safe indexing.
127 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
128 (!!?) m@(Mat rows) (i, j)
129 | i < 0 || j < 0 = Nothing
130 | i > V.length rows = Nothing
131 | otherwise = if j > V.length (row m j)
132 then Nothing
133 else Just $ (row m j) ! j
134
135
136 -- | The number of rows in the matrix.
137 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
138 nrows _ = arity (undefined :: m)
139
140 -- | The number of columns in the first row of the
141 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
142 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
143 ncols _ = arity (undefined :: n)
144
145
146 -- | Return the @i@th row of @m@. Unsafe.
147 row :: Mat m n a -> Int -> (Vec n a)
148 row (Mat rows) i = rows ! i
149
150
151 -- | Return the @j@th column of @m@. Unsafe.
152 column :: Mat m n a -> Int -> (Vec m a)
153 column (Mat rows) j =
154 V.map (element j) rows
155 where
156 element = flip (!)
157
158
159
160
161 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
162 -- implementation.. it would be a little cleaner to use imap, but it
163 -- doesn't seem to work.
164 --
165 -- TODO: Don't cheat with fromList.
166 --
167 -- Examples:
168 --
169 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
170 -- >>> transpose m
171 -- ((1,3),(2,4))
172 --
173 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
174 transpose m = Mat $ V.fromList column_list
175 where
176 column_list = [ column m i | i <- [0..(ncols m)-1] ]
177
178
179 -- | Is @m@ symmetric?
180 --
181 -- Examples:
182 --
183 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
184 -- >>> symmetric m1
185 -- True
186 --
187 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
188 -- >>> symmetric m2
189 -- False
190 --
191 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
192 symmetric m =
193 m == (transpose m)
194
195
196 -- | Construct a new matrix from a function @lambda@. The function
197 -- @lambda@ should take two parameters i,j corresponding to the
198 -- entries in the matrix. The i,j entry of the resulting matrix will
199 -- have the value returned by lambda i j.
200 --
201 -- TODO: Don't cheat with fromList.
202 --
203 -- Examples:
204 --
205 -- >>> let lambda i j = i + j
206 -- >>> construct lambda :: Mat3 Int
207 -- ((0,1,2),(1,2,3),(2,3,4))
208 --
209 construct :: forall m n a. (Arity m, Arity n)
210 => (Int -> Int -> a) -> Mat m n a
211 construct lambda = Mat rows
212 where
213 -- The arity trick is used in Data.Vector.Fixed.length.
214 imax = (arity (undefined :: m)) - 1
215 jmax = (arity (undefined :: n)) - 1
216 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
217 rows = V.fromList [ row' i | i <- [0..imax] ]
218
219
220 -- | Given a positive-definite matrix @m@, computes the
221 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
222 -- values on the diagonal of @r@ positive.
223 --
224 -- Examples:
225 --
226 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
227 -- >>> cholesky m1
228 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
229 -- >>> (transpose (cholesky m1)) * (cholesky m1)
230 -- ((20.000000000000004,-1.0),(-1.0,20.0))
231 --
232 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
233 => (Mat m n a) -> (Mat m n a)
234 cholesky m = construct r
235 where
236 r :: Int -> Int -> a
237 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
238 | i < j =
239 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
240 | otherwise = 0
241
242
243 -- | Returns True if the given matrix is upper-triangular, and False
244 -- otherwise.
245 --
246 -- Examples:
247 --
248 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
249 -- >>> is_upper_triangular m
250 -- False
251 --
252 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
253 -- >>> is_upper_triangular m
254 -- True
255 --
256 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
257 => Mat m n a -> Bool
258 is_upper_triangular m =
259 and $ concat results
260 where
261 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
262
263 test :: Int -> Int -> Bool
264 test i j
265 | i <= j = True
266 | otherwise = m !!! (i,j) == 0
267
268
269 -- | Returns True if the given matrix is lower-triangular, and False
270 -- otherwise.
271 --
272 -- Examples:
273 --
274 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
275 -- >>> is_lower_triangular m
276 -- True
277 --
278 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
279 -- >>> is_lower_triangular m
280 -- False
281 --
282 is_lower_triangular :: (Eq a,
283 Ring.C a,
284 Arity m,
285 Arity n)
286 => Mat m n a
287 -> Bool
288 is_lower_triangular = is_upper_triangular . transpose
289
290
291 -- | Returns True if the given matrix is triangular, and False
292 -- otherwise.
293 --
294 -- Examples:
295 --
296 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
297 -- >>> is_triangular m
298 -- True
299 --
300 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
301 -- >>> is_triangular m
302 -- True
303 --
304 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
305 -- >>> is_triangular m
306 -- False
307 --
308 is_triangular :: (Eq a,
309 Ring.C a,
310 Arity m,
311 Arity n)
312 => Mat m n a
313 -> Bool
314 is_triangular m = is_upper_triangular m || is_lower_triangular m
315
316
317 -- | Return the (i,j)th minor of m.
318 --
319 -- Examples:
320 --
321 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
322 -- >>> minor m 0 0 :: Mat2 Int
323 -- ((5,6),(8,9))
324 -- >>> minor m 1 1 :: Mat2 Int
325 -- ((1,3),(7,9))
326 --
327 minor :: (m ~ S r,
328 n ~ S t,
329 Arity r,
330 Arity t)
331 => Mat m n a
332 -> Int
333 -> Int
334 -> Mat r t a
335 minor (Mat rows) i j = m
336 where
337 rows' = delete rows i
338 m = Mat $ V.map ((flip delete) j) rows'
339
340
341 class (Eq a, Ring.C a) => Determined p a where
342 determinant :: (p a) -> a
343
344 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
345 determinant (Mat rows) = (V.head . V.head) rows
346
347 instance (Eq a,
348 Ring.C a,
349 Arity n,
350 Determined (Mat (S n) (S n)) a)
351 => Determined (Mat (S (S n)) (S (S n))) a where
352 -- | The recursive definition with a special-case for triangular matrices.
353 --
354 -- Examples:
355 --
356 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
357 -- >>> determinant m
358 -- -1
359 --
360 determinant m
361 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
362 | otherwise = determinant_recursive
363 where
364 m' i j = m !!! (i,j)
365
366 det_minor i j = determinant (minor m i j)
367
368 determinant_recursive =
369 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
370 | j <- [0..(ncols m)-1] ]
371
372
373
374 -- | Matrix multiplication.
375 --
376 -- Examples:
377 --
378 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
379 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
380 -- >>> m1 * m2
381 -- ((22,28),(49,64))
382 --
383 infixl 7 *
384 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
385 => Mat m n a
386 -> Mat n p a
387 -> Mat m p a
388 (*) m1 m2 = construct lambda
389 where
390 lambda i j =
391 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
392
393
394
395 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
396
397 (Mat rows1) + (Mat rows2) =
398 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
399
400 (Mat rows1) - (Mat rows2) =
401 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
402
403 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
404
405
406 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
407 -- The first * is ring multiplication, the second is matrix
408 -- multiplication.
409 m1 * m2 = m1 * m2
410
411
412 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
413 -- We can multiply a matrix by a scalar of the same type as its
414 -- elements.
415 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
416
417
418 instance (Algebraic.C a,
419 ToRational.C a,
420 Arity m)
421 => Normed (Mat (S m) N1 a) where
422 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
423 -- all matrices as big vectors.
424 --
425 -- Examples:
426 --
427 -- >>> let v1 = vec2d (3,4)
428 -- >>> norm_p 1 v1
429 -- 7.0
430 -- >>> norm_p 2 v1
431 -- 5.0
432 --
433 norm_p p (Mat rows) =
434 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
435 where
436 p' = toInteger p
437 xs = concat $ V.toList $ V.map V.toList rows
438
439 -- | The infinity norm.
440 --
441 -- Examples:
442 --
443 -- >>> let v1 = vec3d (1,5,2)
444 -- >>> norm_infty v1
445 -- 5
446 --
447 norm_infty (Mat rows) =
448 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
449
450
451
452
453
454 -- Vector helpers. We want it to be easy to create low-dimension
455 -- column vectors, which are nx1 matrices.
456
457 -- | Convenient constructor for 2D vectors.
458 --
459 -- Examples:
460 --
461 -- >>> import Roots.Simple
462 -- >>> let fst m = m !!! (0,0)
463 -- >>> let snd m = m !!! (1,0)
464 -- >>> let h = 0.5 :: Double
465 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
466 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
467 -- >>> let g u = vec2d ((g1 u), (g2 u))
468 -- >>> let u0 = vec2d (1.0, 1.0)
469 -- >>> let eps = 1/(10^9)
470 -- >>> fixed_point g eps u0
471 -- ((1.0728549599342185),(1.0820591495686167))
472 --
473 vec1d :: (a) -> Mat N1 N1 a
474 vec1d (x) = Mat (mk1 (mk1 x))
475
476 vec2d :: (a,a) -> Mat N2 N1 a
477 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
478
479 vec3d :: (a,a,a) -> Mat N3 N1 a
480 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
481
482 vec4d :: (a,a,a,a) -> Mat N4 N1 a
483 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
484
485 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
486 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
487
488 -- Since we commandeered multiplication, we need to create 1x1
489 -- matrices in order to multiply things.
490 scalar :: a -> Mat N1 N1 a
491 scalar x = Mat (mk1 (mk1 x))
492
493 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
494 => Mat m n a
495 -> Mat m n a
496 -> a
497 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
498
499
500 -- | The angle between @v1@ and @v2@ in Euclidean space.
501 --
502 -- Examples:
503 --
504 -- >>> let v1 = vec2d (1.0, 0.0)
505 -- >>> let v2 = vec2d (0.0, 1.0)
506 -- >>> angle v1 v2 == pi/2.0
507 -- True
508 --
509 angle :: (Transcendental.C a,
510 RealRing.C a,
511 n ~ N1,
512 m ~ S t,
513 Arity t,
514 ToRational.C a)
515 => Mat m n a
516 -> Mat m n a
517 -> a
518 angle v1 v2 =
519 acos theta
520 where
521 theta = (recip norms) NP.* (v1 `dot` v2)
522 norms = (norm v1) NP.* (norm v2)
523
524
525
526 -- | Given a square @matrix@, return a new matrix of the same size
527 -- containing only the on-diagonal entries of @matrix@. The
528 -- off-diagonal entries are set to zero.
529 --
530 -- Examples:
531 --
532 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
533 -- >>> diagonal_part m
534 -- ((1,0,0),(0,5,0),(0,0,9))
535 --
536 diagonal_part :: (Arity m, Ring.C a)
537 => Mat m m a
538 -> Mat m m a
539 diagonal_part matrix =
540 construct lambda
541 where
542 lambda i j = if i == j then matrix !!! (i,j) else 0
543
544
545 -- | Given a square @matrix@, return a new matrix of the same size
546 -- containing only the on-diagonal and below-diagonal entries of
547 -- @matrix@. The above-diagonal entries are set to zero.
548 --
549 -- Examples:
550 --
551 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
552 -- >>> lt_part m
553 -- ((1,0,0),(4,5,0),(7,8,9))
554 --
555 lt_part :: (Arity m, Ring.C a)
556 => Mat m m a
557 -> Mat m m a
558 lt_part matrix =
559 construct lambda
560 where
561 lambda i j = if i >= j then matrix !!! (i,j) else 0
562
563
564 -- | Given a square @matrix@, return a new matrix of the same size
565 -- containing only the below-diagonal entries of @matrix@. The on-
566 -- and above-diagonal entries are set to zero.
567 --
568 -- Examples:
569 --
570 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
571 -- >>> lt_part_strict m
572 -- ((0,0,0),(4,0,0),(7,8,0))
573 --
574 lt_part_strict :: (Arity m, Ring.C a)
575 => Mat m m a
576 -> Mat m m a
577 lt_part_strict matrix =
578 construct lambda
579 where
580 lambda i j = if i > j then matrix !!! (i,j) else 0
581
582
583 -- | Given a square @matrix@, return a new matrix of the same size
584 -- containing only the on-diagonal and above-diagonal entries of
585 -- @matrix@. The below-diagonal entries are set to zero.
586 --
587 -- Examples:
588 --
589 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
590 -- >>> ut_part m
591 -- ((1,2,3),(0,5,6),(0,0,9))
592 --
593 ut_part :: (Arity m, Ring.C a)
594 => Mat m m a
595 -> Mat m m a
596 ut_part = transpose . lt_part . transpose
597
598
599 -- | Given a square @matrix@, return a new matrix of the same size
600 -- containing only the above-diagonal entries of @matrix@. The on-
601 -- and below-diagonal entries are set to zero.
602 --
603 -- Examples:
604 --
605 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
606 -- >>> ut_part_strict m
607 -- ((0,2,3),(0,0,6),(0,0,0))
608 --
609 ut_part_strict :: (Arity m, Ring.C a)
610 => Mat m m a
611 -> Mat m m a
612 ut_part_strict = transpose . lt_part_strict . transpose