]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Matrix.hs
a79619e3bc57281e91f4d4132e5e1abbfcb049d8
[numerical-analysis.git] / src / Matrix.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2
3 -- | A Matrix type using Data.Vector as the underlying type. In other
4 -- words, the size is not fixed, but at least we have safe indexing if
5 -- we want it.
6 --
7 -- This should be replaced with a fixed-size implementation eventually!
8 --
9 module Matrix
10 where
11
12 import qualified Data.Vector as V
13
14 type Rows a = V.Vector (V.Vector a)
15 type Columns a = V.Vector (V.Vector a)
16 data Matrix a = Matrix (Rows a) deriving Eq
17
18 -- | Unsafe indexing
19 (!) :: (Matrix a) -> (Int, Int) -> a
20 (Matrix rows) ! (i, j) = (rows V.! i) V.! j
21
22 -- | Safe indexing
23 (!?) :: (Matrix a) -> (Int, Int) -> Maybe a
24 (Matrix rows) !? (i, j) = do
25 row <- rows V.!? i
26 col <- row V.!? j
27 return col
28
29 -- | Unsafe indexing without bounds checking
30 unsafeIndex :: (Matrix a) -> (Int, Int) -> a
31 (Matrix rows) `unsafeIndex` (i, j) =
32 (rows `V.unsafeIndex` i) `V.unsafeIndex` j
33
34 -- | Return the @i@th column of @m@. Unsafe!
35 column :: (Matrix a) -> Int -> (V.Vector a)
36 column (Matrix rows) i =
37 V.fromList [row V.! i | row <- V.toList rows]
38
39 -- | The number of rows in the matrix.
40 nrows :: (Matrix a) -> Int
41 nrows (Matrix rows) = V.length rows
42
43 -- | The number of columns in the first row of the matrix.
44 ncols :: (Matrix a) -> Int
45 ncols (Matrix rows)
46 | V.length rows == 0 = 0
47 | otherwise = V.length (rows V.! 0)
48
49 -- | Return the vector of @m@'s columns.
50 columns :: (Matrix a) -> (Columns a)
51 columns m =
52 V.fromList [column m i | i <- [0..(ncols m)-1]]
53
54 -- | Transose @m@; switch it's columns and its rows.
55 transpose :: (Matrix a) -> (Matrix a)
56 transpose m =
57 Matrix (columns m)
58
59 instance Show a => Show (Matrix a) where
60 show (Matrix rows) =
61 concat $ V.toList $ V.map show_row rows
62 where show_row r = "[" ++ (show r) ++ "]\n"
63
64 instance Functor Matrix where
65 f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows)
66
67
68 -- | Vector addition.
69 vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
70 vplus xs ys = V.zipWith (+) xs ys
71
72 -- | Vector subtraction.
73 vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
74 vminus xs ys = V.zipWith (-) xs ys
75
76 -- | Add two vectors of rows.
77 rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
78 rowsplus rs1 rs2 =
79 V.zipWith vplus rs1 rs2
80
81 -- | Subtract two vectors of rows.
82 rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
83 rowsminus rs1 rs2 =
84 V.zipWith vminus rs1 rs2
85
86 -- | Matrix multiplication.
87 mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a)
88 mtimes m1@(Matrix rows1) m2@(Matrix rows2) =
89 Matrix (V.fromList rows)
90 where
91 row i = V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ]
92 | j <- [0..(ncols m2)-1] ]
93 rows = [row i | i <- [0..(nrows m1)-1]]
94
95 -- | Is @m@ symmetric?
96 symmetric :: Eq a => (Matrix a) -> Bool
97 symmetric m =
98 m == (transpose m)
99
100 -- | Construct a new matrix from a function @lambda@. The function
101 -- @lambda@ should take two parameters i,j corresponding to the
102 -- entries in the matrix. The i,j entry of the resulting matrix will
103 -- have the value returned by lambda i j.
104 --
105 -- The @imax@ and @jmax@ parameters determine the size of the matrix.
106 --
107 construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a)
108 construct imax jmax lambda =
109 Matrix rows
110 where
111 row i = V.fromList [ lambda i j | j <- [0..jmax] ]
112 rows = V.fromList [ row i | i <- [0..imax] ]
113
114 -- | Given a positive-definite matrix @m@, computes the
115 -- upper-triangular matrix @r@ with (transpose r)*r == m and all
116 -- values on the diagonal of @r@ positive.
117 cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a)
118 cholesky m =
119 construct (nrows m - 1) (ncols m - 1) r
120 where
121 r :: Int -> Int -> a
122 r i j | i > j = 0
123 | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
124 | i < j = (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
125
126 -- | It's not correct to use Num here, but I really don't want to have
127 -- to define my own addition and subtraction.
128 instance Num a => Num (Matrix a) where
129 -- Standard componentwise addition.
130 (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2)
131
132 -- Standard componentwise subtraction.
133 (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2)
134
135 -- Matrix multiplication.
136 m1 * m2 = m1 `mtimes` m2
137
138 abs _ = error "absolute value of matrices is undefined"
139
140 signum _ = error "signum of matrices is undefined"
141
142 fromInteger x = error "fromInteger of matrices is undefined"