1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6 {-# LANGUAGE ScopedTypeVariables #-}
7 {-# LANGUAGE TypeFamilies #-}
8 {-# LANGUAGE RebindableSyntax #-}
10 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
11 -- assume that the underlying representation is
12 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
13 -- generality and failed.
18 import Data.List (intercalate)
20 import Data.Vector.Fixed (
36 import qualified Data.Vector.Fixed as V (
47 import Data.Vector.Fixed.Cont (Arity, arity)
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
106 row_strings = V.map show_vector rows
108 "(" ++ (intercalate "," element_strings) ++ ")"
111 element_strings = P.map show v1l
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
134 else Just $ (row m j) ! j
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
166 -- TODO: Don't cheat with fromList.
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
180 -- | Is @m@ symmetric?
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
216 -- | Create an identity matrix with the right dimensions.
220 -- >>> identity_matrix :: Mat3 Int
221 -- ((1,0,0),(0,1,0),(0,0,1))
222 -- >>> identity_matrix :: Mat3 Double
223 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
225 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
227 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
229 -- | Given a positive-definite matrix @m@, computes the
230 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
231 -- values on the diagonal of @r@ positive.
235 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
237 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
238 -- >>> (transpose (cholesky m1)) * (cholesky m1)
239 -- ((20.000000000000004,-1.0),(-1.0,20.0))
241 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
242 => (Mat m n a) -> (Mat m n a)
243 cholesky m = construct r
246 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
248 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
252 -- | Returns True if the given matrix is upper-triangular, and False
257 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
258 -- >>> is_upper_triangular m
261 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
262 -- >>> is_upper_triangular m
265 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
267 is_upper_triangular m =
270 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
272 test :: Int -> Int -> Bool
275 | otherwise = m !!! (i,j) == 0
278 -- | Returns True if the given matrix is lower-triangular, and False
283 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
284 -- >>> is_lower_triangular m
287 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
288 -- >>> is_lower_triangular m
291 is_lower_triangular :: (Eq a,
297 is_lower_triangular = is_upper_triangular . transpose
300 -- | Returns True if the given matrix is triangular, and False
305 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
306 -- >>> is_triangular m
309 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
310 -- >>> is_triangular m
313 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
314 -- >>> is_triangular m
317 is_triangular :: (Eq a,
323 is_triangular m = is_upper_triangular m || is_lower_triangular m
326 -- | Return the (i,j)th minor of m.
330 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
331 -- >>> minor m 0 0 :: Mat2 Int
333 -- >>> minor m 1 1 :: Mat2 Int
344 minor (Mat rows) i j = m
346 rows' = delete rows i
347 m = Mat $ V.map ((flip delete) j) rows'
350 class (Eq a, Ring.C a) => Determined p a where
351 determinant :: (p a) -> a
353 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
354 determinant (Mat rows) = (V.head . V.head) rows
359 Determined (Mat (S n) (S n)) a)
360 => Determined (Mat (S (S n)) (S (S n))) a where
361 -- | The recursive definition with a special-case for triangular matrices.
365 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
370 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
371 | otherwise = determinant_recursive
375 det_minor i j = determinant (minor m i j)
377 determinant_recursive =
378 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
379 | j <- [0..(ncols m)-1] ]
383 -- | Matrix multiplication.
387 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
388 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
393 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
397 (*) m1 m2 = construct lambda
400 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
404 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
415 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
416 -- The first * is ring multiplication, the second is matrix
421 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
422 -- We can multiply a matrix by a scalar of the same type as its
424 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
427 instance (Algebraic.C a,
430 => Normed (Mat (S m) N1 a) where
431 -- | Generic p-norms for vectors in R^n that are represented as nx1
436 -- >>> let v1 = vec2d (3,4)
442 norm_p p (Mat rows) =
443 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
446 xs = concat $ V.toList $ V.map V.toList rows
448 -- | The infinity norm.
452 -- >>> let v1 = vec3d (1,5,2)
456 norm_infty (Mat rows) =
457 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
460 -- | Compute the Frobenius norm of a matrix. This essentially treats
461 -- the matrix as one long vector containing all of its entries (in
462 -- any order, it doesn't matter).
466 -- >>> let m = fromList [[1, 2, 3],[4,5,6],[7,8,9]] :: Mat3 Double
467 -- >>> frobenius_norm m == sqrt 285
470 -- >>> let m = fromList [[1, -1, 1],[-1,1,-1],[1,-1,1]] :: Mat3 Double
471 -- >>> frobenius_norm m == 3
474 frobenius_norm :: (Algebraic.C a, Ring.C a) => Mat m n a -> a
475 frobenius_norm (Mat rows) =
476 sqrt $ element_sum $ V.map row_sum rows
478 -- | Square and add up the entries of a row.
479 row_sum = element_sum . V.map (^2)
482 -- Vector helpers. We want it to be easy to create low-dimension
483 -- column vectors, which are nx1 matrices.
485 -- | Convenient constructor for 2D vectors.
489 -- >>> import Roots.Simple
490 -- >>> let fst m = m !!! (0,0)
491 -- >>> let snd m = m !!! (1,0)
492 -- >>> let h = 0.5 :: Double
493 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
494 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
495 -- >>> let g u = vec2d ((g1 u), (g2 u))
496 -- >>> let u0 = vec2d (1.0, 1.0)
497 -- >>> let eps = 1/(10^9)
498 -- >>> fixed_point g eps u0
499 -- ((1.0728549599342185),(1.0820591495686167))
501 vec1d :: (a) -> Mat N1 N1 a
502 vec1d (x) = Mat (mk1 (mk1 x))
504 vec2d :: (a,a) -> Mat N2 N1 a
505 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
507 vec3d :: (a,a,a) -> Mat N3 N1 a
508 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
510 vec4d :: (a,a,a,a) -> Mat N4 N1 a
511 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
513 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
514 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
516 -- Since we commandeered multiplication, we need to create 1x1
517 -- matrices in order to multiply things.
518 scalar :: a -> Mat N1 N1 a
519 scalar x = Mat (mk1 (mk1 x))
521 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
525 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
528 -- | The angle between @v1@ and @v2@ in Euclidean space.
532 -- >>> let v1 = vec2d (1.0, 0.0)
533 -- >>> let v2 = vec2d (0.0, 1.0)
534 -- >>> angle v1 v2 == pi/2.0
537 angle :: (Transcendental.C a,
549 theta = (recip norms) NP.* (v1 `dot` v2)
550 norms = (norm v1) NP.* (norm v2)
554 -- | Given a square @matrix@, return a new matrix of the same size
555 -- containing only the on-diagonal entries of @matrix@. The
556 -- off-diagonal entries are set to zero.
560 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
561 -- >>> diagonal_part m
562 -- ((1,0,0),(0,5,0),(0,0,9))
564 diagonal_part :: (Arity m, Ring.C a)
567 diagonal_part matrix =
570 lambda i j = if i == j then matrix !!! (i,j) else 0
573 -- | Given a square @matrix@, return a new matrix of the same size
574 -- containing only the on-diagonal and below-diagonal entries of
575 -- @matrix@. The above-diagonal entries are set to zero.
579 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
581 -- ((1,0,0),(4,5,0),(7,8,9))
583 lt_part :: (Arity m, Ring.C a)
589 lambda i j = if i >= j then matrix !!! (i,j) else 0
592 -- | Given a square @matrix@, return a new matrix of the same size
593 -- containing only the below-diagonal entries of @matrix@. The on-
594 -- and above-diagonal entries are set to zero.
598 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
599 -- >>> lt_part_strict m
600 -- ((0,0,0),(4,0,0),(7,8,0))
602 lt_part_strict :: (Arity m, Ring.C a)
605 lt_part_strict matrix =
608 lambda i j = if i > j then matrix !!! (i,j) else 0
611 -- | Given a square @matrix@, return a new matrix of the same size
612 -- containing only the on-diagonal and above-diagonal entries of
613 -- @matrix@. The below-diagonal entries are set to zero.
617 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
619 -- ((1,2,3),(0,5,6),(0,0,9))
621 ut_part :: (Arity m, Ring.C a)
624 ut_part = transpose . lt_part . transpose
627 -- | Given a square @matrix@, return a new matrix of the same size
628 -- containing only the above-diagonal entries of @matrix@. The on-
629 -- and below-diagonal entries are set to zero.
633 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
634 -- >>> ut_part_strict m
635 -- ((0,2,3),(0,0,6),(0,0,0))
637 ut_part_strict :: (Arity m, Ring.C a)
640 ut_part_strict = transpose . lt_part_strict . transpose