1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
9 -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We
10 -- assume that the underlying representation is
11 -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in
12 -- generality and failed.
17 import Data.List (intercalate)
19 import Data.Vector.Fixed (
33 import qualified Data.Vector.Fixed as V (
43 import Data.Vector.Fixed.Boxed (Vec)
44 import Data.Vector.Fixed.Internal (Arity, arity)
48 import NumericPrelude hiding ((*), abs)
49 import qualified NumericPrelude as NP ((*))
50 import qualified Algebra.Algebraic as Algebraic
51 import Algebra.Algebraic (root)
52 import qualified Algebra.Additive as Additive
53 import qualified Algebra.Ring as Ring
54 import qualified Algebra.Module as Module
55 import qualified Algebra.RealRing as RealRing
56 import qualified Algebra.ToRational as ToRational
57 import qualified Algebra.Transcendental as Transcendental
58 import qualified Prelude as P
60 data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a))
61 type Mat1 a = Mat N1 N1 a
62 type Mat2 a = Mat N2 N2 a
63 type Mat3 a = Mat N3 N3 a
64 type Mat4 a = Mat N4 N4 a
65 type Mat5 a = Mat N5 N5 a
67 instance (Eq a) => Eq (Mat m n a) where
68 -- | Compare a row at a time.
72 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
73 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
74 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
80 (Mat rows1) == (Mat rows2) =
81 V.and $ V.zipWith comp rows1 rows2
83 -- Compare a row, one column at a time.
84 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
87 instance (Show a) => Show (Mat m n a) where
88 -- | Display matrices and vectors as ordinary tuples. This is poor
89 -- practice, but these results are primarily displayed
90 -- interactively and convenience trumps correctness (said the guy
91 -- who insists his vector lengths be statically checked at
96 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
101 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
103 row_strings = V.map show_vector rows
105 "(" ++ (intercalate "," element_strings) ++ ")"
108 element_strings = P.map show v1l
111 -- | Convert a matrix to a nested list.
112 toList :: Mat m n a -> [[a]]
113 toList (Mat rows) = map V.toList (V.toList rows)
115 -- | Create a matrix from a nested list.
116 fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a
117 fromList vs = Mat (V.fromList $ map V.fromList vs)
120 -- | Unsafe indexing.
121 (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a
122 (!!!) m (i, j) = (row m i) ! j
125 (!!?) :: Mat m n a -> (Int, Int) -> Maybe a
126 (!!?) m@(Mat rows) (i, j)
127 | i < 0 || j < 0 = Nothing
128 | i > V.length rows = Nothing
129 | otherwise = if j > V.length (row m j)
131 else Just $ (row m j) ! j
134 -- | The number of rows in the matrix.
135 nrows :: forall m n a. (Arity m) => Mat m n a -> Int
136 nrows _ = arity (undefined :: m)
138 -- | The number of columns in the first row of the
139 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
140 ncols :: forall m n a. (Arity n) => Mat m n a -> Int
141 ncols _ = arity (undefined :: n)
144 -- | Return the @i@th row of @m@. Unsafe.
145 row :: Mat m n a -> Int -> (Vec n a)
146 row (Mat rows) i = rows ! i
149 -- | Return the @j@th column of @m@. Unsafe.
150 column :: Mat m n a -> Int -> (Vec m a)
151 column (Mat rows) j =
152 V.map (element j) rows
159 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
160 -- implementation.. it would be a little cleaner to use imap, but it
161 -- doesn't seem to work.
163 -- TODO: Don't cheat with fromList.
167 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
171 transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a
172 transpose m = Mat $ V.fromList column_list
174 column_list = [ column m i | i <- [0..(ncols m)-1] ]
177 -- | Is @m@ symmetric?
181 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
185 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
189 symmetric :: (Eq a, Arity m) => Mat m m a -> Bool
194 -- | Construct a new matrix from a function @lambda@. The function
195 -- @lambda@ should take two parameters i,j corresponding to the
196 -- entries in the matrix. The i,j entry of the resulting matrix will
197 -- have the value returned by lambda i j.
199 -- TODO: Don't cheat with fromList.
203 -- >>> let lambda i j = i + j
204 -- >>> construct lambda :: Mat3 Int
205 -- ((0,1,2),(1,2,3),(2,3,4))
207 construct :: forall m n a. (Arity m, Arity n)
208 => (Int -> Int -> a) -> Mat m n a
209 construct lambda = Mat rows
211 -- The arity trick is used in Data.Vector.Fixed.length.
212 imax = (arity (undefined :: m)) - 1
213 jmax = (arity (undefined :: n)) - 1
214 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
215 rows = V.fromList [ row' i | i <- [0..imax] ]
218 -- | Given a positive-definite matrix @m@, computes the
219 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
220 -- values on the diagonal of @r@ positive.
224 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
226 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
227 -- >>> (transpose (cholesky m1)) * (cholesky m1)
228 -- ((20.000000000000004,-1.0),(-1.0,20.0))
230 cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n)
231 => (Mat m n a) -> (Mat m n a)
232 cholesky m = construct r
235 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
237 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
241 -- | Returns True if the given matrix is upper-triangular, and False
246 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
247 -- >>> is_upper_triangular m
250 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
251 -- >>> is_upper_triangular m
254 is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
256 is_upper_triangular m =
259 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
261 test :: Int -> Int -> Bool
264 | otherwise = m !!! (i,j) == 0
267 -- | Returns True if the given matrix is lower-triangular, and False
272 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
273 -- >>> is_lower_triangular m
276 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
277 -- >>> is_lower_triangular m
280 is_lower_triangular :: (Eq a,
286 is_lower_triangular = is_upper_triangular . transpose
289 -- | Returns True if the given matrix is triangular, and False
294 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
295 -- >>> is_triangular m
298 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
299 -- >>> is_triangular m
302 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
303 -- >>> is_triangular m
306 is_triangular :: (Eq a,
312 is_triangular m = is_upper_triangular m || is_lower_triangular m
315 -- | Return the (i,j)th minor of m.
319 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
320 -- >>> minor m 0 0 :: Mat2 Int
322 -- >>> minor m 1 1 :: Mat2 Int
333 minor (Mat rows) i j = m
335 rows' = delete rows i
336 m = Mat $ V.map ((flip delete) j) rows'
339 class (Eq a, Ring.C a) => Determined p a where
340 determinant :: (p a) -> a
342 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
343 determinant m = m !!! (0,0)
345 instance (Eq a, Ring.C a, Arity m) => Determined (Mat m m) a where
346 determinant _ = undefined
348 instance (Eq a, Ring.C a, Arity n)
349 => Determined (Mat (S (S n)) (S (S n))) a where
351 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
352 | otherwise = determinant_recursive
356 det_minor i j = determinant (minor m i j)
358 determinant_recursive =
359 sum [ (-1)^(1+(toInteger j)) NP.* (m' 0 j) NP.* (det_minor 0 j)
360 | j <- [0..(ncols m)-1] ]
364 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
365 -- we need additional restrictions on the result type anyway.
369 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int
370 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int
375 (*) :: (Ring.C a, Arity m, Arity n, Arity p)
379 (*) m1 m2 = construct lambda
382 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
386 instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where
388 (Mat rows1) + (Mat rows2) =
389 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
391 (Mat rows1) - (Mat rows2) =
392 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
394 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
397 instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where
398 -- The first * is ring multiplication, the second is matrix
403 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where
404 -- We can multiply a matrix by a scalar of the same type as its
406 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
409 instance (Algebraic.C a,
413 => Normed (Mat (S m) (S n) a) where
414 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
415 -- all matrices as big vectors.
419 -- >>> let v1 = vec2d (3,4)
425 norm_p p (Mat rows) =
426 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
429 xs = concat $ V.toList $ V.map V.toList rows
431 -- | The infinity norm.
435 -- >>> let v1 = vec3d (1,5,2)
439 norm_infty (Mat rows) =
440 fromRational' $ toRational $ V.maximum $ V.map V.maximum rows
446 -- Vector helpers. We want it to be easy to create low-dimension
447 -- column vectors, which are nx1 matrices.
449 -- | Convenient constructor for 2D vectors.
453 -- >>> import Roots.Simple
454 -- >>> let fst m = m !!! (0,0)
455 -- >>> let snd m = m !!! (1,0)
456 -- >>> let h = 0.5 :: Double
457 -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2)
458 -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2)
459 -- >>> let g u = vec2d ((g1 u), (g2 u))
460 -- >>> let u0 = vec2d (1.0, 1.0)
461 -- >>> let eps = 1/(10^9)
462 -- >>> fixed_point g eps u0
463 -- ((1.0728549599342185),(1.0820591495686167))
465 vec1d :: (a) -> Mat N1 N1 a
466 vec1d (x) = Mat (mk1 (mk1 x))
468 vec2d :: (a,a) -> Mat N2 N1 a
469 vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y))
471 vec3d :: (a,a,a) -> Mat N3 N1 a
472 vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z))
474 vec4d :: (a,a,a,a) -> Mat N4 N1 a
475 vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z))
477 vec5d :: (a,a,a,a,a) -> Mat N5 N1 a
478 vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z))
480 -- Since we commandeered multiplication, we need to create 1x1
481 -- matrices in order to multiply things.
482 scalar :: a -> Mat N1 N1 a
483 scalar x = Mat (mk1 (mk1 x))
485 dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t)
489 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
492 -- | The angle between @v1@ and @v2@ in Euclidean space.
496 -- >>> let v1 = vec2d (1.0, 0.0)
497 -- >>> let v2 = vec2d (0.0, 1.0)
498 -- >>> angle v1 v2 == pi/2.0
501 angle :: (Transcendental.C a,
513 theta = (recip norms) NP.* (v1 `dot` v2)
514 norms = (norm v1) NP.* (norm v2)