1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
17 import Data.List (intercalate)
19 import Data.Vector.Fixed (
35 import qualified Data.Vector.Fixed as V (
46 import Data.Vector.Fixed.Boxed (Vec)
47 import Data.Vector.Fixed.Cont (Arity, arity)
51 import NumericPrelude hiding ((*), abs)
52 import qualified NumericPrelude as NP ((*))
53 import qualified Algebra.Algebraic as Algebraic
54 import Algebra.Algebraic (root)
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Ring as Ring
57 import qualified Algebra.Module as Module
58 import qualified Algebra.RealRing as RealRing
59 import qualified Algebra.ToRational as ToRational
60 import qualified Algebra.Transcendental as Transcendental
61 import qualified Prelude as P
63 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
64 type Mat1 a = Mat N1 N1 a
65 type Mat2 a = Mat N2 N2 a
66 type Mat3 a = Mat N3 N3 a
67 type Mat4 a = Mat N4 N4 a
68 type Mat5 a = Mat N5 N5 a
70 instance (Eq a) => Eq (Mat m n a) where
71 -- | Compare a row at a time.
75 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
77 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
83 (Mat rows1) == (Mat rows2) =
84 V.and $ V.zipWith comp rows1 rows2
86 -- Compare a row, one column at a time.
87 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
90 instance (Show a) => Show (Mat m n a) where
91 -- | Display matrices and vectors as ordinary tuples. This is poor
92 -- practice, but these results are primarily displayed
93 -- interactively and convenience trumps correctness (said the guy
94 -- who insists his vector lengths be statically checked at
99 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
104 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
106 row_strings = V.map show_vector rows
108 "(" ++ (intercalate "," element_strings) ++ ")"
111 element_strings = P.map show v1l
114 -- | Convert a matrix to a nested list.
115 toList :: Mat m n a -> [[a]]
116 toList (Mat rows) = map V.toList (V.toList rows)
118 -- | Create a matrix from a nested list.
119 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
120 fromList vs = Mat (V.fromList $ map V.fromList vs)
123 -- | Unsafe indexing.
124 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
125 (!!!) m (i, j) = (row m i) ! j
128 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
129 (!!?) m@(Mat rows) (i, j)
130 | i < 0 || j < 0 = Nothing
131 | i > V.length rows = Nothing
132 | otherwise = if j > V.length (row m j)
134 else Just $ (row m j) ! j
137 -- | The number of rows in the matrix.
138 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
139 nrows _ = arity (undefined :: m)
141 -- | The number of columns in the first row of the
142 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
143 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
144 ncols _ = arity (undefined :: n)
147 -- | Return the @i@th row of @m@. Unsafe.
148 row :: Mat m n a -> Int -> (Vec n a)
149 row (Mat rows) i = rows ! i
152 -- | Return the @j@th column of @m@. Unsafe.
153 column :: Mat m n a -> Int -> (Vec m a)
154 column (Mat rows) j =
155 V.map (element j) rows
162 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
163 -- implementation.. it would be a little cleaner to use imap, but it
164 -- doesn't seem to work.
166 -- TODO: Don't cheat with fromList.
170 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
174 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
175 transpose m = Mat $ V.fromList column_list
177 column_list = [ column m i | i <- [0..(ncols m)-1] ]
180 -- | Is @m@ symmetric?
184 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
188 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
192 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
197 -- | Construct a new matrix from a function @lambda@. The function
198 -- @lambda@ should take two parameters i,j corresponding to the
199 -- entries in the matrix. The i,j entry of the resulting matrix will
200 -- have the value returned by lambda i j.
204 -- >>> let lambda i j = i + j
205 -- >>> construct lambda :: Mat3 Int
206 -- ((0,1,2),(1,2,3),(2,3,4))
208 construct :: forall m n a. (Arity m, Arity n)
209 => (Int -> Int -> a) -> Mat m n a
210 construct lambda = Mat $ generate make_row
212 make_row :: Int -> Vec n a
213 make_row i = generate (lambda i)
217 -- | Given a positive-definite matrix @m@, computes the
218 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
219 -- values on the diagonal of @r@ positive.
223 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
225 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
226 -- >>> (transpose (cholesky m1)) * (cholesky m1)
227 -- ((20.000000000000004,-1.0),(-1.0,20.0))
229 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
230 => (Mat m n a) -> (Mat m n a)
231 cholesky m = construct r
234 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
236 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
240 -- | Returns True if the given matrix is upper-triangular, and False
245 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
246 -- >>> is_upper_triangular m
249 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
250 -- >>> is_upper_triangular m
253 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
255 is_upper_triangular m =
258 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
260 test :: Int -> Int -> Bool
263 | otherwise = m !!! (i,j) == 0
266 -- | Returns True if the given matrix is lower-triangular, and False
271 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
272 -- >>> is_lower_triangular m
275 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
276 -- >>> is_lower_triangular m
279 is_lower_triangular :: (Eq a,
285 is_lower_triangular = is_upper_triangular . transpose
288 -- | Returns True if the given matrix is triangular, and False
293 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
294 -- >>> is_triangular m
297 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
298 -- >>> is_triangular m
301 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
302 -- >>> is_triangular m
305 is_triangular :: (Eq a,
311 is_triangular m = is_upper_triangular m || is_lower_triangular m
314 -- | Return the (i,j)th minor of m.
318 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
319 -- >>> minor m 0 0 :: Mat2 Int
321 -- >>> minor m 1 1 :: Mat2 Int
332 minor (Mat rows) i j = m
334 rows' = delete rows i
335 m = Mat $ V.map ((flip delete) j) rows'
338 class (Eq a, Ring.C a) => Determined p a where
339 determinant :: (p a) -> a
341 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
342 determinant (Mat rows) = (V.head . V.head) rows
347 Determined (Mat (S n) (S n)) a)
348 => Determined (Mat (S (S n)) (S (S n))) a where
349 -- | The recursive definition with a special-case for triangular matrices.
353 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
358 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
359 | otherwise = determinant_recursive
363 det_minor i j = determinant (minor m i j)
365 determinant_recursive =
366 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
367 | j <- [0..(ncols m)-1] ]
371 -- | Matrix multiplication.
375 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
376 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
381 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
385 (*) m1 m2 = construct lambda
388 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
392 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
394 (Mat rows1) + (Mat rows2) =
395 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
397 (Mat rows1) - (Mat rows2) =
398 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
400 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
403 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
404 -- The first * is ring multiplication, the second is matrix
409 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
410 -- We can multiply a matrix by a scalar of the same type as its
412 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
415 instance (Algebraic.C a,
418 => Normed (Mat (S m) N1 a) where
419 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
420 -- all matrices as big vectors.
424 -- >>> let v1 = vec2d (3,4)
430 norm_p p (Mat rows) =
431 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
434 xs = concat $ V.toList $ V.map V.toList rows
436 -- | The infinity norm.
440 -- >>> let v1 = vec3d (1,5,2)
444 norm_infty (Mat rows) =
445 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
451 -- Vector helpers. We want it to be easy to create low-dimension
452 -- column vectors, which are nx1 matrices.
454 -- | Convenient constructor for 2D vectors.
458 -- >>> import Roots.Simple
459 -- >>> let fst m = m !!! (0,0)
460 -- >>> let snd m = m !!! (1,0)
461 -- >>> let h = 0.5 :: Double
462 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
463 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
464 -- >>> let g u = vec2d ((g1 u), (g2 u))
465 -- >>> let u0 = vec2d (1.0, 1.0)
466 -- >>> let eps = 1/(10^9)
467 -- >>> fixed_point g eps u0
468 -- ((1.0728549599342185),(1.0820591495686167))
470 vec1d :: (a) -> Mat N1 N1 a
471 vec1d (x) = Mat (mk1 (mk1 x))
473 vec2d :: (a,a) -> Mat N2 N1 a
474 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
476 vec3d :: (a,a,a) -> Mat N3 N1 a
477 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
479 vec4d :: (a,a,a,a) -> Mat N4 N1 a
480 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
482 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
483 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
485 -- Since we commandeered multiplication, we need to create 1x1
486 -- matrices in order to multiply things.
487 scalar :: a -> Mat N1 N1 a
488 scalar x = Mat (mk1 (mk1 x))
490 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
494 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
497 -- | The angle between @v1@ and @v2@ in Euclidean space.
501 -- >>> let v1 = vec2d (1.0, 0.0)
502 -- >>> let v2 = vec2d (0.0, 1.0)
503 -- >>> angle v1 v2 == pi/2.0
506 angle :: (Transcendental.C a,
518 theta = (recip norms) NP.* (v1 `dot` v2)
519 norms = (norm v1) NP.* (norm v2)
523 -- | Given a square @matrix@, return a new matrix of the same size
524 -- containing only the on-diagonal entries of @matrix@. The
525 -- off-diagonal entries are set to zero.
529 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
530 -- >>> diagonal_part m
531 -- ((1,0,0),(0,5,0),(0,0,9))
533 diagonal_part :: (Arity m, Ring.C a)
536 diagonal_part matrix =
539 lambda i j = if i == j then matrix !!! (i,j) else 0
542 -- | Given a square @matrix@, return a new matrix of the same size
543 -- containing only the on-diagonal and below-diagonal entries of
544 -- @matrix@. The above-diagonal entries are set to zero.
548 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
550 -- ((1,0,0),(4,5,0),(7,8,9))
552 lt_part :: (Arity m, Ring.C a)
558 lambda i j = if i >= j then matrix !!! (i,j) else 0
561 -- | Given a square @matrix@, return a new matrix of the same size
562 -- containing only the below-diagonal entries of @matrix@. The on-
563 -- and above-diagonal entries are set to zero.
567 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
568 -- >>> lt_part_strict m
569 -- ((0,0,0),(4,0,0),(7,8,0))
571 lt_part_strict :: (Arity m, Ring.C a)
574 lt_part_strict matrix =
577 lambda i j = if i > j then matrix !!! (i,j) else 0
580 -- | Given a square @matrix@, return a new matrix of the same size
581 -- containing only the on-diagonal and above-diagonal entries of
582 -- @matrix@. The below-diagonal entries are set to zero.
586 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
588 -- ((1,2,3),(0,5,6),(0,0,9))
590 ut_part :: (Arity m, Ring.C a)
593 ut_part = transpose . lt_part . transpose
596 -- | Given a square @matrix@, return a new matrix of the same size
597 -- containing only the above-diagonal entries of @matrix@. The on-
598 -- and below-diagonal entries are set to zero.
602 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
603 -- >>> ut_part_strict m
604 -- ((0,2,3),(0,0,6),(0,0,0))
606 ut_part_strict :: (Arity m, Ring.C a)
609 ut_part_strict = transpose . lt_part_strict . transpose