1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
12 import Data.List (intercalate)
14 import Data.Vector.Fixed (
19 import qualified Data.Vector.Fixed as V (
28 import Data.Vector.Fixed.Internal (Arity, arity, S)
32 import NumericPrelude hiding ((*), abs)
33 import qualified NumericPrelude as NP ((*))
34 import qualified Algebra.Algebraic as Algebraic
35 import Algebra.Algebraic (root)
36 import qualified Algebra.Additive as Additive
37 import qualified Algebra.Ring as Ring
38 import qualified Algebra.Module as Module
39 import qualified Algebra.RealRing as RealRing
40 import qualified Algebra.ToRational as ToRational
41 import qualified Algebra.Transcendental as Transcendental
42 import qualified Prelude as P
44 data Mat v w a = (Vector v (w a), Vector w a) => Mat (v (w a))
45 type Mat1 a = Mat D1 D1 a
46 type Mat2 a = Mat D2 D2 a
47 type Mat3 a = Mat D3 D3 a
48 type Mat4 a = Mat D4 D4 a
50 -- We can't just declare that all instances of Vector are instances of
51 -- Eq unfortunately. We wind up with an overlapping instance for
53 instance (Eq a, Vector v Bool, Vector w Bool) => Eq (Mat v w a) where
54 -- | Compare a row at a time.
58 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
59 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
60 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
66 (Mat rows1) == (Mat rows2) =
67 V.and $ V.zipWith comp rows1 rows2
69 -- Compare a row, one column at a time.
70 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
73 instance (Show a, Vector v String, Vector w String) => Show (Mat v w a) where
74 -- | Display matrices and vectors as ordinary tuples. This is poor
75 -- practice, but these results are primarily displayed
76 -- interactively and convenience trumps correctness (said the guy
77 -- who insists his vector lengths be statically checked at
82 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
87 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
89 row_strings = V.map show_vector rows
91 "(" ++ (intercalate "," element_strings) ++ ")"
94 element_strings = P.map show v1l
98 -- | Convert a matrix to a nested list.
99 toList :: Mat v w a -> [[a]]
100 toList (Mat rows) = map V.toList (V.toList rows)
102 -- | Create a matrix from a nested list.
103 fromList :: (Vector v (w a), Vector w a, Vector v a) => [[a]] -> Mat v w a
104 fromList vs = Mat (V.fromList $ map V.fromList vs)
107 -- | Unsafe indexing.
108 (!!!) :: (Vector w a) => Mat v w a -> (Int, Int) -> a
109 (!!!) m (i, j) = (row m i) ! j
112 (!!?) :: Mat v w a -> (Int, Int) -> Maybe a
113 (!!?) m@(Mat rows) (i, j)
114 | i < 0 || j < 0 = Nothing
115 | i > V.length rows = Nothing
116 | otherwise = if j > V.length (row m j)
118 else Just $ (row m j) ! j
121 -- | The number of rows in the matrix.
122 nrows :: Mat v w a -> Int
123 nrows (Mat rows) = V.length rows
125 -- | The number of columns in the first row of the
126 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
127 ncols :: forall v w a. (Vector w a) => Mat v w a -> Int
128 ncols _ = (arity (undefined :: Dim w))
130 -- | Return the @i@th row of @m@. Unsafe.
131 row :: Mat v w a -> Int -> w a
132 row (Mat rows) i = rows ! i
135 -- | Return the @j@th column of @m@. Unsafe.
136 column :: (Vector v a) => Mat v w a -> Int -> v a
137 column (Mat rows) j =
138 V.map (element j) rows
143 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
144 -- implementation.. it would be a little cleaner to use imap, but it
145 -- doesn't seem to work.
147 -- TODO: Don't cheat with fromList.
151 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
155 transpose :: (Vector w (v a),
160 transpose m = Mat $ V.fromList column_list
162 column_list = [ column m i | i <- [0..(ncols m)-1] ]
165 -- | Is @m@ symmetric?
169 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
173 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
177 symmetric :: (Vector v (w a),
188 -- | Construct a new matrix from a function @lambda@. The function
189 -- @lambda@ should take two parameters i,j corresponding to the
190 -- entries in the matrix. The i,j entry of the resulting matrix will
191 -- have the value returned by lambda i j.
193 -- TODO: Don't cheat with fromList.
197 -- >>> let lambda i j = i + j
198 -- >>> construct lambda :: Mat3 Int
199 -- ((0,1,2),(1,2,3),(2,3,4))
201 construct :: forall v w a.
206 construct lambda = Mat rows
208 -- The arity trick is used in Data.Vector.Fixed.length.
209 imax = (arity (undefined :: Dim v)) - 1
210 jmax = (arity (undefined :: Dim w)) - 1
211 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
212 rows = V.fromList [ row' i | i <- [0..imax] ]
214 -- | Given a positive-definite matrix @m@, computes the
215 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
216 -- values on the diagonal of @r@ positive.
220 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
222 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
223 -- >>> (transpose (cholesky m1)) * (cholesky m1)
224 -- ((20.000000000000004,-1.0),(-1.0,20.0))
226 cholesky :: forall a v w.
233 cholesky m = construct r
236 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
238 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
242 -- | Returns True if the given matrix is upper-triangular, and False
247 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
248 -- >>> is_upper_triangular m
251 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
252 -- >>> is_upper_triangular m
255 is_upper_triangular :: (Eq a, Ring.C a, Vector w a) => Mat v w a -> Bool
256 is_upper_triangular m =
259 results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
261 test :: Int -> Int -> Bool
264 | otherwise = m !!! (i,j) == 0
267 -- | Returns True if the given matrix is lower-triangular, and False
272 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
273 -- >>> is_lower_triangular m
276 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
277 -- >>> is_lower_triangular m
280 is_lower_triangular :: (Eq a,
287 is_lower_triangular = is_upper_triangular . transpose
290 -- | Returns True if the given matrix is triangular, and False
295 -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
296 -- >>> is_triangular m
299 -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
300 -- >>> is_triangular m
303 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
304 -- >>> is_triangular m
307 is_triangular :: (Eq a,
314 is_triangular m = is_upper_triangular m || is_lower_triangular m
317 -- | Return the (i,j)th minor of m.
321 -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
322 -- >>> minor m 0 0 :: Mat2 Int
324 -- >>> minor m 1 1 :: Mat2 Int
327 minor :: (Dim v ~ S (Dim u),
336 minor (Mat rows) i j = m
338 rows' = delete rows i
339 m = Mat $ V.map ((flip delete) j) rows'
342 determinant :: (Eq a,
352 | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ]
353 | otherwise = undefined --determinant_recursive m
356 determinant_recursive :: forall v w a r c.
362 determinant_recursive m
363 | (ncols m) == 0 || (nrows m) == 0 = error "don't do that"
364 | (ncols m) == 1 && (nrows m) == 1 = m !!! (0,0) -- Base case
366 sum [ (-1)^(1+(toInteger j)) NP.* (m' 1 j) NP.* (det_minor 1 j)
367 | j <- [0..(ncols m)-1] ]
371 det_minor :: Int -> Int -> a
372 det_minor i j = determinant (minor m i j)
375 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
376 -- we need additional restrictions on the result type anyway.
380 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat D2 D3 Int
381 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat D3 D2 Int
394 (*) m1 m2 = construct lambda
397 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
404 => Additive.C (Mat v w a) where
406 (Mat rows1) + (Mat rows2) =
407 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
409 (Mat rows1) - (Mat rows2) =
410 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
412 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
419 => Ring.C (Mat v w a) where
420 -- The first * is ring multiplication, the second is matrix
428 => Module.C a (Mat v w a) where
429 -- We can multiply a matrix by a scalar of the same type as its
431 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
434 instance (Algebraic.C a,
440 => Normed (Mat v w a) where
441 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
442 -- all matrices as big vectors.
446 -- >>> let v1 = vec2d (3,4)
452 norm_p p (Mat rows) =
453 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
456 xs = concat $ V.toList $ V.map V.toList rows
458 -- | The infinity norm. We don't use V.maximum here because it
459 -- relies on a type constraint that the vector be non-empty and I
460 -- don't know how to pattern match it away.
464 -- >>> let v1 = vec3d (1,5,2)
468 norm_infty m@(Mat rows)
469 | nrows m == 0 || ncols m == 0 = 0
471 fromRational' $ toRational $
472 P.maximum $ V.toList $ V.map (P.maximum . V.toList) rows
478 -- Vector helpers. We want it to be easy to create low-dimension
479 -- column vectors, which are nx1 matrices.
481 -- | Convenient constructor for 2D vectors.
485 -- >>> import Roots.Simple
486 -- >>> let h = 0.5 :: Double
487 -- >>> let g1 (Mat (D2 (D1 x) (D1 y))) = 1.0 + h NP.* exp(-(x^2))/(1.0 + y^2)
488 -- >>> let g2 (Mat (D2 (D1 x) (D1 y))) = 0.5 + h NP.* atan(x^2 + y^2)
489 -- >>> let g u = vec2d ((g1 u), (g2 u))
490 -- >>> let u0 = vec2d (1.0, 1.0)
491 -- >>> let eps = 1/(10^9)
492 -- >>> fixed_point g eps u0
493 -- ((1.0728549599342185),(1.0820591495686167))
495 vec1d :: (a) -> Mat D1 D1 a
496 vec1d (x) = Mat (D1 (D1 x))
498 vec2d :: (a,a) -> Mat D2 D1 a
499 vec2d (x,y) = Mat (D2 (D1 x) (D1 y))
501 vec3d :: (a,a,a) -> Mat D3 D1 a
502 vec3d (x,y,z) = Mat (D3 (D1 x) (D1 y) (D1 z))
504 vec4d :: (a,a,a,a) -> Mat D4 D1 a
505 vec4d (w,x,y,z) = Mat (D4 (D1 w) (D1 x) (D1 y) (D1 z))
507 -- Since we commandeered multiplication, we need to create 1x1
508 -- matrices in order to multiply things.
509 scalar :: a -> Mat D1 D1 a
510 scalar x = Mat (D1 (D1 x))
512 dot :: (RealRing.C a,
522 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
525 -- | The angle between @v1@ and @v2@ in Euclidean space.
529 -- >>> let v1 = vec2d (1.0, 0.0)
530 -- >>> let v2 = vec2d (0.0, 1.0)
531 -- >>> angle v1 v2 == pi/2.0
534 angle :: (Transcendental.C a,
551 theta = (recip norms) NP.* (v1 `dot` v2)
552 norms = (norm v1) NP.* (norm v2)