1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
17 import Data.List (intercalate)
19 import Data.Vector.Fixed (
34 import qualified Data.Vector.Fixed as V (
45 import Data.Vector.Fixed.Boxed (Vec)
46 import Data.Vector.Fixed.Internal.Arity (Arity, arity)
50 import NumericPrelude hiding ((*), abs)
51 import qualified NumericPrelude as NP ((*))
52 import qualified Algebra.Algebraic as Algebraic
53 import Algebra.Algebraic (root)
54 import qualified Algebra.Additive as Additive
55 import qualified Algebra.Ring as Ring
56 import qualified Algebra.Module as Module
57 import qualified Algebra.RealRing as RealRing
58 import qualified Algebra.ToRational as ToRational
59 import qualified Algebra.Transcendental as Transcendental
60 import qualified Prelude as P
62 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
63 type Mat1 a = Mat N1 N1 a
64 type Mat2 a = Mat N2 N2 a
65 type Mat3 a = Mat N3 N3 a
66 type Mat4 a = Mat N4 N4 a
67 type Mat5 a = Mat N5 N5 a
69 instance (Eq a) => Eq (Mat m n a) where
70 -- | Compare a row at a time.
74 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
75 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
76 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
82 (Mat rows1) == (Mat rows2) =
83 V.and $ V.zipWith comp rows1 rows2
85 -- Compare a row, one column at a time.
86 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
89 instance (Show a) => Show (Mat m n a) where
90 -- | Display matrices and vectors as ordinary tuples. This is poor
91 -- practice, but these results are primarily displayed
92 -- interactively and convenience trumps correctness (said the guy
93 -- who insists his vector lengths be statically checked at
98 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
103 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
105 row_strings = V.map show_vector rows
107 "(" ++ (intercalate "," element_strings) ++ ")"
110 element_strings = P.map show v1l
113 -- | Convert a matrix to a nested list.
114 toList :: Mat m n a -> [[a]]
115 toList (Mat rows) = map V.toList (V.toList rows)
117 -- | Create a matrix from a nested list.
118 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
119 fromList vs = Mat (V.fromList $ map V.fromList vs)
122 -- | Unsafe indexing.
123 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
124 (!!!) m (i, j) = (row m i) ! j
127 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
128 (!!?) m@(Mat rows) (i, j)
129 | i < 0 || j < 0 = Nothing
130 | i > V.length rows = Nothing
131 | otherwise = if j > V.length (row m j)
133 else Just $ (row m j) ! j
136 -- | The number of rows in the matrix.
137 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
138 nrows _ = arity (undefined :: m)
140 -- | The number of columns in the first row of the
141 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
142 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
143 ncols _ = arity (undefined :: n)
146 -- | Return the @i@th row of @m@. Unsafe.
147 row :: Mat m n a -> Int -> (Vec n a)
148 row (Mat rows) i = rows ! i
151 -- | Return the @j@th column of @m@. Unsafe.
152 column :: Mat m n a -> Int -> (Vec m a)
153 column (Mat rows) j =
154 V.map (element j) rows
161 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
162 -- implementation.. it would be a little cleaner to use imap, but it
163 -- doesn't seem to work.
165 -- TODO: Don't cheat with fromList.
169 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
173 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
174 transpose m = Mat $ V.fromList column_list
176 column_list = [ column m i | i <- [0..(ncols m)-1] ]
179 -- | Is @m@ symmetric?
183 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
187 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
191 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
196 -- | Construct a new matrix from a function @lambda@. The function
197 -- @lambda@ should take two parameters i,j corresponding to the
198 -- entries in the matrix. The i,j entry of the resulting matrix will
199 -- have the value returned by lambda i j.
201 -- TODO: Don't cheat with fromList.
205 -- >>> let lambda i j = i + j
206 -- >>> construct lambda :: Mat3 Int
207 -- ((0,1,2),(1,2,3),(2,3,4))
209 construct :: forall m n a. (Arity m, Arity n)
210 => (Int -> Int -> a) -> Mat m n a
211 construct lambda = Mat rows
213 -- The arity trick is used in Data.Vector.Fixed.length.
214 imax = (arity (undefined :: m)) - 1
215 jmax = (arity (undefined :: n)) - 1
216 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
217 rows = V.fromList [ row' i | i <- [0..imax] ]
220 -- | Given a positive-definite matrix @m@, computes the
221 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
222 -- values on the diagonal of @r@ positive.
226 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
228 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
229 -- >>> (transpose (cholesky m1)) * (cholesky m1)
230 -- ((20.000000000000004,-1.0),(-1.0,20.0))
232 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
233 => (Mat m n a) -> (Mat m n a)
234 cholesky m = construct r
237 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
239 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
243 -- | Returns True if the given matrix is upper-triangular, and False
248 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
249 -- >>> is_upper_triangular m
252 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
253 -- >>> is_upper_triangular m
256 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
258 is_upper_triangular m =
261 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
263 test :: Int -> Int -> Bool
266 | otherwise = m !!! (i,j) == 0
269 -- | Returns True if the given matrix is lower-triangular, and False
274 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
275 -- >>> is_lower_triangular m
278 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
279 -- >>> is_lower_triangular m
282 is_lower_triangular :: (Eq a,
288 is_lower_triangular = is_upper_triangular . transpose
291 -- | Returns True if the given matrix is triangular, and False
296 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
297 -- >>> is_triangular m
300 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
301 -- >>> is_triangular m
304 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
305 -- >>> is_triangular m
308 is_triangular :: (Eq a,
314 is_triangular m = is_upper_triangular m || is_lower_triangular m
317 -- | Return the (i,j)th minor of m.
321 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
322 -- >>> minor m 0 0 :: Mat2 Int
324 -- >>> minor m 1 1 :: Mat2 Int
335 minor (Mat rows) i j = m
337 rows' = delete rows i
338 m = Mat $ V.map ((flip delete) j) rows'
341 class (Eq a, Ring.C a) => Determined p a where
342 determinant :: (p a) -> a
344 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
345 determinant (Mat rows) = (V.head . V.head) rows
350 Determined (Mat (S n) (S n)) a)
351 => Determined (Mat (S (S n)) (S (S n))) a where
352 -- | The recursive definition with a special-case for triangular matrices.
356 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
361 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
362 | otherwise = determinant_recursive
366 det_minor i j = determinant (minor m i j)
368 determinant_recursive =
369 sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j)
370 | j <- [0..(ncols m)-1] ]
374 -- | Matrix multiplication.
378 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
379 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
384 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
388 (*) m1 m2 = construct lambda
391 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
395 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
397 (Mat rows1) + (Mat rows2) =
398 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
400 (Mat rows1) - (Mat rows2) =
401 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
403 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
406 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
407 -- The first * is ring multiplication, the second is matrix
412 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
413 -- We can multiply a matrix by a scalar of the same type as its
415 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
418 instance (Algebraic.C a,
422 => Normed (Mat (S m) (S n) a) where
423 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
424 -- all matrices as big vectors.
428 -- >>> let v1 = vec2d (3,4)
434 norm_p p (Mat rows) =
435 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
438 xs = concat $ V.toList $ V.map V.toList rows
440 -- | The infinity norm.
444 -- >>> let v1 = vec3d (1,5,2)
448 norm_infty (Mat rows) =
449 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
455 -- Vector helpers. We want it to be easy to create low-dimension
456 -- column vectors, which are nx1 matrices.
458 -- | Convenient constructor for 2D vectors.
462 -- >>> import Roots.Simple
463 -- >>> let fst m = m !!! (0,0)
464 -- >>> let snd m = m !!! (1,0)
465 -- >>> let h = 0.5 :: Double
466 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
467 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
468 -- >>> let g u = vec2d ((g1 u), (g2 u))
469 -- >>> let u0 = vec2d (1.0, 1.0)
470 -- >>> let eps = 1/(10^9)
471 -- >>> fixed_point g eps u0
472 -- ((1.0728549599342185),(1.0820591495686167))
474 vec1d :: (a) -> Mat N1 N1 a
475 vec1d (x) = Mat (mk1 (mk1 x))
477 vec2d :: (a,a) -> Mat N2 N1 a
478 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
480 vec3d :: (a,a,a) -> Mat N3 N1 a
481 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
483 vec4d :: (a,a,a,a) -> Mat N4 N1 a
484 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
486 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
487 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
489 -- Since we commandeered multiplication, we need to create 1x1
490 -- matrices in order to multiply things.
491 scalar :: a -> Mat N1 N1 a
492 scalar x = Mat (mk1 (mk1 x))
494 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
498 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
501 -- | The angle between @v1@ and @v2@ in Euclidean space.
505 -- >>> let v1 = vec2d (1.0, 0.0)
506 -- >>> let v2 = vec2d (0.0, 1.0)
507 -- >>> angle v1 v2 == pi/2.0
510 angle :: (Transcendental.C a,
522 theta = (recip norms) NP.* (v1 `dot` v2)
523 norms = (norm v1) NP.* (norm v2)