]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Linear/Matrix.hs
Move the Vector and Matrix modules under Linear.
[numerical-analysis.git] / src / Linear / Matrix.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE TypeFamilies #-}
6
7 module Linear.Matrix
8 where
9
10 import Data.Vector.Fixed (
11 Dim,
12 Vector
13 )
14 import qualified Data.Vector.Fixed as V (
15 fromList,
16 length,
17 map,
18 toList
19 )
20 import Data.Vector.Fixed.Internal (arity)
21
22 import Linear.Vector
23
24 type Mat v w a = Vn v (Vn w a)
25 type Mat2 a = Mat Vec2D Vec2D a
26 type Mat3 a = Mat Vec3D Vec3D a
27 type Mat4 a = Mat Vec4D Vec4D a
28
29 -- | Convert a matrix to a nested list.
30 toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]]
31 toList m = map V.toList (V.toList m)
32
33 -- | Create a matrix from a nested list.
34 fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a
35 fromList vs = V.fromList $ map V.fromList vs
36
37
38 -- | Unsafe indexing.
39 (!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a
40 (!!!) m (i, j) = (row m i) ! j
41
42 -- | Safe indexing.
43 (!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a
44 -> (Int, Int)
45 -> Maybe a
46 (!!?) m (i, j)
47 | i < 0 || j < 0 = Nothing
48 | i > V.length m = Nothing
49 | otherwise = if j > V.length (row m j)
50 then Nothing
51 else Just $ (row m j) ! j
52
53
54 -- | The number of rows in the matrix.
55 nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
56 nrows = V.length
57
58 -- | The number of columns in the first row of the
59 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
60 ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
61 ncols _ = arity (undefined :: Dim w)
62
63 -- | Return the @i@th row of @m@. Unsafe.
64 row :: (Vector v (Vn w a), Vector w a) => Mat v w a
65 -> Int
66 -> Vn w a
67 row m i = m ! i
68
69
70 -- | Return the @j@th column of @m@. Unsafe.
71 column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a
72 -> Int
73 -> Vn v a
74 column m j =
75 V.map (element j) m
76 where
77 element = flip (!)
78
79
80 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
81 -- implementation.. it would be a little cleaner to use imap, but it
82 -- doesn't seem to work.
83 --
84 -- TODO: Don't cheat with fromList.
85 --
86 -- Examples:
87 --
88 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
89 -- >>> transpose m
90 -- ((1,3),(2,4))
91 --
92 transpose :: (Vector v (Vn w a),
93 Vector w (Vn v a),
94 Vector v a,
95 Vector w a)
96 => Mat v w a
97 -> Mat w v a
98 transpose m = V.fromList column_list
99 where
100 column_list = [ column m i | i <- [0..(ncols m)-1] ]
101
102 -- | Is @m@ symmetric?
103 --
104 -- Examples:
105 --
106 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
107 -- >>> symmetric m1
108 -- True
109 --
110 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
111 -- >>> symmetric m2
112 -- False
113 --
114 symmetric :: (Vector v (Vn w a),
115 Vector w a,
116 v ~ w,
117 Vector w Bool,
118 Eq a)
119 => Mat v w a
120 -> Bool
121 symmetric m =
122 m == (transpose m)
123
124
125 -- | Construct a new matrix from a function @lambda@. The function
126 -- @lambda@ should take two parameters i,j corresponding to the
127 -- entries in the matrix. The i,j entry of the resulting matrix will
128 -- have the value returned by lambda i j.
129 --
130 -- TODO: Don't cheat with fromList.
131 --
132 -- Examples:
133 --
134 -- >>> let lambda i j = i + j
135 -- >>> construct lambda :: Mat3 Int
136 -- ((0,1,2),(1,2,3),(2,3,4))
137 --
138 construct :: forall v w a.
139 (Vector v (Vn w a),
140 Vector w a)
141 => (Int -> Int -> a)
142 -> Mat v w a
143 construct lambda = rows
144 where
145 -- The arity trick is used in Data.Vector.Fixed.length.
146 imax = (arity (undefined :: Dim v)) - 1
147 jmax = (arity (undefined :: Dim w)) - 1
148 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
149 rows = V.fromList [ row' i | i <- [0..imax] ]
150
151 -- | Given a positive-definite matrix @m@, computes the
152 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
153 -- values on the diagonal of @r@ positive.
154 --
155 -- Examples:
156 --
157 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
158 -- >>> cholesky m1
159 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
160 -- >>> (transpose (cholesky m1)) `mult` (cholesky m1)
161 -- ((20.000000000000004,-1.0),(-1.0,20.0))
162 --
163 cholesky :: forall a v w.
164 (RealFloat a,
165 Vector v (Vn w a),
166 Vector w a)
167 => (Mat v w a)
168 -> (Mat v w a)
169 cholesky m = construct r
170 where
171 r :: Int -> Int -> a
172 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
173 | i < j =
174 (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
175 | otherwise = 0
176
177 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
178 -- we need additional restrictions on the result type anyway.
179 --
180 -- Examples:
181 --
182 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int
183 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int
184 -- >>> m1 `mult` m2
185 -- ((22,28),(49,64))
186 --
187 mult :: (Num a,
188 Vector v (Vn w a),
189 Vector w a,
190 Vector w (Vn z a),
191 Vector z a,
192 Vector v (Vn z a))
193 => Mat v w a
194 -> Mat w z a
195 -> Mat v z a
196 mult m1 m2 = construct lambda
197 where
198 lambda i j =
199 sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]