]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Matrix.hs
Bump to fixed-vector-0.2.*.
[numerical-analysis.git] / src / Matrix.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE TypeFamilies #-}
6
7 module Matrix
8 where
9
10 import Vector
11 import Data.Vector.Fixed (
12 Dim,
13 Vector
14 )
15 import qualified Data.Vector.Fixed as V (
16 fromList,
17 length,
18 map,
19 toList
20 )
21 import Data.Vector.Fixed.Internal (arity)
22
23 type Mat v w a = Vn v (Vn w a)
24 type Mat2 a = Mat Vec2D Vec2D a
25 type Mat3 a = Mat Vec3D Vec3D a
26 type Mat4 a = Mat Vec4D Vec4D a
27
28 -- | Convert a matrix to a nested list.
29 toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]]
30 toList m = map V.toList (V.toList m)
31
32 -- | Create a matrix from a nested list.
33 fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a
34 fromList vs = V.fromList $ map V.fromList vs
35
36
37 -- | Unsafe indexing.
38 (!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a
39 (!!!) m (i, j) = (row m i) ! j
40
41 -- | Safe indexing.
42 (!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a
43 -> (Int, Int)
44 -> Maybe a
45 (!!?) m (i, j)
46 | i < 0 || j < 0 = Nothing
47 | i > V.length m = Nothing
48 | otherwise = if j > V.length (row m j)
49 then Nothing
50 else Just $ (row m j) ! j
51
52
53 -- | The number of rows in the matrix.
54 nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
55 nrows = V.length
56
57 -- | The number of columns in the first row of the
58 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
59 ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
60 ncols _ = arity (undefined :: Dim w)
61
62 -- | Return the @i@th row of @m@. Unsafe.
63 row :: (Vector v (Vn w a), Vector w a) => Mat v w a
64 -> Int
65 -> Vn w a
66 row m i = m ! i
67
68
69 -- | Return the @j@th column of @m@. Unsafe.
70 column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a
71 -> Int
72 -> Vn v a
73 column m j =
74 V.map (element j) m
75 where
76 element = flip (!)
77
78
79 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
80 -- implementation.. it would be a little cleaner to use imap, but it
81 -- doesn't seem to work.
82 --
83 -- TODO: Don't cheat with fromList.
84 --
85 -- Examples:
86 --
87 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
88 -- >>> transpose m
89 -- ((1,3),(2,4))
90 --
91 transpose :: (Vector v (Vn w a),
92 Vector w (Vn v a),
93 Vector v a,
94 Vector w a)
95 => Mat v w a
96 -> Mat w v a
97 transpose m = V.fromList column_list
98 where
99 column_list = [ column m i | i <- [0..(ncols m)-1] ]
100
101 -- | Is @m@ symmetric?
102 --
103 -- Examples:
104 --
105 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
106 -- >>> symmetric m1
107 -- True
108 --
109 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
110 -- >>> symmetric m2
111 -- False
112 --
113 symmetric :: (Vector v (Vn w a),
114 Vector w a,
115 v ~ w,
116 Vector w Bool,
117 Eq a)
118 => Mat v w a
119 -> Bool
120 symmetric m =
121 m == (transpose m)
122
123
124 -- | Construct a new matrix from a function @lambda@. The function
125 -- @lambda@ should take two parameters i,j corresponding to the
126 -- entries in the matrix. The i,j entry of the resulting matrix will
127 -- have the value returned by lambda i j.
128 --
129 -- TODO: Don't cheat with fromList.
130 --
131 -- Examples:
132 --
133 -- >>> let lambda i j = i + j
134 -- >>> construct lambda :: Mat3 Int
135 -- ((0,1,2),(1,2,3),(2,3,4))
136 --
137 construct :: forall v w a.
138 (Vector v (Vn w a),
139 Vector w a)
140 => (Int -> Int -> a)
141 -> Mat v w a
142 construct lambda = rows
143 where
144 -- The arity trick is used in Data.Vector.Fixed.length.
145 imax = (arity (undefined :: Dim v)) - 1
146 jmax = (arity (undefined :: Dim w)) - 1
147 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
148 rows = V.fromList [ row' i | i <- [0..imax] ]
149
150 -- | Given a positive-definite matrix @m@, computes the
151 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
152 -- values on the diagonal of @r@ positive.
153 --
154 -- Examples:
155 --
156 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
157 -- >>> cholesky m1
158 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
159 -- >>> (transpose (cholesky m1)) `mult` (cholesky m1)
160 -- ((20.000000000000004,-1.0),(-1.0,20.0))
161 --
162 cholesky :: forall a v w.
163 (RealFloat a,
164 Vector v (Vn w a),
165 Vector w a)
166 => (Mat v w a)
167 -> (Mat v w a)
168 cholesky m = construct r
169 where
170 r :: Int -> Int -> a
171 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
172 | i < j =
173 (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
174 | otherwise = 0
175
176 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
177 -- we need additional restrictions on the result type anyway.
178 --
179 -- Examples:
180 --
181 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int
182 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int
183 -- >>> m1 `mult` m2
184 -- ((22,28),(49,64))
185 --
186 mult :: (Num a,
187 Vector v (Vn w a),
188 Vector w a,
189 Vector w (Vn z a),
190 Vector z a,
191 Vector v (Vn z a))
192 => Mat v w a
193 -> Mat w z a
194 -> Mat v z a
195 mult m1 m2 = construct lambda
196 where
197 lambda i j =
198 sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]