]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Add triangular functions, determinant, minor (all half-baked) to Linear.Matrix.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
8
9 module Linear.Matrix
10 where
11
12 import Data.List (intercalate)
13
14 import Data.Vector.Fixed (
15 Dim,
16 N1,
17 Vector
18 )
19 import qualified Data.Vector.Fixed as V (
20 and,
21 fromList,
22 length,
23 map,
24 replicate,
25 toList,
26 zipWith
27 )
28 import Data.Vector.Fixed.Internal (Arity, arity, S)
29 import Linear.Vector
30 import Normed
31
32 import NumericPrelude hiding ((*), abs)
33 import qualified NumericPrelude as NP ((*))
34 import qualified Algebra.Algebraic as Algebraic
35 import Algebra.Algebraic (root)
36 import qualified Algebra.Additive as Additive
37 import qualified Algebra.Ring as Ring
38 import qualified Algebra.Module as Module
39 import qualified Algebra.RealRing as RealRing
40 import qualified Algebra.ToRational as ToRational
41 import qualified Algebra.Transcendental as Transcendental
42 import qualified Prelude as P
43
44 data Mat v w a = (Vector v (w a), Vector w a) => Mat (v (w a))
45 type Mat1 a = Mat D1 D1 a
46 type Mat2 a = Mat D2 D2 a
47 type Mat3 a = Mat D3 D3 a
48 type Mat4 a = Mat D4 D4 a
49
50 -- We can't just declare that all instances of Vector are instances of
51 -- Eq unfortunately. We wind up with an overlapping instance for
52 -- w (w a).
53 instance (Eq a, Vector v Bool, Vector w Bool) => Eq (Mat v w a) where
54 -- | Compare a row at a time.
55 --
56 -- Examples:
57 --
58 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
59 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
60 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
61 -- >>> m1 == m2
62 -- True
63 -- >>> m1 == m3
64 -- False
65 --
66 (Mat rows1) == (Mat rows2) =
67 V.and $ V.zipWith comp rows1 rows2
68 where
69 -- Compare a row, one column at a time.
70 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
71
72
73 instance (Show a, Vector v String, Vector w String) => Show (Mat v w a) where
74 -- | Display matrices and vectors as ordinary tuples. This is poor
75 -- practice, but these results are primarily displayed
76 -- interactively and convenience trumps correctness (said the guy
77 -- who insists his vector lengths be statically checked at
78 -- compile-time).
79 --
80 -- Examples:
81 --
82 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
83 -- >>> show m
84 -- ((1,2),(3,4))
85 --
86 show (Mat rows) =
87 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
88 where
89 row_strings = V.map show_vector rows
90 show_vector v1 =
91 "(" ++ (intercalate "," element_strings) ++ ")"
92 where
93 v1l = V.toList v1
94 element_strings = P.map show v1l
95
96
97
98 -- | Convert a matrix to a nested list.
99 toList :: Mat v w a -> [[a]]
100 toList (Mat rows) = map V.toList (V.toList rows)
101
102 -- | Create a matrix from a nested list.
103 fromList :: (Vector v (w a), Vector w a, Vector v a) => [[a]] -> Mat v w a
104 fromList vs = Mat (V.fromList $ map V.fromList vs)
105
106
107 -- | Unsafe indexing.
108 (!!!) :: (Vector w a) => Mat v w a -> (Int, Int) -> a
109 (!!!) m (i, j) = (row m i) ! j
110
111 -- | Safe indexing.
112 (!!?) :: Mat v w a -> (Int, Int) -> Maybe a
113 (!!?) m@(Mat rows) (i, j)
114 | i < 0 || j < 0 = Nothing
115 | i > V.length rows = Nothing
116 | otherwise = if j > V.length (row m j)
117 then Nothing
118 else Just $ (row m j) ! j
119
120
121 -- | The number of rows in the matrix.
122 nrows :: Mat v w a -> Int
123 nrows (Mat rows) = V.length rows
124
125 -- | The number of columns in the first row of the
126 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
127 ncols :: forall v w a. (Vector w a) => Mat v w a -> Int
128 ncols _ = (arity (undefined :: Dim w))
129
130 -- | Return the @i@th row of @m@. Unsafe.
131 row :: Mat v w a -> Int -> w a
132 row (Mat rows) i = rows ! i
133
134
135 -- | Return the @j@th column of @m@. Unsafe.
136 column :: (Vector v a) => Mat v w a -> Int -> v a
137 column (Mat rows) j =
138 V.map (element j) rows
139 where
140 element = flip (!)
141
142
143 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
144 -- implementation.. it would be a little cleaner to use imap, but it
145 -- doesn't seem to work.
146 --
147 -- TODO: Don't cheat with fromList.
148 --
149 -- Examples:
150 --
151 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
152 -- >>> transpose m
153 -- ((1,3),(2,4))
154 --
155 transpose :: (Vector w (v a),
156 Vector v a,
157 Vector w a)
158 => Mat v w a
159 -> Mat w v a
160 transpose m = Mat $ V.fromList column_list
161 where
162 column_list = [ column m i | i <- [0..(ncols m)-1] ]
163
164
165 -- | Is @m@ symmetric?
166 --
167 -- Examples:
168 --
169 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
170 -- >>> symmetric m1
171 -- True
172 --
173 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
174 -- >>> symmetric m2
175 -- False
176 --
177 symmetric :: (Vector v (w a),
178 Vector w a,
179 v ~ w,
180 Vector w Bool,
181 Eq a)
182 => Mat v w a
183 -> Bool
184 symmetric m =
185 m == (transpose m)
186
187
188 -- | Construct a new matrix from a function @lambda@. The function
189 -- @lambda@ should take two parameters i,j corresponding to the
190 -- entries in the matrix. The i,j entry of the resulting matrix will
191 -- have the value returned by lambda i j.
192 --
193 -- TODO: Don't cheat with fromList.
194 --
195 -- Examples:
196 --
197 -- >>> let lambda i j = i + j
198 -- >>> construct lambda :: Mat3 Int
199 -- ((0,1,2),(1,2,3),(2,3,4))
200 --
201 construct :: forall v w a.
202 (Vector v (w a),
203 Vector w a)
204 => (Int -> Int -> a)
205 -> Mat v w a
206 construct lambda = Mat rows
207 where
208 -- The arity trick is used in Data.Vector.Fixed.length.
209 imax = (arity (undefined :: Dim v)) - 1
210 jmax = (arity (undefined :: Dim w)) - 1
211 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
212 rows = V.fromList [ row' i | i <- [0..imax] ]
213
214 -- | Given a positive-definite matrix @m@, computes the
215 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
216 -- values on the diagonal of @r@ positive.
217 --
218 -- Examples:
219 --
220 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
221 -- >>> cholesky m1
222 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
223 -- >>> (transpose (cholesky m1)) * (cholesky m1)
224 -- ((20.000000000000004,-1.0),(-1.0,20.0))
225 --
226 cholesky :: forall a v w.
227 (Algebraic.C a,
228 Vector v (w a),
229 Vector w a,
230 Vector v a)
231 => (Mat v w a)
232 -> (Mat v w a)
233 cholesky m = construct r
234 where
235 r :: Int -> Int -> a
236 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
237 | i < j =
238 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
239 | otherwise = 0
240
241
242 -- | Returns True if the given matrix is upper-triangular, and False
243 -- otherwise.
244 --
245 -- Examples:
246 --
247 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
248 -- >>> is_upper_triangular m
249 -- False
250 --
251 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
252 -- >>> is_upper_triangular m
253 -- True
254 --
255 is_upper_triangular :: (Eq a, Ring.C a, Vector w a) => Mat v w a -> Bool
256 is_upper_triangular m =
257 and $ concat results
258 where
259 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
260
261 test :: Int -> Int -> Bool
262 test i j
263 | i <= j = True
264 | otherwise = m !!! (i,j) == 0
265
266
267 -- | Returns True if the given matrix is lower-triangular, and False
268 -- otherwise.
269 --
270 -- Examples:
271 --
272 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
273 -- >>> is_lower_triangular m
274 -- True
275 --
276 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
277 -- >>> is_lower_triangular m
278 -- False
279 --
280 is_lower_triangular :: (Eq a,
281 Ring.C a,
282 Vector w a,
283 Vector w (v a),
284 Vector v a)
285 => Mat v w a
286 -> Bool
287 is_lower_triangular = is_upper_triangular . transpose
288
289
290 -- | Returns True if the given matrix is triangular, and False
291 -- otherwise.
292 --
293 -- Examples:
294 --
295 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
296 -- >>> is_triangular m
297 -- True
298 --
299 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
300 -- >>> is_triangular m
301 -- True
302 --
303 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
304 -- >>> is_triangular m
305 -- False
306 --
307 is_triangular :: (Eq a,
308 Ring.C a,
309 Vector w a,
310 Vector w (v a),
311 Vector v a)
312 => Mat v w a
313 -> Bool
314 is_triangular m = is_upper_triangular m || is_lower_triangular m
315
316
317 -- | Return the (i,j)th minor of m.
318 --
319 -- Examples:
320 --
321 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
322 -- >>> minor m 0 0 :: Mat2 Int
323 -- ((5,6),(8,9))
324 -- >>> minor m 1 1 :: Mat2 Int
325 -- ((1,3),(7,9))
326 --
327 minor :: (Dim v ~ S (Dim u),
328 Dim w ~ S (Dim z),
329 Vector z a,
330 Vector u (w a),
331 Vector u (z a))
332 => Mat v w a
333 -> Int
334 -> Int
335 -> Mat u z a
336 minor (Mat rows) i j = m
337 where
338 rows' = delete rows i
339 m = Mat $ V.map ((flip delete) j) rows'
340
341
342 determinant :: (Eq a,
343 Ring.C a,
344 Vector w a,
345 Vector w (v a),
346 Vector v a,
347 Dim v ~ S r,
348 Dim w ~ S t)
349 => Mat v w a
350 -> a
351 determinant m
352 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
353 | otherwise = undefined --determinant_recursive m
354
355 {-
356 determinant_recursive :: forall v w a r c.
357 (Eq a,
358 Ring.C a,
359 Vector w a)
360 => Mat (v r) (w c) a
361 -> a
362 determinant_recursive m
363 | (ncols m) == 0 || (nrows m) == 0 = error "don't do that"
364 | (ncols m) == 1 && (nrows m) == 1 = m !!! (0,0) -- Base case
365 | otherwise =
366 sum [ (-1)^(1+(toInteger j)) NP.* (m' 1 j) NP.* (det_minor 1 j)
367 | j <- [0..(ncols m)-1] ]
368 where
369 m' i j = m !!! (i,j)
370
371 det_minor :: Int -> Int -> a
372 det_minor i j = determinant (minor m i j)
373 -}
374
375 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
376 -- we need additional restrictions on the result type anyway.
377 --
378 -- Examples:
379 --
380 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat D2 D3 Int
381 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat D3 D2 Int
382 -- >>> m1 * m2
383 -- ((22,28),(49,64))
384 --
385 infixl 7 *
386 (*) :: (Ring.C a,
387 Vector v a,
388 Vector w a,
389 Vector z a,
390 Vector v (z a))
391 => Mat v w a
392 -> Mat w z a
393 -> Mat v z a
394 (*) m1 m2 = construct lambda
395 where
396 lambda i j =
397 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
398
399
400
401 instance (Ring.C a,
402 Vector v (w a),
403 Vector w a)
404 => Additive.C (Mat v w a) where
405
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
408
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
411
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
413
414
415 instance (Ring.C a,
416 Vector v (w a),
417 Vector w a,
418 v ~ w)
419 => Ring.C (Mat v w a) where
420 -- The first * is ring multiplication, the second is matrix
421 -- multiplication.
422 m1 * m2 = m1 * m2
423
424
425 instance (Ring.C a,
426 Vector v (w a),
427 Vector w a)
428 => Module.C a (Mat v w a) where
429 -- We can multiply a matrix by a scalar of the same type as its
430 -- elements.
431 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
432
433
434 instance (Algebraic.C a,
435 ToRational.C a,
436 Vector v (w a),
437 Vector w a,
438 Vector v a,
439 Vector v [a])
440 => Normed (Mat v w a) where
441 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
442 -- all matrices as big vectors.
443 --
444 -- Examples:
445 --
446 -- >>> let v1 = vec2d (3,4)
447 -- >>> norm_p 1 v1
448 -- 7.0
449 -- >>> norm_p 2 v1
450 -- 5.0
451 --
452 norm_p p (Mat rows) =
453 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
454 where
455 p' = toInteger p
456 xs = concat $ V.toList $ V.map V.toList rows
457
458 -- | The infinity norm. We don't use V.maximum here because it
459 -- relies on a type constraint that the vector be non-empty and I
460 -- don't know how to pattern match it away.
461 --
462 -- Examples:
463 --
464 -- >>> let v1 = vec3d (1,5,2)
465 -- >>> norm_infty v1
466 -- 5
467 --
468 norm_infty m@(Mat rows)
469 | nrows m == 0 || ncols m == 0 = 0
470 | otherwise =
471 fromRational' $ toRational $
472 P.maximum $ V.toList $ V.map (P.maximum . V.toList) rows
473
474
475
476
477
478 -- Vector helpers. We want it to be easy to create low-dimension
479 -- column vectors, which are nx1 matrices.
480
481 -- | Convenient constructor for 2D vectors.
482 --
483 -- Examples:
484 --
485 -- >>> import Roots.Simple
486 -- >>> let h = 0.5 :: Double
487 -- >>> let g1 (Mat (D2 (D1 x) (D1 y))) = 1.0 + h NP.* exp(-(x^2))/(1.0 + y^2)
488 -- >>> let g2 (Mat (D2 (D1 x) (D1 y))) = 0.5 + h NP.* atan(x^2 + y^2)
489 -- >>> let g u = vec2d ((g1 u), (g2 u))
490 -- >>> let u0 = vec2d (1.0, 1.0)
491 -- >>> let eps = 1/(10^9)
492 -- >>> fixed_point g eps u0
493 -- ((1.0728549599342185),(1.0820591495686167))
494 --
495 vec1d :: (a) -> Mat D1 D1 a
496 vec1d (x) = Mat (D1 (D1 x))
497
498 vec2d :: (a,a) -> Mat D2 D1 a
499 vec2d (x,y) = Mat (D2 (D1 x) (D1 y))
500
501 vec3d :: (a,a,a) -> Mat D3 D1 a
502 vec3d (x,y,z) = Mat (D3 (D1 x) (D1 y) (D1 z))
503
504 vec4d :: (a,a,a,a) -> Mat D4 D1 a
505 vec4d (w,x,y,z) = Mat (D4 (D1 w) (D1 x) (D1 y) (D1 z))
506
507 -- Since we commandeered multiplication, we need to create 1x1
508 -- matrices in order to multiply things.
509 scalar :: a -> Mat D1 D1 a
510 scalar x = Mat (D1 (D1 x))
511
512 dot :: (RealRing.C a,
513 Dim w ~ N1,
514 Dim v ~ S n,
515 Vector v a,
516 Vector w a,
517 Vector w (v a),
518 Vector w (w a))
519 => Mat v w a
520 -> Mat v w a
521 -> a
522 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
523
524
525 -- | The angle between @v1@ and @v2@ in Euclidean space.
526 --
527 -- Examples:
528 --
529 -- >>> let v1 = vec2d (1.0, 0.0)
530 -- >>> let v2 = vec2d (0.0, 1.0)
531 -- >>> angle v1 v2 == pi/2.0
532 -- True
533 --
534 angle :: (Transcendental.C a,
535 RealRing.C a,
536 Dim w ~ N1,
537 Dim v ~ S n,
538 Vector w (w a),
539 Vector v [a],
540 Vector v a,
541 Vector w a,
542 Vector v (w a),
543 Vector w (v a),
544 ToRational.C a)
545 => Mat v w a
546 -> Mat v w a
547 -> a
548 angle v1 v2 =
549 acos theta
550 where
551 theta = (recip norms) NP.* (v1 `dot` v2)
552 norms = (norm v1) NP.* (norm v2)