1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE NoMonomorphismRestriction #-}
6 {-# LANGUAGE ScopedTypeVariables #-}
7 {-# LANGUAGE TypeFamilies #-}
8 {-# LANGUAGE RebindableSyntax #-}
10 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
11 -- assume that the underlying representation is
12 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
13 -- generality and failed.
18 import Data.List (intercalate)
20 import Data.Vector.Fixed (
36 import qualified Data.Vector.Fixed as V (
48 import Data.Vector.Fixed.Boxed (Vec)
49 import Data.Vector.Fixed.Cont (Arity, arity)
53 import NumericPrelude hiding ((*), abs)
54 import qualified NumericPrelude as NP ((*))
55 import qualified Algebra.Algebraic as Algebraic
56 import Algebra.Algebraic (root)
57 import qualified Algebra.Additive as Additive
58 import qualified Algebra.Ring as Ring
59 import qualified Algebra.Module as Module
60 import qualified Algebra.RealRing as RealRing
61 import qualified Algebra.ToRational as ToRational
62 import qualified Algebra.Transcendental as Transcendental
63 import qualified Prelude as P
65 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
66 type Mat1 a = Mat N1 N1 a
67 type Mat2 a = Mat N2 N2 a
68 type Mat3 a = Mat N3 N3 a
69 type Mat4 a = Mat N4 N4 a
70 type Mat5 a = Mat N5 N5 a
72 instance (Eq a) => Eq (Mat m n a) where
73 -- | Compare a row at a time.
77 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
78 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
79 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
85 (Mat rows1) == (Mat rows2) =
86 V.and $ V.zipWith comp rows1 rows2
88 -- Compare a row, one column at a time.
89 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
92 instance (Show a) => Show (Mat m n a) where
93 -- | Display matrices and vectors as ordinary tuples. This is poor
94 -- practice, but these results are primarily displayed
95 -- interactively and convenience trumps correctness (said the guy
96 -- who insists his vector lengths be statically checked at
101 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
106 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
108 row_strings = V.map show_vector rows
110 "(" ++ (intercalate "," element_strings) ++ ")"
113 element_strings = P.map show v1l
116 -- | Convert a matrix to a nested list.
117 toList :: Mat m n a -> [[a]]
118 toList (Mat rows) = map V.toList (V.toList rows)
120 -- | Create a matrix from a nested list.
121 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
122 fromList vs = Mat (V.fromList $ map V.fromList vs)
125 -- | Unsafe indexing.
126 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
127 (!!!) m (i, j) = (row m i) ! j
130 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
131 (!!?) m@(Mat rows) (i, j)
132 | i < 0 || j < 0 = Nothing
133 | i > V.length rows = Nothing
134 | otherwise = if j > V.length (row m j)
136 else Just $ (row m j) ! j
139 -- | The number of rows in the matrix.
140 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
141 nrows _ = arity (undefined :: m)
143 -- | The number of columns in the first row of the
144 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
145 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
146 ncols _ = arity (undefined :: n)
149 -- | Return the @i@th row of @m@. Unsafe.
150 row :: Mat m n a -> Int -> (Vec n a)
151 row (Mat rows) i = rows ! i
154 -- | Return the @j@th column of @m@. Unsafe.
155 column :: Mat m n a -> Int -> (Vec m a)
156 column (Mat rows) j =
157 V.map (element j) rows
164 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
165 -- implementation.. it would be a little cleaner to use imap, but it
166 -- doesn't seem to work.
168 -- TODO: Don't cheat with fromList.
172 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
176 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
177 transpose m = Mat $ V.fromList column_list
179 column_list = [ column m i | i <- [0..(ncols m)-1] ]
182 -- | Is @m@ symmetric?
186 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
190 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
194 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
199 -- | Construct a new matrix from a function @lambda@. The function
200 -- @lambda@ should take two parameters i,j corresponding to the
201 -- entries in the matrix. The i,j entry of the resulting matrix will
202 -- have the value returned by lambda i j.
206 -- >>> let lambda i j = i + j
207 -- >>> construct lambda :: Mat3 Int
208 -- ((0,1,2),(1,2,3),(2,3,4))
210 construct :: forall m n a. (Arity m, Arity n)
211 => (Int -> Int -> a) -> Mat m n a
212 construct lambda = Mat $ generate make_row
214 make_row :: Int -> Vec n a
215 make_row i = generate (lambda i)
218 -- | Create an identity matrix with the right dimensions.
222 -- >>> identity_matrix :: Mat3 Int
223 -- ((1,0,0),(0,1,0),(0,0,1))
224 -- >>> identity_matrix :: Mat3 Double
225 -- ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
227 identity_matrix :: (Arity m, Ring.C a) => Mat m m a
229 construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
231 -- | Given a positive-definite matrix @m@, computes the
232 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
233 -- values on the diagonal of @r@ positive.
237 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
239 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
240 -- >>> (transpose (cholesky m1)) * (cholesky m1)
241 -- ((20.000000000000004,-1.0),(-1.0,20.0))
243 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
244 => (Mat m n a) -> (Mat m n a)
245 cholesky m = construct r
248 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
250 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
254 -- | Returns True if the given matrix is upper-triangular, and False
259 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
260 -- >>> is_upper_triangular m
263 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
264 -- >>> is_upper_triangular m
267 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
269 is_upper_triangular m =
272 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
274 test :: Int -> Int -> Bool
277 | otherwise = m !!! (i,j) == 0
280 -- | Returns True if the given matrix is lower-triangular, and False
285 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
286 -- >>> is_lower_triangular m
289 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
290 -- >>> is_lower_triangular m
293 is_lower_triangular :: (Eq a,
299 is_lower_triangular = is_upper_triangular . transpose
302 -- | Returns True if the given matrix is triangular, and False
307 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
308 -- >>> is_triangular m
311 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
312 -- >>> is_triangular m
315 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
316 -- >>> is_triangular m
319 is_triangular :: (Eq a,
325 is_triangular m = is_upper_triangular m || is_lower_triangular m
328 -- | Return the (i,j)th minor of m.
332 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
333 -- >>> minor m 0 0 :: Mat2 Int
335 -- >>> minor m 1 1 :: Mat2 Int
346 minor (Mat rows) i j = m
348 rows' = delete rows i
349 m = Mat $ V.map ((flip delete) j) rows'
352 class (Eq a, Ring.C a) => Determined p a where
353 determinant :: (p a) -> a
355 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
356 determinant (Mat rows) = (V.head . V.head) rows
361 Determined (Mat (S n) (S n)) a)
362 => Determined (Mat (S (S n)) (S (S n))) a where
363 -- | The recursive definition with a special-case for triangular matrices.
367 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
372 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
373 | otherwise = determinant_recursive
377 det_minor i j = determinant (minor m i j)
379 determinant_recursive =
380 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
381 | j <- [0..(ncols m)-1] ]
385 -- | Matrix multiplication.
389 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
390 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
395 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
399 (*) m1 m2 = construct lambda
402 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
406 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
408 (Mat rows1) + (Mat rows2) =
409 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
411 (Mat rows1) - (Mat rows2) =
412 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
414 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
417 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
418 -- The first * is ring multiplication, the second is matrix
423 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
424 -- We can multiply a matrix by a scalar of the same type as its
426 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
429 instance (Algebraic.C a,
432 => Normed (Mat (S m) N1 a) where
433 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
434 -- all matrices as big vectors.
438 -- >>> let v1 = vec2d (3,4)
444 norm_p p (Mat rows) =
445 (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
448 xs = concat $ V.toList $ V.map V.toList rows
450 -- | The infinity norm.
454 -- >>> let v1 = vec3d (1,5,2)
458 norm_infty (Mat rows) =
459 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
462 -- | Compute the Frobenius norm of a matrix. This essentially treats
463 -- the matrix as one long vector containing all of its entries (in
464 -- any order, it doesn't matter).
468 -- >>> let m = fromList [[1, 2, 3],[4,5,6],[7,8,9]] :: Mat3 Double
469 -- >>> frobenius_norm m == sqrt 285
472 -- >>> let m = fromList [[1, -1, 1],[-1,1,-1],[1,-1,1]] :: Mat3 Double
473 -- >>> frobenius_norm m == 3
476 frobenius_norm :: (Algebraic.C a, Ring.C a) => Mat m n a -> a
477 frobenius_norm (Mat rows) =
478 sqrt $ vsum $ V.map row_sum rows
480 -- | The \"sum\" function defined in fixed-vector requires a 'Num'
481 -- constraint whereas we want to use the classes from
483 vsum = V.foldl (+) (fromInteger 0)
485 -- | Square and add up the entries of a row.
486 row_sum = vsum . V.map (^2)
489 -- Vector helpers. We want it to be easy to create low-dimension
490 -- column vectors, which are nx1 matrices.
492 -- | Convenient constructor for 2D vectors.
496 -- >>> import Roots.Simple
497 -- >>> let fst m = m !!! (0,0)
498 -- >>> let snd m = m !!! (1,0)
499 -- >>> let h = 0.5 :: Double
500 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
501 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
502 -- >>> let g u = vec2d ((g1 u), (g2 u))
503 -- >>> let u0 = vec2d (1.0, 1.0)
504 -- >>> let eps = 1/(10^9)
505 -- >>> fixed_point g eps u0
506 -- ((1.0728549599342185),(1.0820591495686167))
508 vec1d :: (a) -> Mat N1 N1 a
509 vec1d (x) = Mat (mk1 (mk1 x))
511 vec2d :: (a,a) -> Mat N2 N1 a
512 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
514 vec3d :: (a,a,a) -> Mat N3 N1 a
515 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
517 vec4d :: (a,a,a,a) -> Mat N4 N1 a
518 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
520 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
521 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
523 -- Since we commandeered multiplication, we need to create 1x1
524 -- matrices in order to multiply things.
525 scalar :: a -> Mat N1 N1 a
526 scalar x = Mat (mk1 (mk1 x))
528 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
532 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
535 -- | The angle between @v1@ and @v2@ in Euclidean space.
539 -- >>> let v1 = vec2d (1.0, 0.0)
540 -- >>> let v2 = vec2d (0.0, 1.0)
541 -- >>> angle v1 v2 == pi/2.0
544 angle :: (Transcendental.C a,
556 theta = (recip norms) NP.* (v1 `dot` v2)
557 norms = (norm v1) NP.* (norm v2)
561 -- | Given a square @matrix@, return a new matrix of the same size
562 -- containing only the on-diagonal entries of @matrix@. The
563 -- off-diagonal entries are set to zero.
567 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
568 -- >>> diagonal_part m
569 -- ((1,0,0),(0,5,0),(0,0,9))
571 diagonal_part :: (Arity m, Ring.C a)
574 diagonal_part matrix =
577 lambda i j = if i == j then matrix !!! (i,j) else 0
580 -- | Given a square @matrix@, return a new matrix of the same size
581 -- containing only the on-diagonal and below-diagonal entries of
582 -- @matrix@. The above-diagonal entries are set to zero.
586 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
588 -- ((1,0,0),(4,5,0),(7,8,9))
590 lt_part :: (Arity m, Ring.C a)
596 lambda i j = if i >= j then matrix !!! (i,j) else 0
599 -- | Given a square @matrix@, return a new matrix of the same size
600 -- containing only the below-diagonal entries of @matrix@. The on-
601 -- and above-diagonal entries are set to zero.
605 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
606 -- >>> lt_part_strict m
607 -- ((0,0,0),(4,0,0),(7,8,0))
609 lt_part_strict :: (Arity m, Ring.C a)
612 lt_part_strict matrix =
615 lambda i j = if i > j then matrix !!! (i,j) else 0
618 -- | Given a square @matrix@, return a new matrix of the same size
619 -- containing only the on-diagonal and above-diagonal entries of
620 -- @matrix@. The below-diagonal entries are set to zero.
624 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
626 -- ((1,2,3),(0,5,6),(0,0,9))
628 ut_part :: (Arity m, Ring.C a)
631 ut_part = transpose . lt_part . transpose
634 -- | Given a square @matrix@, return a new matrix of the same size
635 -- containing only the above-diagonal entries of @matrix@. The on-
636 -- and below-diagonal entries are set to zero.
640 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
641 -- >>> ut_part_strict m
642 -- ((0,2,3),(0,0,6),(0,0,0))
644 ut_part_strict :: (Arity m, Ring.C a)
647 ut_part_strict = transpose . lt_part_strict . transpose