]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
New function: Linear.Matrix.identity_matrix.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
8
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
13 --
14 module Linear.Matrix
15 where
16
17 import Data.List (intercalate)
18
19 import Data.Vector.Fixed (
20 (!),
21 N1,
22 N2,
23 N3,
24 N4,
25 N5,
26 S,
27 Z,
28 generate,
29 mk1,
30 mk2,
31 mk3,
32 mk4,
33 mk5
34 )
35 import qualified Data.Vector.Fixed as V (
36 and,
37 fromList,
38 head,
39 length,
40 map,
41 maximum,
42 replicate,
43 toList,
44 zipWith
45 )
46 import Data.Vector.Fixed.Boxed (Vec)
47 import Data.Vector.Fixed.Cont (Arity, arity)
48 import Linear.Vector
49 import Normed
50
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
62
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
69
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
72 --
73 -- Examples:
74 --
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
78 -- >>> m1 == m2
79 -- True
80 -- >>> m1 == m3
81 -- False
82 --
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
85 where
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
88
89
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
95 -- compile-time).
96 --
97 -- Examples:
98 --
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
100 -- >>> show m
101 -- ((1,2),(3,4))
102 --
103 show (Mat rows) =
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
105 where
106 row_strings = V.map show_vector rows
107 show_vector v1 =
108 "(" ++ (intercalate "," element_strings) ++ ")"
109 where
110 v1l = V.toList v1
111 element_strings = P.map show v1l
112
113
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
117
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
121
122
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
126
127 -- | Safe indexing.
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
133 then Nothing
134 else Just $ (row m j) ! j
135
136
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
140
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
145
146
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
150
151
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
156 where
157 element = flip (!)
158
159
160
161
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
165 --
166 -- TODO: Don't cheat with fromList.
167 --
168 -- Examples:
169 --
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
171 -- >>> transpose m
172 -- ((1,3),(2,4))
173 --
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
176 where
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
178
179
180 -- | Is @m@ symmetric?
181 --
182 -- Examples:
183 --
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
185 -- >>> symmetric m1
186 -- True
187 --
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
189 -- >>> symmetric m2
190 -- False
191 --
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
193 symmetric m =
194 m == (transpose m)
195
196
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
201 --
202 -- Examples:
203 --
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
207 --
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
211 where
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
214
215
216 -- | Create an identity matrix with the right dimensions.
217 --
218 -- Examples:
219 --
220 -- >>> identity_matrix :: Mat3 Int
221 -- ((1,0,0),(0,1,0),(0,0,1))
222 -- >>> identity_matrix :: Mat3 Double
223 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
224 --
225 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
226 identity_matrix =
227 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
228
229 -- | Given a positive-definite matrix @m@, computes the
230 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
231 -- values on the diagonal of @r@ positive.
232 --
233 -- Examples:
234 --
235 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
236 -- >>> cholesky m1
237 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
238 -- >>> (transpose (cholesky m1)) * (cholesky m1)
239 -- ((20.000000000000004,-1.0),(-1.0,20.0))
240 --
241 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
242 => (Mat m n a) -> (Mat m n a)
243 cholesky m = construct r
244 where
245 r :: Int -> Int -> a
246 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
247 | i < j =
248 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
249 | otherwise = 0
250
251
252 -- | Returns True if the given matrix is upper-triangular, and False
253 -- otherwise.
254 --
255 -- Examples:
256 --
257 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
258 -- >>> is_upper_triangular m
259 -- False
260 --
261 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
262 -- >>> is_upper_triangular m
263 -- True
264 --
265 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
266 => Mat m n a -> Bool
267 is_upper_triangular m =
268 and $ concat results
269 where
270 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
271
272 test :: Int -> Int -> Bool
273 test i j
274 | i <= j = True
275 | otherwise = m !!! (i,j) == 0
276
277
278 -- | Returns True if the given matrix is lower-triangular, and False
279 -- otherwise.
280 --
281 -- Examples:
282 --
283 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
284 -- >>> is_lower_triangular m
285 -- True
286 --
287 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
288 -- >>> is_lower_triangular m
289 -- False
290 --
291 is_lower_triangular :: (Eq a,
292 Ring.C a,
293 Arity m,
294 Arity n)
295 => Mat m n a
296 -> Bool
297 is_lower_triangular = is_upper_triangular . transpose
298
299
300 -- | Returns True if the given matrix is triangular, and False
301 -- otherwise.
302 --
303 -- Examples:
304 --
305 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
306 -- >>> is_triangular m
307 -- True
308 --
309 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
310 -- >>> is_triangular m
311 -- True
312 --
313 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
314 -- >>> is_triangular m
315 -- False
316 --
317 is_triangular :: (Eq a,
318 Ring.C a,
319 Arity m,
320 Arity n)
321 => Mat m n a
322 -> Bool
323 is_triangular m = is_upper_triangular m || is_lower_triangular m
324
325
326 -- | Return the (i,j)th minor of m.
327 --
328 -- Examples:
329 --
330 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
331 -- >>> minor m 0 0 :: Mat2 Int
332 -- ((5,6),(8,9))
333 -- >>> minor m 1 1 :: Mat2 Int
334 -- ((1,3),(7,9))
335 --
336 minor :: (m ~ S r,
337 n ~ S t,
338 Arity r,
339 Arity t)
340 => Mat m n a
341 -> Int
342 -> Int
343 -> Mat r t a
344 minor (Mat rows) i j = m
345 where
346 rows' = delete rows i
347 m = Mat $ V.map ((flip delete) j) rows'
348
349
350 class (Eq a, Ring.C a) => Determined p a where
351 determinant :: (p a) -> a
352
353 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
354 determinant (Mat rows) = (V.head . V.head) rows
355
356 instance (Eq a,
357 Ring.C a,
358 Arity n,
359 Determined (Mat (S n) (S n)) a)
360 => Determined (Mat (S (S n)) (S (S n))) a where
361 -- | The recursive definition with a special-case for triangular matrices.
362 --
363 -- Examples:
364 --
365 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
366 -- >>> determinant m
367 -- -1
368 --
369 determinant m
370 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
371 | otherwise = determinant_recursive
372 where
373 m' i j = m !!! (i,j)
374
375 det_minor i j = determinant (minor m i j)
376
377 determinant_recursive =
378 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
379 | j <- [0..(ncols m)-1] ]
380
381
382
383 -- | Matrix multiplication.
384 --
385 -- Examples:
386 --
387 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
388 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
389 -- >>> m1 * m2
390 -- ((22,28),(49,64))
391 --
392 infixl 7 *
393 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
394 => Mat m n a
395 -> Mat n p a
396 -> Mat m p a
397 (*) m1 m2 = construct lambda
398 where
399 lambda i j =
400 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
401
402
403
404 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
405
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
408
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
411
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
413
414
415 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
416 -- The first * is ring multiplication, the second is matrix
417 -- multiplication.
418 m1 * m2 = m1 * m2
419
420
421 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
422 -- We can multiply a matrix by a scalar of the same type as its
423 -- elements.
424 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
425
426
427 instance (Algebraic.C a,
428 ToRational.C a,
429 Arity m)
430 => Normed (Mat (S m) N1 a) where
431 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
432 -- all matrices as big vectors.
433 --
434 -- Examples:
435 --
436 -- >>> let v1 = vec2d (3,4)
437 -- >>> norm_p 1 v1
438 -- 7.0
439 -- >>> norm_p 2 v1
440 -- 5.0
441 --
442 norm_p p (Mat rows) =
443 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
444 where
445 p' = toInteger p
446 xs = concat $ V.toList $ V.map V.toList rows
447
448 -- | The infinity norm.
449 --
450 -- Examples:
451 --
452 -- >>> let v1 = vec3d (1,5,2)
453 -- >>> norm_infty v1
454 -- 5
455 --
456 norm_infty (Mat rows) =
457 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
458
459
460
461
462
463 -- Vector helpers. We want it to be easy to create low-dimension
464 -- column vectors, which are nx1 matrices.
465
466 -- | Convenient constructor for 2D vectors.
467 --
468 -- Examples:
469 --
470 -- >>> import Roots.Simple
471 -- >>> let fst m = m !!! (0,0)
472 -- >>> let snd m = m !!! (1,0)
473 -- >>> let h = 0.5 :: Double
474 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
475 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
476 -- >>> let g u = vec2d ((g1 u), (g2 u))
477 -- >>> let u0 = vec2d (1.0, 1.0)
478 -- >>> let eps = 1/(10^9)
479 -- >>> fixed_point g eps u0
480 -- ((1.0728549599342185),(1.0820591495686167))
481 --
482 vec1d :: (a) -> Mat N1 N1 a
483 vec1d (x) = Mat (mk1 (mk1 x))
484
485 vec2d :: (a,a) -> Mat N2 N1 a
486 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
487
488 vec3d :: (a,a,a) -> Mat N3 N1 a
489 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
490
491 vec4d :: (a,a,a,a) -> Mat N4 N1 a
492 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
493
494 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
495 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
496
497 -- Since we commandeered multiplication, we need to create 1x1
498 -- matrices in order to multiply things.
499 scalar :: a -> Mat N1 N1 a
500 scalar x = Mat (mk1 (mk1 x))
501
502 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
503 => Mat m n a
504 -> Mat m n a
505 -> a
506 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
507
508
509 -- | The angle between @v1@ and @v2@ in Euclidean space.
510 --
511 -- Examples:
512 --
513 -- >>> let v1 = vec2d (1.0, 0.0)
514 -- >>> let v2 = vec2d (0.0, 1.0)
515 -- >>> angle v1 v2 == pi/2.0
516 -- True
517 --
518 angle :: (Transcendental.C a,
519 RealRing.C a,
520 n ~ N1,
521 m ~ S t,
522 Arity t,
523 ToRational.C a)
524 => Mat m n a
525 -> Mat m n a
526 -> a
527 angle v1 v2 =
528 acos theta
529 where
530 theta = (recip norms) NP.* (v1 `dot` v2)
531 norms = (norm v1) NP.* (norm v2)
532
533
534
535 -- | Given a square @matrix@, return a new matrix of the same size
536 -- containing only the on-diagonal entries of @matrix@. The
537 -- off-diagonal entries are set to zero.
538 --
539 -- Examples:
540 --
541 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
542 -- >>> diagonal_part m
543 -- ((1,0,0),(0,5,0),(0,0,9))
544 --
545 diagonal_part :: (Arity m, Ring.C a)
546 => Mat m m a
547 -> Mat m m a
548 diagonal_part matrix =
549 construct lambda
550 where
551 lambda i j = if i == j then matrix !!! (i,j) else 0
552
553
554 -- | Given a square @matrix@, return a new matrix of the same size
555 -- containing only the on-diagonal and below-diagonal entries of
556 -- @matrix@. The above-diagonal entries are set to zero.
557 --
558 -- Examples:
559 --
560 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
561 -- >>> lt_part m
562 -- ((1,0,0),(4,5,0),(7,8,9))
563 --
564 lt_part :: (Arity m, Ring.C a)
565 => Mat m m a
566 -> Mat m m a
567 lt_part matrix =
568 construct lambda
569 where
570 lambda i j = if i >= j then matrix !!! (i,j) else 0
571
572
573 -- | Given a square @matrix@, return a new matrix of the same size
574 -- containing only the below-diagonal entries of @matrix@. The on-
575 -- and above-diagonal entries are set to zero.
576 --
577 -- Examples:
578 --
579 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
580 -- >>> lt_part_strict m
581 -- ((0,0,0),(4,0,0),(7,8,0))
582 --
583 lt_part_strict :: (Arity m, Ring.C a)
584 => Mat m m a
585 -> Mat m m a
586 lt_part_strict matrix =
587 construct lambda
588 where
589 lambda i j = if i > j then matrix !!! (i,j) else 0
590
591
592 -- | Given a square @matrix@, return a new matrix of the same size
593 -- containing only the on-diagonal and above-diagonal entries of
594 -- @matrix@. The below-diagonal entries are set to zero.
595 --
596 -- Examples:
597 --
598 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
599 -- >>> ut_part m
600 -- ((1,2,3),(0,5,6),(0,0,9))
601 --
602 ut_part :: (Arity m, Ring.C a)
603 => Mat m m a
604 -> Mat m m a
605 ut_part = transpose . lt_part . transpose
606
607
608 -- | Given a square @matrix@, return a new matrix of the same size
609 -- containing only the above-diagonal entries of @matrix@. The on-
610 -- and below-diagonal entries are set to zero.
611 --
612 -- Examples:
613 --
614 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
615 -- >>> ut_part_strict m
616 -- ((0,2,3),(0,0,6),(0,0,0))
617 --
618 ut_part_strict :: (Arity m, Ring.C a)
619 => Mat m m a
620 -> Mat m m a
621 ut_part_strict = transpose . lt_part_strict . transpose