{-# LANGUAGE ScopedTypeVariables #-} -- | A Matrix type using Data.Vector as the underlying type. In other -- words, the size is not fixed, but at least we have safe indexing if -- we want it. -- -- This should be replaced with a fixed-size implementation eventually! -- module Matrix where import qualified Data.Vector as V type Rows a = V.Vector (V.Vector a) type Columns a = V.Vector (V.Vector a) data Matrix a = Matrix (Rows a) deriving Eq -- | Unsafe indexing (!) :: (Matrix a) -> (Int, Int) -> a (Matrix rows) ! (i, j) = (rows V.! i) V.! j -- | Safe indexing (!?) :: (Matrix a) -> (Int, Int) -> Maybe a (Matrix rows) !? (i, j) = do row <- rows V.!? i col <- row V.!? j return col -- | Unsafe indexing without bounds checking unsafeIndex :: (Matrix a) -> (Int, Int) -> a (Matrix rows) `unsafeIndex` (i, j) = (rows `V.unsafeIndex` i) `V.unsafeIndex` j -- | Return the @i@th column of @m@. Unsafe! column :: (Matrix a) -> Int -> (V.Vector a) column (Matrix rows) i = V.fromList [row V.! i | row <- V.toList rows] -- | The number of rows in the matrix. nrows :: (Matrix a) -> Int nrows (Matrix rows) = V.length rows -- | The number of columns in the first row of the matrix. ncols :: (Matrix a) -> Int ncols (Matrix rows) | V.length rows == 0 = 0 | otherwise = V.length (rows V.! 0) -- | Return the vector of @m@'s columns. columns :: (Matrix a) -> (Columns a) columns m = V.fromList [column m i | i <- [0..(ncols m)-1]] -- | Transose @m@; switch it's columns and its rows. transpose :: (Matrix a) -> (Matrix a) transpose m = Matrix (columns m) instance Show a => Show (Matrix a) where show (Matrix rows) = concat $ V.toList $ V.map show_row rows where show_row r = "[" ++ (show r) ++ "]\n" instance Functor Matrix where f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows) -- | Vector addition. vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) vplus xs ys = V.zipWith (+) xs ys -- | Vector subtraction. vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) vminus xs ys = V.zipWith (-) xs ys -- | Add two vectors of rows. rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a) rowsplus rs1 rs2 = V.zipWith vplus rs1 rs2 -- | Subtract two vectors of rows. rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a) rowsminus rs1 rs2 = V.zipWith vminus rs1 rs2 -- | Matrix multiplication. mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a) mtimes m1 m2 = Matrix (V.fromList rows) where row i = V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ] | j <- [0..(ncols m2)-1] ] rows = [row i | i <- [0..(nrows m1)-1]] -- | Is @m@ symmetric? symmetric :: Eq a => (Matrix a) -> Bool symmetric m = m == (transpose m) -- | Construct a new matrix from a function @lambda@. The function -- @lambda@ should take two parameters i,j corresponding to the -- entries in the matrix. The i,j entry of the resulting matrix will -- have the value returned by lambda i j. -- -- The @imax@ and @jmax@ parameters determine the size of the matrix. -- construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a) construct imax jmax lambda = Matrix rows where row i = V.fromList [ lambda i j | j <- [0..jmax] ] rows = V.fromList [ row i | i <- [0..imax] ] -- | Given a positive-definite matrix @m@, computes the -- upper-triangular matrix @r@ with (transpose r)*r == m and all -- values on the diagonal of @r@ positive. cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a) cholesky m = construct (nrows m - 1) (ncols m - 1) r where r :: Int -> Int -> a r i j | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) | i < j = (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) | otherwise = 0 -- | It's not correct to use Num here, but I really don't want to have -- to define my own addition and subtraction. instance Num a => Num (Matrix a) where -- Standard componentwise addition. (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2) -- Standard componentwise subtraction. (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2) -- Matrix multiplication. m1 * m2 = m1 `mtimes` m2 abs _ = error "absolute value of matrices is undefined" signum _ = error "signum of matrices is undefined" fromInteger _ = error "fromInteger of matrices is undefined"