1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6 {-# LANGUAGE ScopedTypeVariables #-}
7 {-# LANGUAGE TypeFamilies #-}
8 {-# LANGUAGE RebindableSyntax #-}
10 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
11 -- assume that the underlying representation is
12 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
13 -- generality and failed.
18 import Data.List (intercalate)
20 import Data.Vector.Fixed (
36 import qualified Data.Vector.Fixed as V (
47 import Data.Vector.Fixed.Cont (Arity, arity)
51 import NumericPrelude hiding ( (*), abs )
52 import qualified NumericPrelude as NP ( (*) )
53 import qualified Algebra.Absolute as Absolute ( C )
54 import Algebra.Absolute ( abs )
55 import qualified Algebra.Additive as Additive
56 import qualified Algebra.Algebraic as Algebraic
57 import Algebra.Algebraic (root)
58 import qualified Algebra.Ring as Ring
59 import qualified Algebra.Module as Module
60 import qualified Algebra.RealRing as RealRing
61 import qualified Algebra.ToRational as ToRational
62 import qualified Algebra.Transcendental as Transcendental
63 import qualified Prelude as P
65 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
66 type Mat1 a = Mat N1 N1 a
67 type Mat2 a = Mat N2 N2 a
68 type Mat3 a = Mat N3 N3 a
69 type Mat4 a = Mat N4 N4 a
70 type Mat5 a = Mat N5 N5 a
72 instance (Eq a) => Eq (Mat m n a) where
73 -- | Compare a row at a time.
77 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
78 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
79 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
85 (Mat rows1) == (Mat rows2) =
86 V.and $ V.zipWith comp rows1 rows2
88 -- Compare a row, one column at a time.
89 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
92 instance (Show a) => Show (Mat m n a) where
93 -- | Display matrices and vectors as ordinary tuples. This is poor
94 -- practice, but these results are primarily displayed
95 -- interactively and convenience trumps correctness (said the guy
96 -- who insists his vector lengths be statically checked at
101 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
106 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
108 row_strings = V.map show_vector rows
110 "(" ++ (intercalate "," element_strings) ++ ")"
113 element_strings = P.map show v1l
116 -- | Convert a matrix to a nested list.
117 toList :: Mat m n a -> [[a]]
118 toList (Mat rows) = map V.toList (V.toList rows)
120 -- | Create a matrix from a nested list.
121 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
122 fromList vs = Mat (V.fromList $ map V.fromList vs)
125 -- | Unsafe indexing.
126 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
127 (!!!) m (i, j) = (row m i) ! j
130 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
131 (!!?) m@(Mat rows) (i, j)
132 | i < 0 || j < 0 = Nothing
133 | i > V.length rows = Nothing
134 | otherwise = if j > V.length (row m j)
136 else Just $ (row m j) ! j
139 -- | The number of rows in the matrix.
140 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
141 nrows _ = arity (undefined :: m)
143 -- | The number of columns in the first row of the
144 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
145 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
146 ncols _ = arity (undefined :: n)
149 -- | Return the @i@th row of @m@. Unsafe.
150 row :: Mat m n a -> Int -> (Vec n a)
151 row (Mat rows) i = rows ! i
154 -- | Return the @j@th column of @m@. Unsafe.
155 column :: Mat m n a -> Int -> (Vec m a)
156 column (Mat rows) j =
157 V.map (element j) rows
164 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
165 -- implementation.. it would be a little cleaner to use imap, but it
166 -- doesn't seem to work.
168 -- TODO: Don't cheat with fromList.
172 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
176 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
177 transpose m = Mat $ V.fromList column_list
179 column_list = [ column m i | i <- [0..(ncols m)-1] ]
182 -- | Is @m@ symmetric?
186 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
190 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
194 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
199 -- | Construct a new matrix from a function @lambda@. The function
200 -- @lambda@ should take two parameters i,j corresponding to the
201 -- entries in the matrix. The i,j entry of the resulting matrix will
202 -- have the value returned by lambda i j.
206 -- >>> let lambda i j = i + j
207 -- >>> construct lambda :: Mat3 Int
208 -- ((0,1,2),(1,2,3),(2,3,4))
210 construct :: forall m n a. (Arity m, Arity n)
211 => (Int -> Int -> a) -> Mat m n a
212 construct lambda = Mat $ generate make_row
214 make_row :: Int -> Vec n a
215 make_row i = generate (lambda i)
218 -- | Create an identity matrix with the right dimensions.
222 -- >>> identity_matrix :: Mat3 Int
223 -- ((1,0,0),(0,1,0),(0,0,1))
224 -- >>> identity_matrix :: Mat3 Double
225 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
227 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
229 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
231 -- | Given a positive-definite matrix @m@, computes the
232 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
233 -- values on the diagonal of @r@ positive.
237 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
239 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
240 -- >>> (transpose (cholesky m1)) * (cholesky m1)
241 -- ((20.000000000000004,-1.0),(-1.0,20.0))
243 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
244 => (Mat m n a) -> (Mat m n a)
245 cholesky m = construct r
248 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
250 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
254 -- | Returns True if the given matrix is upper-triangular, and False
255 -- otherwise. The parameter @epsilon@ lets the caller choose a
260 -- >>> let m = fromList [[1,1],[1e-12,1]] :: Mat2 Double
261 -- >>> is_upper_triangular m
263 -- >>> is_upper_triangular' 1e-10 m
268 -- 1. Don't cheat with lists.
270 is_upper_triangular' :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
271 => a -- ^ The tolerance @epsilon@.
274 is_upper_triangular' epsilon m =
277 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
279 test :: Int -> Int -> Bool
282 -- use "less than or equal to" so zero is a valid epsilon
283 | otherwise = abs (m !!! (i,j)) <= epsilon
286 -- | Returns True if the given matrix is upper-triangular, and False
287 -- otherwise. A specialized version of 'is_upper_triangular\'' with
292 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
293 -- >>> is_upper_triangular m
296 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
297 -- >>> is_upper_triangular m
302 -- 1. The Ord constraint is too strong here, Eq would suffice.
304 is_upper_triangular :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
306 is_upper_triangular = is_upper_triangular' 0
309 -- | Returns True if the given matrix is lower-triangular, and False
310 -- otherwise. This is a specialized version of 'is_lower_triangular\''
311 -- with @epsilon = 0@.
315 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
316 -- >>> is_lower_triangular m
319 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
320 -- >>> is_lower_triangular m
323 is_lower_triangular :: (Ord a,
330 is_lower_triangular = is_upper_triangular . transpose
333 -- | Returns True if the given matrix is lower-triangular, and False
334 -- otherwise. The parameter @epsilon@ lets the caller choose a
339 -- >>> let m = fromList [[1,1e-12],[1,1]] :: Mat2 Double
340 -- >>> is_lower_triangular m
342 -- >>> is_lower_triangular' 1e-12 m
345 is_lower_triangular' :: (Ord a,
350 => a -- ^ The tolerance @epsilon@.
353 is_lower_triangular' epsilon = (is_upper_triangular' epsilon) . transpose
356 -- | Returns True if the given matrix is triangular, and False
361 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
362 -- >>> is_triangular m
365 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
366 -- >>> is_triangular m
369 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
370 -- >>> is_triangular m
373 is_triangular :: (Ord a,
380 is_triangular m = is_upper_triangular m || is_lower_triangular m
383 -- | Return the (i,j)th minor of m.
387 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
388 -- >>> minor m 0 0 :: Mat2 Int
390 -- >>> minor m 1 1 :: Mat2 Int
401 minor (Mat rows) i j = m
403 rows' = delete rows i
404 m = Mat $ V.map ((flip delete) j) rows'
407 class (Eq a, Ring.C a) => Determined p a where
408 determinant :: (p a) -> a
410 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
411 determinant (Mat rows) = (V.head . V.head) rows
417 Determined (Mat (S n) (S n)) a)
418 => Determined (Mat (S (S n)) (S (S n))) a where
419 -- | The recursive definition with a special-case for triangular matrices.
423 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
428 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
429 | otherwise = determinant_recursive
433 det_minor i j = determinant (minor m i j)
435 determinant_recursive =
436 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
437 | j <- [0..(ncols m)-1] ]
441 -- | Matrix multiplication.
445 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
446 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
451 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
455 (*) m1 m2 = construct lambda
458 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
462 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
464 (Mat rows1) + (Mat rows2) =
465 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
467 (Mat rows1) - (Mat rows2) =
468 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
470 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
473 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
474 -- The first * is ring multiplication, the second is matrix
479 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
480 -- We can multiply a matrix by a scalar of the same type as its
482 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
485 instance (Algebraic.C a,
488 => Normed (Mat (S m) N1 a) where
489 -- | Generic p-norms for vectors in R^n that are represented as nx1
494 -- >>> let v1 = vec2d (3,4)
500 norm_p p (Mat rows) =
501 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
504 xs = concat $ V.toList $ V.map V.toList rows
506 -- | The infinity norm.
510 -- >>> let v1 = vec3d (1,5,2)
514 norm_infty (Mat rows) =
515 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
518 -- | Compute the Frobenius norm of a matrix. This essentially treats
519 -- the matrix as one long vector containing all of its entries (in
520 -- any order, it doesn't matter).
524 -- >>> let m = fromList [[1, 2, 3],[4,5,6],[7,8,9]] :: Mat3 Double
525 -- >>> frobenius_norm m == sqrt 285
528 -- >>> let m = fromList [[1, -1, 1],[-1,1,-1],[1,-1,1]] :: Mat3 Double
529 -- >>> frobenius_norm m == 3
532 frobenius_norm :: (Algebraic.C a, Ring.C a) => Mat m n a -> a
533 frobenius_norm (Mat rows) =
534 sqrt $ element_sum $ V.map row_sum rows
536 -- | Square and add up the entries of a row.
537 row_sum = element_sum . V.map (^2)
540 -- Vector helpers. We want it to be easy to create low-dimension
541 -- column vectors, which are nx1 matrices.
543 -- | Convenient constructor for 2D vectors.
547 -- >>> import Roots.Simple
548 -- >>> let fst m = m !!! (0,0)
549 -- >>> let snd m = m !!! (1,0)
550 -- >>> let h = 0.5 :: Double
551 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
552 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
553 -- >>> let g u = vec2d ((g1 u), (g2 u))
554 -- >>> let u0 = vec2d (1.0, 1.0)
555 -- >>> let eps = 1/(10^9)
556 -- >>> fixed_point g eps u0
557 -- ((1.0728549599342185),(1.0820591495686167))
559 vec1d :: (a) -> Mat N1 N1 a
560 vec1d (x) = Mat (mk1 (mk1 x))
562 vec2d :: (a,a) -> Mat N2 N1 a
563 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
565 vec3d :: (a,a,a) -> Mat N3 N1 a
566 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
568 vec4d :: (a,a,a,a) -> Mat N4 N1 a
569 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
571 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
572 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
574 -- Since we commandeered multiplication, we need to create 1x1
575 -- matrices in order to multiply things.
576 scalar :: a -> Mat N1 N1 a
577 scalar x = Mat (mk1 (mk1 x))
579 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
583 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
586 -- | The angle between @v1@ and @v2@ in Euclidean space.
590 -- >>> let v1 = vec2d (1.0, 0.0)
591 -- >>> let v2 = vec2d (0.0, 1.0)
592 -- >>> angle v1 v2 == pi/2.0
595 angle :: (Transcendental.C a,
607 theta = (recip norms) NP.* (v1 `dot` v2)
608 norms = (norm v1) NP.* (norm v2)
612 -- | Given a square @matrix@, return a new matrix of the same size
613 -- containing only the on-diagonal entries of @matrix@. The
614 -- off-diagonal entries are set to zero.
618 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
619 -- >>> diagonal_part m
620 -- ((1,0,0),(0,5,0),(0,0,9))
622 diagonal_part :: (Arity m, Ring.C a)
625 diagonal_part matrix =
628 lambda i j = if i == j then matrix !!! (i,j) else 0
631 -- | Given a square @matrix@, return a new matrix of the same size
632 -- containing only the on-diagonal and below-diagonal entries of
633 -- @matrix@. The above-diagonal entries are set to zero.
637 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
639 -- ((1,0,0),(4,5,0),(7,8,9))
641 lt_part :: (Arity m, Ring.C a)
647 lambda i j = if i >= j then matrix !!! (i,j) else 0
650 -- | Given a square @matrix@, return a new matrix of the same size
651 -- containing only the below-diagonal entries of @matrix@. The on-
652 -- and above-diagonal entries are set to zero.
656 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
657 -- >>> lt_part_strict m
658 -- ((0,0,0),(4,0,0),(7,8,0))
660 lt_part_strict :: (Arity m, Ring.C a)
663 lt_part_strict matrix =
666 lambda i j = if i > j then matrix !!! (i,j) else 0
669 -- | Given a square @matrix@, return a new matrix of the same size
670 -- containing only the on-diagonal and above-diagonal entries of
671 -- @matrix@. The below-diagonal entries are set to zero.
675 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
677 -- ((1,2,3),(0,5,6),(0,0,9))
679 ut_part :: (Arity m, Ring.C a)
682 ut_part = transpose . lt_part . transpose
685 -- | Given a square @matrix@, return a new matrix of the same size
686 -- containing only the above-diagonal entries of @matrix@. The on-
687 -- and below-diagonal entries are set to zero.
691 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
692 -- >>> ut_part_strict m
693 -- ((0,2,3),(0,0,6),(0,0,0))
695 ut_part_strict :: (Arity m, Ring.C a)
698 ut_part_strict = transpose . lt_part_strict . transpose