1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
17 import Data.List (intercalate)
19 import Data.Vector.Fixed (
35 import qualified Data.Vector.Fixed as V (
46 import Data.Vector.Fixed.Boxed (Vec)
47 import Data.Vector.Fixed.Cont (Arity, arity)
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
106 row_strings = V.map show_vector rows
108 "(" ++ (intercalate "," element_strings) ++ ")"
111 element_strings = P.map show v1l
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
134 else Just $ (row m j) ! j
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
166 -- TODO: Don't cheat with fromList.
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
180 -- | Is @m@ symmetric?
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
216 -- | Create an identity matrix with the right dimensions.
220 -- >>> identity_matrix :: Mat3 Int
221 -- ((1,0,0),(0,1,0),(0,0,1))
222 -- >>> identity_matrix :: Mat3 Double
223 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
225 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
227 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
229 -- | Given a positive-definite matrix @m@, computes the
230 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
231 -- values on the diagonal of @r@ positive.
235 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
237 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
238 -- >>> (transpose (cholesky m1)) * (cholesky m1)
239 -- ((20.000000000000004,-1.0),(-1.0,20.0))
241 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
242 => (Mat m n a) -> (Mat m n a)
243 cholesky m = construct r
246 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
248 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
252 -- | Returns True if the given matrix is upper-triangular, and False
257 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
258 -- >>> is_upper_triangular m
261 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
262 -- >>> is_upper_triangular m
265 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
267 is_upper_triangular m =
270 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
272 test :: Int -> Int -> Bool
275 | otherwise = m !!! (i,j) == 0
278 -- | Returns True if the given matrix is lower-triangular, and False
283 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
284 -- >>> is_lower_triangular m
287 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
288 -- >>> is_lower_triangular m
291 is_lower_triangular :: (Eq a,
297 is_lower_triangular = is_upper_triangular . transpose
300 -- | Returns True if the given matrix is triangular, and False
305 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
306 -- >>> is_triangular m
309 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
310 -- >>> is_triangular m
313 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
314 -- >>> is_triangular m
317 is_triangular :: (Eq a,
323 is_triangular m = is_upper_triangular m || is_lower_triangular m
326 -- | Return the (i,j)th minor of m.
330 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
331 -- >>> minor m 0 0 :: Mat2 Int
333 -- >>> minor m 1 1 :: Mat2 Int
344 minor (Mat rows) i j = m
346 rows' = delete rows i
347 m = Mat $ V.map ((flip delete) j) rows'
350 class (Eq a, Ring.C a) => Determined p a where
351 determinant :: (p a) -> a
353 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
354 determinant (Mat rows) = (V.head . V.head) rows
359 Determined (Mat (S n) (S n)) a)
360 => Determined (Mat (S (S n)) (S (S n))) a where
361 -- | The recursive definition with a special-case for triangular matrices.
365 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
370 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
371 | otherwise = determinant_recursive
375 det_minor i j = determinant (minor m i j)
377 determinant_recursive =
378 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
379 | j <- [0..(ncols m)-1] ]
383 -- | Matrix multiplication.
387 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
388 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
393 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
397 (*) m1 m2 = construct lambda
400 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
404 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
415 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
416 -- The first * is ring multiplication, the second is matrix
421 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
422 -- We can multiply a matrix by a scalar of the same type as its
424 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
427 instance (Algebraic.C a,
430 => Normed (Mat (S m) N1 a) where
431 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
432 -- all matrices as big vectors.
436 -- >>> let v1 = vec2d (3,4)
442 norm_p p (Mat rows) =
443 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
446 xs = concat $ V.toList $ V.map V.toList rows
448 -- | The infinity norm.
452 -- >>> let v1 = vec3d (1,5,2)
456 norm_infty (Mat rows) =
457 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
463 -- Vector helpers. We want it to be easy to create low-dimension
464 -- column vectors, which are nx1 matrices.
466 -- | Convenient constructor for 2D vectors.
470 -- >>> import Roots.Simple
471 -- >>> let fst m = m !!! (0,0)
472 -- >>> let snd m = m !!! (1,0)
473 -- >>> let h = 0.5 :: Double
474 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
475 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
476 -- >>> let g u = vec2d ((g1 u), (g2 u))
477 -- >>> let u0 = vec2d (1.0, 1.0)
478 -- >>> let eps = 1/(10^9)
479 -- >>> fixed_point g eps u0
480 -- ((1.0728549599342185),(1.0820591495686167))
482 vec1d :: (a) -> Mat N1 N1 a
483 vec1d (x) = Mat (mk1 (mk1 x))
485 vec2d :: (a,a) -> Mat N2 N1 a
486 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
488 vec3d :: (a,a,a) -> Mat N3 N1 a
489 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
491 vec4d :: (a,a,a,a) -> Mat N4 N1 a
492 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
494 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
495 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
497 -- Since we commandeered multiplication, we need to create 1x1
498 -- matrices in order to multiply things.
499 scalar :: a -> Mat N1 N1 a
500 scalar x = Mat (mk1 (mk1 x))
502 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
506 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
509 -- | The angle between @v1@ and @v2@ in Euclidean space.
513 -- >>> let v1 = vec2d (1.0, 0.0)
514 -- >>> let v2 = vec2d (0.0, 1.0)
515 -- >>> angle v1 v2 == pi/2.0
518 angle :: (Transcendental.C a,
530 theta = (recip norms) NP.* (v1 `dot` v2)
531 norms = (norm v1) NP.* (norm v2)
535 -- | Given a square @matrix@, return a new matrix of the same size
536 -- containing only the on-diagonal entries of @matrix@. The
537 -- off-diagonal entries are set to zero.
541 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
542 -- >>> diagonal_part m
543 -- ((1,0,0),(0,5,0),(0,0,9))
545 diagonal_part :: (Arity m, Ring.C a)
548 diagonal_part matrix =
551 lambda i j = if i == j then matrix !!! (i,j) else 0
554 -- | Given a square @matrix@, return a new matrix of the same size
555 -- containing only the on-diagonal and below-diagonal entries of
556 -- @matrix@. The above-diagonal entries are set to zero.
560 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
562 -- ((1,0,0),(4,5,0),(7,8,9))
564 lt_part :: (Arity m, Ring.C a)
570 lambda i j = if i >= j then matrix !!! (i,j) else 0
573 -- | Given a square @matrix@, return a new matrix of the same size
574 -- containing only the below-diagonal entries of @matrix@. The on-
575 -- and above-diagonal entries are set to zero.
579 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
580 -- >>> lt_part_strict m
581 -- ((0,0,0),(4,0,0),(7,8,0))
583 lt_part_strict :: (Arity m, Ring.C a)
586 lt_part_strict matrix =
589 lambda i j = if i > j then matrix !!! (i,j) else 0
592 -- | Given a square @matrix@, return a new matrix of the same size
593 -- containing only the on-diagonal and above-diagonal entries of
594 -- @matrix@. The below-diagonal entries are set to zero.
598 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
600 -- ((1,2,3),(0,5,6),(0,0,9))
602 ut_part :: (Arity m, Ring.C a)
605 ut_part = transpose . lt_part . transpose
608 -- | Given a square @matrix@, return a new matrix of the same size
609 -- containing only the above-diagonal entries of @matrix@. The on-
610 -- and below-diagonal entries are set to zero.
614 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
615 -- >>> ut_part_strict m
616 -- ((0,2,3),(0,0,6),(0,0,0))
618 ut_part_strict :: (Arity m, Ring.C a)
621 ut_part_strict = transpose . lt_part_strict . transpose