]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Add the Frobenius norm to Linear.Matrix.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6 {-# LANGUAGE ScopedTypeVariables #-}
7 {-# LANGUAGE TypeFamilies #-}
8 {-# LANGUAGE RebindableSyntax #-}
9
10 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
11 -- assume that the underlying representation is
12 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
13 -- generality and failed.
14 --
15 module Linear.Matrix
16 where
17
18 import Data.List (intercalate)
19
20 import Data.Vector.Fixed (
21 (!),
22 N1,
23 N2,
24 N3,
25 N4,
26 N5,
27 S,
28 Z,
29 generate,
30 mk1,
31 mk2,
32 mk3,
33 mk4,
34 mk5
35 )
36 import qualified Data.Vector.Fixed as V (
37 and,
38 foldl,
39 fromList,
40 head,
41 length,
42 map,
43 maximum,
44 replicate,
45 toList,
46 zipWith
47 )
48 import Data.Vector.Fixed.Boxed (Vec)
49 import Data.Vector.Fixed.Cont (Arity, arity)
50 import Linear.Vector
51 import Normed
52
53 import NumericPrelude hiding ((*), abs)
54 import qualified NumericPrelude as NP ((*))
55 import qualified Algebra.Algebraic as Algebraic
56 import Algebra.Algebraic (root)
57 import qualified Algebra.Additive as Additive
58 import qualified Algebra.Ring as Ring
59 import qualified Algebra.Module as Module
60 import qualified Algebra.RealRing as RealRing
61 import qualified Algebra.ToRational as ToRational
62 import qualified Algebra.Transcendental as Transcendental
63 import qualified Prelude as P
64
65 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
66 type Mat1 a = Mat N1 N1 a
67 type Mat2 a = Mat N2 N2 a
68 type Mat3 a = Mat N3 N3 a
69 type Mat4 a = Mat N4 N4 a
70 type Mat5 a = Mat N5 N5 a
71
72 instance (Eq a) => Eq (Mat m n a) where
73 -- | Compare a row at a time.
74 --
75 -- Examples:
76 --
77 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
78 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
79 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
80 -- >>> m1 == m2
81 -- True
82 -- >>> m1 == m3
83 -- False
84 --
85 (Mat rows1) == (Mat rows2) =
86 V.and $ V.zipWith comp rows1 rows2
87 where
88 -- Compare a row, one column at a time.
89 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
90
91
92 instance (Show a) => Show (Mat m n a) where
93 -- | Display matrices and vectors as ordinary tuples. This is poor
94 -- practice, but these results are primarily displayed
95 -- interactively and convenience trumps correctness (said the guy
96 -- who insists his vector lengths be statically checked at
97 -- compile-time).
98 --
99 -- Examples:
100 --
101 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
102 -- >>> show m
103 -- ((1,2),(3,4))
104 --
105 show (Mat rows) =
106 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
107 where
108 row_strings = V.map show_vector rows
109 show_vector v1 =
110 "(" ++ (intercalate "," element_strings) ++ ")"
111 where
112 v1l = V.toList v1
113 element_strings = P.map show v1l
114
115
116 -- | Convert a matrix to a nested list.
117 toList :: Mat m n a -> [[a]]
118 toList (Mat rows) = map V.toList (V.toList rows)
119
120 -- | Create a matrix from a nested list.
121 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
122 fromList vs = Mat (V.fromList $ map V.fromList vs)
123
124
125 -- | Unsafe indexing.
126 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
127 (!!!) m (i, j) = (row m i) ! j
128
129 -- | Safe indexing.
130 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
131 (!!?) m@(Mat rows) (i, j)
132 | i < 0 || j < 0 = Nothing
133 | i > V.length rows = Nothing
134 | otherwise = if j > V.length (row m j)
135 then Nothing
136 else Just $ (row m j) ! j
137
138
139 -- | The number of rows in the matrix.
140 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
141 nrows _ = arity (undefined :: m)
142
143 -- | The number of columns in the first row of the
144 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
145 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
146 ncols _ = arity (undefined :: n)
147
148
149 -- | Return the @i@th row of @m@. Unsafe.
150 row :: Mat m n a -> Int -> (Vec n a)
151 row (Mat rows) i = rows ! i
152
153
154 -- | Return the @j@th column of @m@. Unsafe.
155 column :: Mat m n a -> Int -> (Vec m a)
156 column (Mat rows) j =
157 V.map (element j) rows
158 where
159 element = flip (!)
160
161
162
163
164 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
165 -- implementation.. it would be a little cleaner to use imap, but it
166 -- doesn't seem to work.
167 --
168 -- TODO: Don't cheat with fromList.
169 --
170 -- Examples:
171 --
172 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
173 -- >>> transpose m
174 -- ((1,3),(2,4))
175 --
176 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
177 transpose m = Mat $ V.fromList column_list
178 where
179 column_list = [ column m i | i <- [0..(ncols m)-1] ]
180
181
182 -- | Is @m@ symmetric?
183 --
184 -- Examples:
185 --
186 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
187 -- >>> symmetric m1
188 -- True
189 --
190 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
191 -- >>> symmetric m2
192 -- False
193 --
194 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
195 symmetric m =
196 m == (transpose m)
197
198
199 -- | Construct a new matrix from a function @lambda@. The function
200 -- @lambda@ should take two parameters i,j corresponding to the
201 -- entries in the matrix. The i,j entry of the resulting matrix will
202 -- have the value returned by lambda i j.
203 --
204 -- Examples:
205 --
206 -- >>> let lambda i j = i + j
207 -- >>> construct lambda :: Mat3 Int
208 -- ((0,1,2),(1,2,3),(2,3,4))
209 --
210 construct :: forall m n a. (Arity m, Arity n)
211 => (Int -> Int -> a) -> Mat m n a
212 construct lambda = Mat $ generate make_row
213 where
214 make_row :: Int -> Vec n a
215 make_row i = generate (lambda i)
216
217
218 -- | Create an identity matrix with the right dimensions.
219 --
220 -- Examples:
221 --
222 -- >>> identity_matrix :: Mat3 Int
223 -- ((1,0,0),(0,1,0),(0,0,1))
224 -- >>> identity_matrix :: Mat3 Double
225 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
226 --
227 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
228 identity_matrix =
229 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
230
231 -- | Given a positive-definite matrix @m@, computes the
232 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
233 -- values on the diagonal of @r@ positive.
234 --
235 -- Examples:
236 --
237 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
238 -- >>> cholesky m1
239 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
240 -- >>> (transpose (cholesky m1)) * (cholesky m1)
241 -- ((20.000000000000004,-1.0),(-1.0,20.0))
242 --
243 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
244 => (Mat m n a) -> (Mat m n a)
245 cholesky m = construct r
246 where
247 r :: Int -> Int -> a
248 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
249 | i < j =
250 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
251 | otherwise = 0
252
253
254 -- | Returns True if the given matrix is upper-triangular, and False
255 -- otherwise.
256 --
257 -- Examples:
258 --
259 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
260 -- >>> is_upper_triangular m
261 -- False
262 --
263 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
264 -- >>> is_upper_triangular m
265 -- True
266 --
267 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
268 => Mat m n a -> Bool
269 is_upper_triangular m =
270 and $ concat results
271 where
272 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
273
274 test :: Int -> Int -> Bool
275 test i j
276 | i <= j = True
277 | otherwise = m !!! (i,j) == 0
278
279
280 -- | Returns True if the given matrix is lower-triangular, and False
281 -- otherwise.
282 --
283 -- Examples:
284 --
285 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
286 -- >>> is_lower_triangular m
287 -- True
288 --
289 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
290 -- >>> is_lower_triangular m
291 -- False
292 --
293 is_lower_triangular :: (Eq a,
294 Ring.C a,
295 Arity m,
296 Arity n)
297 => Mat m n a
298 -> Bool
299 is_lower_triangular = is_upper_triangular . transpose
300
301
302 -- | Returns True if the given matrix is triangular, and False
303 -- otherwise.
304 --
305 -- Examples:
306 --
307 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
308 -- >>> is_triangular m
309 -- True
310 --
311 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
312 -- >>> is_triangular m
313 -- True
314 --
315 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
316 -- >>> is_triangular m
317 -- False
318 --
319 is_triangular :: (Eq a,
320 Ring.C a,
321 Arity m,
322 Arity n)
323 => Mat m n a
324 -> Bool
325 is_triangular m = is_upper_triangular m || is_lower_triangular m
326
327
328 -- | Return the (i,j)th minor of m.
329 --
330 -- Examples:
331 --
332 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
333 -- >>> minor m 0 0 :: Mat2 Int
334 -- ((5,6),(8,9))
335 -- >>> minor m 1 1 :: Mat2 Int
336 -- ((1,3),(7,9))
337 --
338 minor :: (m ~ S r,
339 n ~ S t,
340 Arity r,
341 Arity t)
342 => Mat m n a
343 -> Int
344 -> Int
345 -> Mat r t a
346 minor (Mat rows) i j = m
347 where
348 rows' = delete rows i
349 m = Mat $ V.map ((flip delete) j) rows'
350
351
352 class (Eq a, Ring.C a) => Determined p a where
353 determinant :: (p a) -> a
354
355 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
356 determinant (Mat rows) = (V.head . V.head) rows
357
358 instance (Eq a,
359 Ring.C a,
360 Arity n,
361 Determined (Mat (S n) (S n)) a)
362 => Determined (Mat (S (S n)) (S (S n))) a where
363 -- | The recursive definition with a special-case for triangular matrices.
364 --
365 -- Examples:
366 --
367 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
368 -- >>> determinant m
369 -- -1
370 --
371 determinant m
372 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
373 | otherwise = determinant_recursive
374 where
375 m' i j = m !!! (i,j)
376
377 det_minor i j = determinant (minor m i j)
378
379 determinant_recursive =
380 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
381 | j <- [0..(ncols m)-1] ]
382
383
384
385 -- | Matrix multiplication.
386 --
387 -- Examples:
388 --
389 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
390 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
391 -- >>> m1 * m2
392 -- ((22,28),(49,64))
393 --
394 infixl 7 *
395 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
396 => Mat m n a
397 -> Mat n p a
398 -> Mat m p a
399 (*) m1 m2 = construct lambda
400 where
401 lambda i j =
402 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
403
404
405
406 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
407
408 (Mat rows1) + (Mat rows2) =
409 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
410
411 (Mat rows1) - (Mat rows2) =
412 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
413
414 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
415
416
417 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
418 -- The first * is ring multiplication, the second is matrix
419 -- multiplication.
420 m1 * m2 = m1 * m2
421
422
423 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
424 -- We can multiply a matrix by a scalar of the same type as its
425 -- elements.
426 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
427
428
429 instance (Algebraic.C a,
430 ToRational.C a,
431 Arity m)
432 => Normed (Mat (S m) N1 a) where
433 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
434 -- all matrices as big vectors.
435 --
436 -- Examples:
437 --
438 -- >>> let v1 = vec2d (3,4)
439 -- >>> norm_p 1 v1
440 -- 7.0
441 -- >>> norm_p 2 v1
442 -- 5.0
443 --
444 norm_p p (Mat rows) =
445 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
446 where
447 p' = toInteger p
448 xs = concat $ V.toList $ V.map V.toList rows
449
450 -- | The infinity norm.
451 --
452 -- Examples:
453 --
454 -- >>> let v1 = vec3d (1,5,2)
455 -- >>> norm_infty v1
456 -- 5
457 --
458 norm_infty (Mat rows) =
459 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
460
461
462 -- | Compute the Frobenius norm of a matrix. This essentially treats
463 -- the matrix as one long vector containing all of its entries (in
464 -- any order, it doesn't matter).
465 --
466 -- Examples:
467 --
468 -- >>> let m = fromList [[1, 2, 3],[4,5,6],[7,8,9]] :: Mat3 Double
469 -- >>> frobenius_norm m == sqrt 285
470 -- True
471 --
472 -- >>> let m = fromList [[1, -1, 1],[-1,1,-1],[1,-1,1]] :: Mat3 Double
473 -- >>> frobenius_norm m == 3
474 -- True
475 --
476 frobenius_norm :: (Algebraic.C a, Ring.C a) => Mat m n a -> a
477 frobenius_norm (Mat rows) =
478 sqrt $ vsum $ V.map row_sum rows
479 where
480 -- | The \"sum\" function defined in fixed-vector requires a 'Num'
481 -- constraint whereas we want to use the classes from
482 -- numeric-prelude.
483 vsum = V.foldl (+) (fromInteger 0)
484
485 -- | Square and add up the entries of a row.
486 row_sum = vsum . V.map (^2)
487
488
489 -- Vector helpers. We want it to be easy to create low-dimension
490 -- column vectors, which are nx1 matrices.
491
492 -- | Convenient constructor for 2D vectors.
493 --
494 -- Examples:
495 --
496 -- >>> import Roots.Simple
497 -- >>> let fst m = m !!! (0,0)
498 -- >>> let snd m = m !!! (1,0)
499 -- >>> let h = 0.5 :: Double
500 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
501 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
502 -- >>> let g u = vec2d ((g1 u), (g2 u))
503 -- >>> let u0 = vec2d (1.0, 1.0)
504 -- >>> let eps = 1/(10^9)
505 -- >>> fixed_point g eps u0
506 -- ((1.0728549599342185),(1.0820591495686167))
507 --
508 vec1d :: (a) -> Mat N1 N1 a
509 vec1d (x) = Mat (mk1 (mk1 x))
510
511 vec2d :: (a,a) -> Mat N2 N1 a
512 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
513
514 vec3d :: (a,a,a) -> Mat N3 N1 a
515 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
516
517 vec4d :: (a,a,a,a) -> Mat N4 N1 a
518 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
519
520 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
521 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
522
523 -- Since we commandeered multiplication, we need to create 1x1
524 -- matrices in order to multiply things.
525 scalar :: a -> Mat N1 N1 a
526 scalar x = Mat (mk1 (mk1 x))
527
528 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
529 => Mat m n a
530 -> Mat m n a
531 -> a
532 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
533
534
535 -- | The angle between @v1@ and @v2@ in Euclidean space.
536 --
537 -- Examples:
538 --
539 -- >>> let v1 = vec2d (1.0, 0.0)
540 -- >>> let v2 = vec2d (0.0, 1.0)
541 -- >>> angle v1 v2 == pi/2.0
542 -- True
543 --
544 angle :: (Transcendental.C a,
545 RealRing.C a,
546 n ~ N1,
547 m ~ S t,
548 Arity t,
549 ToRational.C a)
550 => Mat m n a
551 -> Mat m n a
552 -> a
553 angle v1 v2 =
554 acos theta
555 where
556 theta = (recip norms) NP.* (v1 `dot` v2)
557 norms = (norm v1) NP.* (norm v2)
558
559
560
561 -- | Given a square @matrix@, return a new matrix of the same size
562 -- containing only the on-diagonal entries of @matrix@. The
563 -- off-diagonal entries are set to zero.
564 --
565 -- Examples:
566 --
567 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
568 -- >>> diagonal_part m
569 -- ((1,0,0),(0,5,0),(0,0,9))
570 --
571 diagonal_part :: (Arity m, Ring.C a)
572 => Mat m m a
573 -> Mat m m a
574 diagonal_part matrix =
575 construct lambda
576 where
577 lambda i j = if i == j then matrix !!! (i,j) else 0
578
579
580 -- | Given a square @matrix@, return a new matrix of the same size
581 -- containing only the on-diagonal and below-diagonal entries of
582 -- @matrix@. The above-diagonal entries are set to zero.
583 --
584 -- Examples:
585 --
586 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
587 -- >>> lt_part m
588 -- ((1,0,0),(4,5,0),(7,8,9))
589 --
590 lt_part :: (Arity m, Ring.C a)
591 => Mat m m a
592 -> Mat m m a
593 lt_part matrix =
594 construct lambda
595 where
596 lambda i j = if i >= j then matrix !!! (i,j) else 0
597
598
599 -- | Given a square @matrix@, return a new matrix of the same size
600 -- containing only the below-diagonal entries of @matrix@. The on-
601 -- and above-diagonal entries are set to zero.
602 --
603 -- Examples:
604 --
605 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
606 -- >>> lt_part_strict m
607 -- ((0,0,0),(4,0,0),(7,8,0))
608 --
609 lt_part_strict :: (Arity m, Ring.C a)
610 => Mat m m a
611 -> Mat m m a
612 lt_part_strict matrix =
613 construct lambda
614 where
615 lambda i j = if i > j then matrix !!! (i,j) else 0
616
617
618 -- | Given a square @matrix@, return a new matrix of the same size
619 -- containing only the on-diagonal and above-diagonal entries of
620 -- @matrix@. The below-diagonal entries are set to zero.
621 --
622 -- Examples:
623 --
624 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
625 -- >>> ut_part m
626 -- ((1,2,3),(0,5,6),(0,0,9))
627 --
628 ut_part :: (Arity m, Ring.C a)
629 => Mat m m a
630 -> Mat m m a
631 ut_part = transpose . lt_part . transpose
632
633
634 -- | Given a square @matrix@, return a new matrix of the same size
635 -- containing only the above-diagonal entries of @matrix@. The on-
636 -- and below-diagonal entries are set to zero.
637 --
638 -- Examples:
639 --
640 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
641 -- >>> ut_part_strict m
642 -- ((0,2,3),(0,0,6),(0,0,0))
643 --
644 ut_part_strict :: (Arity m, Ring.C a)
645 => Mat m m a
646 -> Mat m m a
647 ut_part_strict = transpose . lt_part_strict . transpose