]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Matrix.hs
Remove non-fixed Matrix module.
[numerical-analysis.git] / src / Matrix.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE TypeFamilies #-}
6
7 module Matrix
8 where
9
10 import Vector
11 import Data.Vector.Fixed (
12 Arity(..),
13 Dim,
14 Vector,
15 (!),
16 )
17 import qualified Data.Vector.Fixed as V (
18 fromList,
19 length,
20 map,
21 toList
22 )
23 import Data.Vector.Fixed.Internal (arity)
24
25 type Mat v w a = Vn v (Vn w a)
26 type Mat2 a = Mat Vec2D Vec2D a
27 type Mat3 a = Mat Vec3D Vec3D a
28 type Mat4 a = Mat Vec4D Vec4D a
29
30 -- | Convert a matrix to a nested list.
31 toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]]
32 toList m = map V.toList (V.toList m)
33
34 -- | Create a matrix from a nested list.
35 fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a
36 fromList vs = V.fromList $ map V.fromList vs
37
38
39 -- | Unsafe indexing.
40 (!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a
41 (!!!) m (i, j) = (row m i) ! j
42
43 -- | Safe indexing.
44 (!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a
45 -> (Int, Int)
46 -> Maybe a
47 (!!?) m (i, j)
48 | i < 0 || j < 0 = Nothing
49 | i > V.length m = Nothing
50 | otherwise = if j > V.length (row m j)
51 then Nothing
52 else Just $ (row m j) ! j
53
54
55 -- | The number of rows in the matrix.
56 nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
57 nrows = V.length
58
59 -- | The number of columns in the first row of the
60 -- matrix. Implementation stolen from Data.Vector.Fixed.length.
61 ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
62 ncols _ = arity (undefined :: Dim w)
63
64 -- | Return the @i@th row of @m@. Unsafe.
65 row :: (Vector v (Vn w a), Vector w a) => Mat v w a
66 -> Int
67 -> Vn w a
68 row m i = m ! i
69
70
71 -- | Return the @j@th column of @m@. Unsafe.
72 column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a
73 -> Int
74 -> Vn v a
75 column m j =
76 V.map (element j) m
77 where
78 element = flip (!)
79
80
81 -- | Transpose @m@; switch it's columns and its rows. This is a dirty
82 -- implementation.. it would be a little cleaner to use imap, but it
83 -- doesn't seem to work.
84 --
85 -- TODO: Don't cheat with fromList.
86 --
87 -- Examples:
88 --
89 -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
90 -- >>> transpose m
91 -- ((1,3),(2,4))
92 --
93 transpose :: (Vector v (Vn w a),
94 Vector w (Vn v a),
95 Vector v a,
96 Vector w a)
97 => Mat v w a
98 -> Mat w v a
99 transpose m = V.fromList column_list
100 where
101 column_list = [ column m i | i <- [0..(ncols m)-1] ]
102
103 -- | Is @m@ symmetric?
104 --
105 -- Examples:
106 --
107 -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
108 -- >>> symmetric m1
109 -- True
110 --
111 -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
112 -- >>> symmetric m2
113 -- False
114 --
115 symmetric :: (Vector v (Vn w a),
116 Vector w a,
117 v ~ w,
118 Vector w Bool,
119 Eq a)
120 => Mat v w a
121 -> Bool
122 symmetric m =
123 m == (transpose m)
124
125
126 -- | Construct a new matrix from a function @lambda@. The function
127 -- @lambda@ should take two parameters i,j corresponding to the
128 -- entries in the matrix. The i,j entry of the resulting matrix will
129 -- have the value returned by lambda i j.
130 --
131 -- TODO: Don't cheat with fromList.
132 --
133 -- Examples:
134 --
135 -- >>> let lambda i j = i + j
136 -- >>> construct lambda :: Mat3 Int
137 -- ((0,1,2),(1,2,3),(2,3,4))
138 --
139 construct :: forall v w a.
140 (Vector v (Vn w a),
141 Vector w a)
142 => (Int -> Int -> a)
143 -> Mat v w a
144 construct lambda = rows
145 where
146 -- The arity trick is used in Data.Vector.Fixed.length.
147 imax = (arity (undefined :: Dim v)) - 1
148 jmax = (arity (undefined :: Dim w)) - 1
149 row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
150 rows = V.fromList [ row' i | i <- [0..imax] ]
151
152 -- | Given a positive-definite matrix @m@, computes the
153 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
154 -- values on the diagonal of @r@ positive.
155 --
156 -- Examples:
157 --
158 -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
159 -- >>> cholesky m1
160 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
161 -- >>> (transpose (cholesky m1)) `mult` (cholesky m1)
162 -- ((20.000000000000004,-1.0),(-1.0,20.0))
163 --
164 cholesky :: forall a v w.
165 (RealFloat a,
166 Vector v (Vn w a),
167 Vector w a)
168 => (Mat v w a)
169 -> (Mat v w a)
170 cholesky m = construct r
171 where
172 r :: Int -> Int -> a
173 r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
174 | i < j =
175 (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
176 | otherwise = 0
177
178 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and
179 -- we need additional restrictions on the result type anyway.
180 --
181 -- Examples:
182 --
183 -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int
184 -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int
185 -- >>> m1 `mult` m2
186 -- ((22,28),(49,64))
187 --
188 mult :: (Num a,
189 Vector v (Vn w a),
190 Vector w a,
191 Vector w (Vn z a),
192 Vector z a,
193 Vector v (Vn z a))
194 => Mat v w a
195 -> Mat w z a
196 -> Mat v z a
197 mult m1 m2 = construct lambda
198 where
199 lambda i j =
200 sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]