]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Fix implementation of Linear.Matrix.construct.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
8
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
13 --
14 module Linear.Matrix
15 where
16
17 import Data.List (intercalate)
18
19 import Data.Vector.Fixed (
20 (!),
21 N1,
22 N2,
23 N3,
24 N4,
25 N5,
26 S,
27 Z,
28 generate,
29 mk1,
30 mk2,
31 mk3,
32 mk4,
33 mk5
34 )
35 import qualified Data.Vector.Fixed as V (
36 and,
37 fromList,
38 head,
39 length,
40 map,
41 maximum,
42 replicate,
43 toList,
44 zipWith
45 )
46 import Data.Vector.Fixed.Boxed (Vec)
47 import Data.Vector.Fixed.Cont (Arity, arity)
48 import Linear.Vector
49 import Normed
50
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
62
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
69
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
72 --
73 -- Examples:
74 --
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
78 -- >>> m1 == m2
79 -- True
80 -- >>> m1 == m3
81 -- False
82 --
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
85 where
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
88
89
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
95 -- compile-time).
96 --
97 -- Examples:
98 --
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
100 -- >>> show m
101 -- ((1,2),(3,4))
102 --
103 show (Mat rows) =
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
105 where
106 row_strings = V.map show_vector rows
107 show_vector v1 =
108 "(" ++ (intercalate "," element_strings) ++ ")"
109 where
110 v1l = V.toList v1
111 element_strings = P.map show v1l
112
113
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
117
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
121
122
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
126
127 -- | Safe indexing.
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
133 then Nothing
134 else Just $ (row m j) ! j
135
136
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
140
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
145
146
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
150
151
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
156 where
157 element = flip (!)
158
159
160
161
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
165 --
166 -- TODO: Don't cheat with fromList.
167 --
168 -- Examples:
169 --
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
171 -- >>> transpose m
172 -- ((1,3),(2,4))
173 --
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
176 where
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
178
179
180 -- | Is @m@ symmetric?
181 --
182 -- Examples:
183 --
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
185 -- >>> symmetric m1
186 -- True
187 --
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
189 -- >>> symmetric m2
190 -- False
191 --
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
193 symmetric m =
194 m == (transpose m)
195
196
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
201 --
202 -- Examples:
203 --
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
207 --
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
211 where
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
214
215
216
217 -- | Given a positive-definite matrix @m@, computes the
218 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
219 -- values on the diagonal of @r@ positive.
220 --
221 -- Examples:
222 --
223 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
224 -- >>> cholesky m1
225 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
226 -- >>> (transpose (cholesky m1)) * (cholesky m1)
227 -- ((20.000000000000004,-1.0),(-1.0,20.0))
228 --
229 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
230 => (Mat m n a) -> (Mat m n a)
231 cholesky m = construct r
232 where
233 r :: Int -> Int -> a
234 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
235 | i < j =
236 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
237 | otherwise = 0
238
239
240 -- | Returns True if the given matrix is upper-triangular, and False
241 -- otherwise.
242 --
243 -- Examples:
244 --
245 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
246 -- >>> is_upper_triangular m
247 -- False
248 --
249 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
250 -- >>> is_upper_triangular m
251 -- True
252 --
253 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
254 => Mat m n a -> Bool
255 is_upper_triangular m =
256 and $ concat results
257 where
258 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
259
260 test :: Int -> Int -> Bool
261 test i j
262 | i <= j = True
263 | otherwise = m !!! (i,j) == 0
264
265
266 -- | Returns True if the given matrix is lower-triangular, and False
267 -- otherwise.
268 --
269 -- Examples:
270 --
271 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
272 -- >>> is_lower_triangular m
273 -- True
274 --
275 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
276 -- >>> is_lower_triangular m
277 -- False
278 --
279 is_lower_triangular :: (Eq a,
280 Ring.C a,
281 Arity m,
282 Arity n)
283 => Mat m n a
284 -> Bool
285 is_lower_triangular = is_upper_triangular . transpose
286
287
288 -- | Returns True if the given matrix is triangular, and False
289 -- otherwise.
290 --
291 -- Examples:
292 --
293 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
294 -- >>> is_triangular m
295 -- True
296 --
297 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
298 -- >>> is_triangular m
299 -- True
300 --
301 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
302 -- >>> is_triangular m
303 -- False
304 --
305 is_triangular :: (Eq a,
306 Ring.C a,
307 Arity m,
308 Arity n)
309 => Mat m n a
310 -> Bool
311 is_triangular m = is_upper_triangular m || is_lower_triangular m
312
313
314 -- | Return the (i,j)th minor of m.
315 --
316 -- Examples:
317 --
318 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
319 -- >>> minor m 0 0 :: Mat2 Int
320 -- ((5,6),(8,9))
321 -- >>> minor m 1 1 :: Mat2 Int
322 -- ((1,3),(7,9))
323 --
324 minor :: (m ~ S r,
325 n ~ S t,
326 Arity r,
327 Arity t)
328 => Mat m n a
329 -> Int
330 -> Int
331 -> Mat r t a
332 minor (Mat rows) i j = m
333 where
334 rows' = delete rows i
335 m = Mat $ V.map ((flip delete) j) rows'
336
337
338 class (Eq a, Ring.C a) => Determined p a where
339 determinant :: (p a) -> a
340
341 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
342 determinant (Mat rows) = (V.head . V.head) rows
343
344 instance (Eq a,
345 Ring.C a,
346 Arity n,
347 Determined (Mat (S n) (S n)) a)
348 => Determined (Mat (S (S n)) (S (S n))) a where
349 -- | The recursive definition with a special-case for triangular matrices.
350 --
351 -- Examples:
352 --
353 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
354 -- >>> determinant m
355 -- -1
356 --
357 determinant m
358 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
359 | otherwise = determinant_recursive
360 where
361 m' i j = m !!! (i,j)
362
363 det_minor i j = determinant (minor m i j)
364
365 determinant_recursive =
366 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
367 | j <- [0..(ncols m)-1] ]
368
369
370
371 -- | Matrix multiplication.
372 --
373 -- Examples:
374 --
375 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
376 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
377 -- >>> m1 * m2
378 -- ((22,28),(49,64))
379 --
380 infixl 7 *
381 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
382 => Mat m n a
383 -> Mat n p a
384 -> Mat m p a
385 (*) m1 m2 = construct lambda
386 where
387 lambda i j =
388 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
389
390
391
392 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
393
394 (Mat rows1) + (Mat rows2) =
395 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
396
397 (Mat rows1) - (Mat rows2) =
398 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
399
400 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
401
402
403 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
404 -- The first * is ring multiplication, the second is matrix
405 -- multiplication.
406 m1 * m2 = m1 * m2
407
408
409 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
410 -- We can multiply a matrix by a scalar of the same type as its
411 -- elements.
412 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
413
414
415 instance (Algebraic.C a,
416 ToRational.C a,
417 Arity m)
418 => Normed (Mat (S m) N1 a) where
419 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
420 -- all matrices as big vectors.
421 --
422 -- Examples:
423 --
424 -- >>> let v1 = vec2d (3,4)
425 -- >>> norm_p 1 v1
426 -- 7.0
427 -- >>> norm_p 2 v1
428 -- 5.0
429 --
430 norm_p p (Mat rows) =
431 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
432 where
433 p' = toInteger p
434 xs = concat $ V.toList $ V.map V.toList rows
435
436 -- | The infinity norm.
437 --
438 -- Examples:
439 --
440 -- >>> let v1 = vec3d (1,5,2)
441 -- >>> norm_infty v1
442 -- 5
443 --
444 norm_infty (Mat rows) =
445 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
446
447
448
449
450
451 -- Vector helpers. We want it to be easy to create low-dimension
452 -- column vectors, which are nx1 matrices.
453
454 -- | Convenient constructor for 2D vectors.
455 --
456 -- Examples:
457 --
458 -- >>> import Roots.Simple
459 -- >>> let fst m = m !!! (0,0)
460 -- >>> let snd m = m !!! (1,0)
461 -- >>> let h = 0.5 :: Double
462 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
463 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
464 -- >>> let g u = vec2d ((g1 u), (g2 u))
465 -- >>> let u0 = vec2d (1.0, 1.0)
466 -- >>> let eps = 1/(10^9)
467 -- >>> fixed_point g eps u0
468 -- ((1.0728549599342185),(1.0820591495686167))
469 --
470 vec1d :: (a) -> Mat N1 N1 a
471 vec1d (x) = Mat (mk1 (mk1 x))
472
473 vec2d :: (a,a) -> Mat N2 N1 a
474 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
475
476 vec3d :: (a,a,a) -> Mat N3 N1 a
477 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
478
479 vec4d :: (a,a,a,a) -> Mat N4 N1 a
480 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
481
482 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
483 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
484
485 -- Since we commandeered multiplication, we need to create 1x1
486 -- matrices in order to multiply things.
487 scalar :: a -> Mat N1 N1 a
488 scalar x = Mat (mk1 (mk1 x))
489
490 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
491 => Mat m n a
492 -> Mat m n a
493 -> a
494 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
495
496
497 -- | The angle between @v1@ and @v2@ in Euclidean space.
498 --
499 -- Examples:
500 --
501 -- >>> let v1 = vec2d (1.0, 0.0)
502 -- >>> let v2 = vec2d (0.0, 1.0)
503 -- >>> angle v1 v2 == pi/2.0
504 -- True
505 --
506 angle :: (Transcendental.C a,
507 RealRing.C a,
508 n ~ N1,
509 m ~ S t,
510 Arity t,
511 ToRational.C a)
512 => Mat m n a
513 -> Mat m n a
514 -> a
515 angle v1 v2 =
516 acos theta
517 where
518 theta = (recip norms) NP.* (v1 `dot` v2)
519 norms = (norm v1) NP.* (norm v2)
520
521
522
523 -- | Given a square @matrix@, return a new matrix of the same size
524 -- containing only the on-diagonal entries of @matrix@. The
525 -- off-diagonal entries are set to zero.
526 --
527 -- Examples:
528 --
529 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
530 -- >>> diagonal_part m
531 -- ((1,0,0),(0,5,0),(0,0,9))
532 --
533 diagonal_part :: (Arity m, Ring.C a)
534 => Mat m m a
535 -> Mat m m a
536 diagonal_part matrix =
537 construct lambda
538 where
539 lambda i j = if i == j then matrix !!! (i,j) else 0
540
541
542 -- | Given a square @matrix@, return a new matrix of the same size
543 -- containing only the on-diagonal and below-diagonal entries of
544 -- @matrix@. The above-diagonal entries are set to zero.
545 --
546 -- Examples:
547 --
548 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
549 -- >>> lt_part m
550 -- ((1,0,0),(4,5,0),(7,8,9))
551 --
552 lt_part :: (Arity m, Ring.C a)
553 => Mat m m a
554 -> Mat m m a
555 lt_part matrix =
556 construct lambda
557 where
558 lambda i j = if i >= j then matrix !!! (i,j) else 0
559
560
561 -- | Given a square @matrix@, return a new matrix of the same size
562 -- containing only the below-diagonal entries of @matrix@. The on-
563 -- and above-diagonal entries are set to zero.
564 --
565 -- Examples:
566 --
567 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
568 -- >>> lt_part_strict m
569 -- ((0,0,0),(4,0,0),(7,8,0))
570 --
571 lt_part_strict :: (Arity m, Ring.C a)
572 => Mat m m a
573 -> Mat m m a
574 lt_part_strict matrix =
575 construct lambda
576 where
577 lambda i j = if i > j then matrix !!! (i,j) else 0
578
579
580 -- | Given a square @matrix@, return a new matrix of the same size
581 -- containing only the on-diagonal and above-diagonal entries of
582 -- @matrix@. The below-diagonal entries are set to zero.
583 --
584 -- Examples:
585 --
586 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
587 -- >>> ut_part m
588 -- ((1,2,3),(0,5,6),(0,0,9))
589 --
590 ut_part :: (Arity m, Ring.C a)
591 => Mat m m a
592 -> Mat m m a
593 ut_part = transpose . lt_part . transpose
594
595
596 -- | Given a square @matrix@, return a new matrix of the same size
597 -- containing only the above-diagonal entries of @matrix@. The on-
598 -- and below-diagonal entries are set to zero.
599 --
600 -- Examples:
601 --
602 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
603 -- >>> ut_part_strict m
604 -- ((0,2,3),(0,0,6),(0,0,0))
605 --
606 ut_part_strict :: (Arity m, Ring.C a)
607 => Mat m m a
608 -> Mat m m a
609 ut_part_strict = transpose . lt_part_strict . transpose