1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE TypeFamilies #-}
11 import Data.Vector.Fixed (
17 import qualified Data.Vector.Fixed as V (
23 import Data.Vector.Fixed.Internal (arity)
25 type Mat v w a = Vn v (Vn w a)
26 type Mat2 a = Mat Vec2D Vec2D a
27 type Mat3 a = Mat Vec3D Vec3D a
28 type Mat4 a = Mat Vec4D Vec4D a
30 -- | Convert a matrix to a nested list.
31 toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]]
32 toList m = map V.toList (V.toList m)
34 -- | Create a matrix from a nested list.
35 fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a
36 fromList vs = V.fromList $ map V.fromList vs
40 (!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a
41 (!!!) m (i, j) = (row m i) ! j
44 (!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a
48 | i < 0 || j < 0 = Nothing
49 | i > V.length m = Nothing
50 | otherwise = if j > V.length (row m j)
52 else Just $ (row m j) ! j
55 -- | The number of rows in the matrix.
56 nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
59 -- | The number of columns in the first row of the
60 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
61 ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
62 ncols _ = arity (undefined :: Dim w)
64 -- | Return the @i@th row of @m@. Unsafe.
65 row :: (Vector v (Vn w a), Vector w a) => Mat v w a
71 -- | Return the @j@th column of @m@. Unsafe.
72 column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a
81 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
82 -- implementation.. it would be a little cleaner to use imap, but it
83 -- doesn't seem to work.
85 -- TODO: Don't cheat with fromList.
89 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
93 transpose :: (Vector v (Vn w a),
99 transpose m = V.fromList column_list
101 column_list = [ column m i | i <- [0..(ncols m)-1] ]
103 -- | Is @m@ symmetric?
107 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
111 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
115 symmetric :: (Vector v (Vn w a),
126 -- | Construct a new matrix from a function @lambda@. The function
127 -- @lambda@ should take two parameters i,j corresponding to the
128 -- entries in the matrix. The i,j entry of the resulting matrix will
129 -- have the value returned by lambda i j.
131 -- TODO: Don't cheat with fromList.
135 -- >>> let lambda i j = i + j
136 -- >>> construct lambda :: Mat3 Int
137 -- ((0,1,2),(1,2,3),(2,3,4))
139 construct :: forall v w a.
144 construct lambda = rows
146 -- The arity trick is used in Data.Vector.Fixed.length.
147 imax = (arity (undefined :: Dim v)) - 1
148 jmax = (arity (undefined :: Dim w)) - 1
149 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
150 rows = V.fromList [ row' i | i <- [0..imax] ]
152 -- | Given a positive-definite matrix @m@, computes the
153 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
154 -- values on the diagonal of @r@ positive.
158 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
160 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
161 -- >>> (transpose (cholesky m1)) `mult` (cholesky m1)
162 -- ((20.000000000000004,-1.0),(-1.0,20.0))
164 cholesky :: forall a v w.
170 cholesky m = construct r
173 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
175 (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
178 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
179 -- we need additional restrictions on the result type anyway.
183 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int
184 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int
197 mult m1 m2 = construct lambda
200 sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]