1 {-# LANGUAGE ExistentialQuantification #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE ScopedTypeVariables #-}
6 {-# LANGUAGE TypeFamilies #-}
7 {-# LANGUAGE RebindableSyntax #-}
12 import Data.List (intercalate)
14 import Data.Vector.Fixed (
18 import qualified Data.Vector.Fixed as V (
32 import Data.Vector.Fixed.Internal (Arity, arity, S, Dim)
36 import NumericPrelude hiding ((*), abs)
37 import qualified NumericPrelude as NP ((*))
38 import qualified Algebra.Algebraic as Algebraic
39 import Algebra.Algebraic (root)
40 import qualified Algebra.Absolute as Absolute
41 import qualified Algebra.Additive as Additive
42 import qualified Algebra.Ring as Ring
43 import Algebra.Absolute (abs)
44 import qualified Algebra.Field as Field
45 import qualified Algebra.Module as Module
46 import qualified Algebra.RealField as RealField
47 import qualified Algebra.RealRing as RealRing
48 import qualified Algebra.ToRational as ToRational
49 import qualified Algebra.Transcendental as Transcendental
50 import qualified Prelude as P
52 data Mat v w a = (Vector v (w a), Vector w a) => Mat (v (w a))
53 type Mat1 a = Mat D1 D1 a
54 type Mat2 a = Mat D2 D2 a
55 type Mat3 a = Mat D3 D3 a
56 type Mat4 a = Mat D4 D4 a
58 -- We can't just declare that all instances of Vector are instances of
59 -- Eq unfortunately. We wind up with an overlapping instance for
61 instance (Eq a, Vector v Bool, Vector w Bool) => Eq (Mat v w a) where
62 -- | Compare a row at a time.
66 -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int
67 -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int
68 -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int
74 (Mat rows1) == (Mat rows2) =
75 V.and $ V.zipWith comp rows1 rows2
77 -- Compare a row, one column at a time.
78 comp row1 row2 = V.and (V.zipWith (==) row1 row2)
81 instance (Show a, Vector v String, Vector w String) => Show (Mat v w a) where
82 -- | Display matrices and vectors as ordinary tuples. This is poor
83 -- practice, but these results are primarily displayed
84 -- interactively and convenience trumps correctness (said the guy
85 -- who insists his vector lengths be statically checked at
90 -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int
95 "(" ++ (intercalate "," (V.toList row_strings)) ++ ")"
97 row_strings = V.map show_vector rows
99 "(" ++ (intercalate "," element_strings) ++ ")"
102 element_strings = P.map show v1l
106 -- | Convert a matrix to a nested list.
107 toList :: Mat v w a -> [[a]]
108 toList (Mat rows) = map V.toList (V.toList rows)
110 -- | Create a matrix from a nested list.
111 fromList :: (Vector v (w a), Vector w a, Vector v a) => [[a]] -> Mat v w a
112 fromList vs = Mat (V.fromList $ map V.fromList vs)
115 -- | Unsafe indexing.
116 (!!!) :: (Vector w a) => Mat v w a -> (Int, Int) -> a
117 (!!!) m (i, j) = (row m i) ! j
120 (!!?) :: Mat v w a -> (Int, Int) -> Maybe a
121 (!!?) m@(Mat rows) (i, j)
122 | i < 0 || j < 0 = Nothing
123 | i > V.length rows = Nothing
124 | otherwise = if j > V.length (row m j)
126 else Just $ (row m j) ! j
129 -- | The number of rows in the matrix.
130 nrows :: Mat v w a -> Int
131 nrows (Mat rows) = V.length rows
133 -- | The number of columns in the first row of the
134 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
135 ncols :: forall v w a. (Vector w a) => Mat v w a -> Int
136 ncols _ = (arity (undefined :: Dim w))
138 -- | Return the @i@th row of @m@. Unsafe.
139 row :: Mat v w a -> Int -> w a
140 row (Mat rows) i = rows ! i
143 -- | Return the @j@th column of @m@. Unsafe.
144 column :: (Vector v a) => Mat v w a -> Int -> v a
145 column (Mat rows) j =
146 V.map (element j) rows
151 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
152 -- implementation.. it would be a little cleaner to use imap, but it
153 -- doesn't seem to work.
155 -- TODO: Don't cheat with fromList.
159 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
163 transpose :: (Vector w (v a),
168 transpose m = Mat $ V.fromList column_list
170 column_list = [ column m i | i <- [0..(ncols m)-1] ]
173 -- | Is @m@ symmetric?
177 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
181 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
185 symmetric :: (Vector v (w a),
196 -- | Construct a new matrix from a function @lambda@. The function
197 -- @lambda@ should take two parameters i,j corresponding to the
198 -- entries in the matrix. The i,j entry of the resulting matrix will
199 -- have the value returned by lambda i j.
201 -- TODO: Don't cheat with fromList.
205 -- >>> let lambda i j = i + j
206 -- >>> construct lambda :: Mat3 Int
207 -- ((0,1,2),(1,2,3),(2,3,4))
209 construct :: forall v w a.
214 construct lambda = Mat rows
216 -- The arity trick is used in Data.Vector.Fixed.length.
217 imax = (arity (undefined :: Dim v)) - 1
218 jmax = (arity (undefined :: Dim w)) - 1
219 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
220 rows = V.fromList [ row' i | i <- [0..imax] ]
222 -- | Given a positive-definite matrix @m@, computes the
223 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
224 -- values on the diagonal of @r@ positive.
228 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
230 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
231 -- >>> (transpose (cholesky m1)) * (cholesky m1)
232 -- ((20.000000000000004,-1.0),(-1.0,20.0))
234 cholesky :: forall a v w.
241 cholesky m = construct r
244 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]])
246 (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i)
250 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
251 -- we need additional restrictions on the result type anyway.
255 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat D2 D3 Int
256 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat D3 D2 Int
269 (*) m1 m2 = construct lambda
272 sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]
279 => Additive.C (Mat v w a) where
281 (Mat rows1) + (Mat rows2) =
282 Mat $ V.zipWith (V.zipWith (+)) rows1 rows2
284 (Mat rows1) - (Mat rows2) =
285 Mat $ V.zipWith (V.zipWith (-)) rows1 rows2
287 zero = Mat (V.replicate $ V.replicate (fromInteger 0))
294 => Ring.C (Mat v w a) where
295 -- The first * is ring multiplication, the second is matrix
303 => Module.C a (Mat v w a) where
304 -- We can multiply a matrix by a scalar of the same type as its
306 x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows
309 instance (Algebraic.C a,
315 => Normed (Mat v w a) where
316 -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat
317 -- all matrices as big vectors.
321 -- >>> let v1 = vec2d (3,4)
327 norm_p p (Mat rows) =
328 (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
331 xs = concat $ V.toList $ V.map V.toList rows
333 -- | The infinity norm. We don't use V.maximum here because it
334 -- relies on a type constraint that the vector be non-empty and I
335 -- don't know how to pattern match it away.
339 -- >>> let v1 = vec3d (1,5,2)
343 norm_infty m@(Mat rows)
344 | nrows m == 0 || ncols m == 0 = 0
346 fromRational' $ toRational $
347 P.maximum $ V.toList $ V.map (P.maximum . V.toList) rows
353 -- Vector helpers. We want it to be easy to create low-dimension
354 -- column vectors, which are nx1 matrices.
356 -- | Convenient constructor for 2D vectors.
360 -- >>> import Roots.Simple
361 -- >>> let h = 0.5 :: Double
362 -- >>> let g1 (Mat (D2 (D1 x) (D1 y))) = 1.0 + h*exp(-(x^2))/(1.0 + y^2)
363 -- >>> let g2 (Mat (D2 (D1 x) (D1 y))) = 0.5 + h*atan(x^2 + y^2)
364 -- >>> let g u = vec2d ((g1 u), (g2 u))
365 -- >>> let u0 = vec2d (1.0, 1.0)
366 -- >>> let eps = 1/(10^9)
367 -- >>> fixed_point g eps u0
368 -- (1.0728549599342185,1.0820591495686167)
370 vec2d :: (a,a) -> Mat D2 D1 a
371 vec2d (x,y) = Mat (D2 (D1 x) (D1 y))
373 vec3d :: (a,a,a) -> Mat D3 D1 a
374 vec3d (x,y,z) = Mat (D3 (D1 x) (D1 y) (D1 z))
376 vec4d :: (a,a,a,a) -> Mat D4 D1 a
377 vec4d (w,x,y,z) = Mat (D4 (D1 w) (D1 x) (D1 y) (D1 z))
379 -- Since we commandeered multiplication, we need to create 1x1
380 -- matrices in order to multiply things.
381 scalar :: a -> Mat D1 D1 a
382 scalar x = Mat (D1 (D1 x))
384 dot :: (RealRing.C a,
394 v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0)
397 -- | The angle between @v1@ and @v2@ in Euclidean space.
401 -- >>> let v1 = vec2d (1.0, 0.0)
402 -- >>> let v2 = vec2d (0.0, 1.0)
403 -- >>> angle v1 v2 == pi/2.0
406 angle :: (Transcendental.C a,
423 theta = (recip norms) NP.* (v1 `dot` v2)
424 norms = (norm v1) NP.* (norm v2)