From: Michael Orlitzky Date: Sun, 3 Feb 2013 01:38:13 +0000 (-0500) Subject: Add a simple unsafe matrix class to test the Cholesky algorithm. X-Git-Url: https://gitweb.michael.orlitzky.com/?a=commitdiff_plain;h=ea770ce8d107e2736576fadd32cbd8aa38ddc319;p=numerical-analysis.git Add a simple unsafe matrix class to test the Cholesky algorithm. --- diff --git a/src/Matrix.hs b/src/Matrix.hs new file mode 100644 index 0000000..a79619e --- /dev/null +++ b/src/Matrix.hs @@ -0,0 +1,142 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +-- | A Matrix type using Data.Vector as the underlying type. In other +-- words, the size is not fixed, but at least we have safe indexing if +-- we want it. +-- +-- This should be replaced with a fixed-size implementation eventually! +-- +module Matrix +where + +import qualified Data.Vector as V + +type Rows a = V.Vector (V.Vector a) +type Columns a = V.Vector (V.Vector a) +data Matrix a = Matrix (Rows a) deriving Eq + +-- | Unsafe indexing +(!) :: (Matrix a) -> (Int, Int) -> a +(Matrix rows) ! (i, j) = (rows V.! i) V.! j + +-- | Safe indexing +(!?) :: (Matrix a) -> (Int, Int) -> Maybe a +(Matrix rows) !? (i, j) = do + row <- rows V.!? i + col <- row V.!? j + return col + +-- | Unsafe indexing without bounds checking +unsafeIndex :: (Matrix a) -> (Int, Int) -> a +(Matrix rows) `unsafeIndex` (i, j) = + (rows `V.unsafeIndex` i) `V.unsafeIndex` j + +-- | Return the @i@th column of @m@. Unsafe! +column :: (Matrix a) -> Int -> (V.Vector a) +column (Matrix rows) i = + V.fromList [row V.! i | row <- V.toList rows] + +-- | The number of rows in the matrix. +nrows :: (Matrix a) -> Int +nrows (Matrix rows) = V.length rows + +-- | The number of columns in the first row of the matrix. +ncols :: (Matrix a) -> Int +ncols (Matrix rows) + | V.length rows == 0 = 0 + | otherwise = V.length (rows V.! 0) + +-- | Return the vector of @m@'s columns. +columns :: (Matrix a) -> (Columns a) +columns m = + V.fromList [column m i | i <- [0..(ncols m)-1]] + +-- | Transose @m@; switch it's columns and its rows. +transpose :: (Matrix a) -> (Matrix a) +transpose m = + Matrix (columns m) + +instance Show a => Show (Matrix a) where + show (Matrix rows) = + concat $ V.toList $ V.map show_row rows + where show_row r = "[" ++ (show r) ++ "]\n" + +instance Functor Matrix where + f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows) + + +-- | Vector addition. +vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) +vplus xs ys = V.zipWith (+) xs ys + +-- | Vector subtraction. +vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) +vminus xs ys = V.zipWith (-) xs ys + +-- | Add two vectors of rows. +rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a) +rowsplus rs1 rs2 = + V.zipWith vplus rs1 rs2 + +-- | Subtract two vectors of rows. +rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a) +rowsminus rs1 rs2 = + V.zipWith vminus rs1 rs2 + +-- | Matrix multiplication. +mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a) +mtimes m1@(Matrix rows1) m2@(Matrix rows2) = + Matrix (V.fromList rows) + where + row i = V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ] + | j <- [0..(ncols m2)-1] ] + rows = [row i | i <- [0..(nrows m1)-1]] + +-- | Is @m@ symmetric? +symmetric :: Eq a => (Matrix a) -> Bool +symmetric m = + m == (transpose m) + +-- | Construct a new matrix from a function @lambda@. The function +-- @lambda@ should take two parameters i,j corresponding to the +-- entries in the matrix. The i,j entry of the resulting matrix will +-- have the value returned by lambda i j. +-- +-- The @imax@ and @jmax@ parameters determine the size of the matrix. +-- +construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a) +construct imax jmax lambda = + Matrix rows + where + row i = V.fromList [ lambda i j | j <- [0..jmax] ] + rows = V.fromList [ row i | i <- [0..imax] ] + +-- | Given a positive-definite matrix @m@, computes the +-- upper-triangular matrix @r@ with (transpose r)*r == m and all +-- values on the diagonal of @r@ positive. +cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a) +cholesky m = + construct (nrows m - 1) (ncols m - 1) r + where + r :: Int -> Int -> a + r i j | i > j = 0 + | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) + | i < j = (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) + +-- | It's not correct to use Num here, but I really don't want to have +-- to define my own addition and subtraction. +instance Num a => Num (Matrix a) where + -- Standard componentwise addition. + (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2) + + -- Standard componentwise subtraction. + (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2) + + -- Matrix multiplication. + m1 * m2 = m1 `mtimes` m2 + + abs _ = error "absolute value of matrices is undefined" + + signum _ = error "signum of matrices is undefined" + + fromInteger x = error "fromInteger of matrices is undefined"