X-Git-Url: http://gitweb.michael.orlitzky.com/?a=blobdiff_plain;f=src%2FMatrix.hs;h=ea44ae968878ab1b5d1b02a7dd981e77f5446337;hb=c00e5ae9829a358890922083267564ec13798061;hp=349b71af2ad3d72810b3c0f6bd3585c739aa69dc;hpb=55cf46834f181be01b17b4c5a02ecd772c4e3090;p=numerical-analysis.git diff --git a/src/Matrix.hs b/src/Matrix.hs index 349b71a..ea44ae9 100644 --- a/src/Matrix.hs +++ b/src/Matrix.hs @@ -1,144 +1,200 @@ {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} --- | A Matrix type using Data.Vector as the underlying type. In other --- words, the size is not fixed, but at least we have safe indexing if --- we want it. --- --- This should be replaced with a fixed-size implementation eventually! --- module Matrix where -import qualified Data.Vector as V - -type Rows a = V.Vector (V.Vector a) -type Columns a = V.Vector (V.Vector a) -data Matrix a = Matrix (Rows a) deriving Eq - --- | Unsafe indexing -(!) :: (Matrix a) -> (Int, Int) -> a -(Matrix rows) ! (i, j) = (rows V.! i) V.! j +import Vector +import Data.Vector.Fixed ( + Arity(..), + Dim, + Vector, + (!), + ) +import qualified Data.Vector.Fixed as V ( + fromList, + length, + map, + toList + ) +import Data.Vector.Fixed.Internal (arity) + +type Mat v w a = Vn v (Vn w a) +type Mat2 a = Mat Vec2D Vec2D a +type Mat3 a = Mat Vec3D Vec3D a +type Mat4 a = Mat Vec4D Vec4D a + +-- | Convert a matrix to a nested list. +toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]] +toList m = map V.toList (V.toList m) + +-- | Create a matrix from a nested list. +fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a +fromList vs = V.fromList $ map V.fromList vs + + +-- | Unsafe indexing. +(!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a +(!!!) m (i, j) = (row m i) ! j + +-- | Safe indexing. +(!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a + -> (Int, Int) + -> Maybe a +(!!?) m (i, j) + | i < 0 || j < 0 = Nothing + | i > V.length m = Nothing + | otherwise = if j > V.length (row m j) + then Nothing + else Just $ (row m j) ! j --- | Safe indexing -(!?) :: (Matrix a) -> (Int, Int) -> Maybe a -(Matrix rows) !? (i, j) = do - row <- rows V.!? i - col <- row V.!? j - return col --- | Unsafe indexing without bounds checking -unsafeIndex :: (Matrix a) -> (Int, Int) -> a -(Matrix rows) `unsafeIndex` (i, j) = - (rows `V.unsafeIndex` i) `V.unsafeIndex` j +-- | The number of rows in the matrix. +nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int +nrows = V.length + +-- | The number of columns in the first row of the +-- matrix. Implementation stolen from Data.Vector.Fixed.length. +ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int +ncols _ = arity (undefined :: Dim w) + +-- | Return the @i@th row of @m@. Unsafe. +row :: (Vector v (Vn w a), Vector w a) => Mat v w a + -> Int + -> Vn w a +row m i = m ! i + + +-- | Return the @j@th column of @m@. Unsafe. +column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a + -> Int + -> Vn v a +column m j = + V.map (element j) m + where + element = flip (!) --- | Return the @i@th column of @m@. Unsafe! -column :: (Matrix a) -> Int -> (V.Vector a) -column (Matrix rows) i = - V.fromList [row V.! i | row <- V.toList rows] --- | The number of rows in the matrix. -nrows :: (Matrix a) -> Int -nrows (Matrix rows) = V.length rows - --- | The number of columns in the first row of the matrix. -ncols :: (Matrix a) -> Int -ncols (Matrix rows) - | V.length rows == 0 = 0 - | otherwise = V.length (rows V.! 0) - --- | Return the vector of @m@'s columns. -columns :: (Matrix a) -> (Columns a) -columns m = - V.fromList [column m i | i <- [0..(ncols m)-1]] - --- | Transose @m@; switch it's columns and its rows. -transpose :: (Matrix a) -> (Matrix a) -transpose m = - Matrix (columns m) - -instance Show a => Show (Matrix a) where - show (Matrix rows) = - concat $ V.toList $ V.map show_row rows - where show_row r = "[" ++ (show r) ++ "]\n" - -instance Functor Matrix where - f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows) - - --- | Vector addition. -vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) -vplus xs ys = V.zipWith (+) xs ys - --- | Vector subtraction. -vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a) -vminus xs ys = V.zipWith (-) xs ys - --- | Add two vectors of rows. -rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a) -rowsplus rs1 rs2 = - V.zipWith vplus rs1 rs2 - --- | Subtract two vectors of rows. -rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a) -rowsminus rs1 rs2 = - V.zipWith vminus rs1 rs2 - --- | Matrix multiplication. -mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a) -mtimes m1 m2 = - Matrix (V.fromList rows) +-- | Transpose @m@; switch it's columns and its rows. This is a dirty +-- implementation.. it would be a little cleaner to use imap, but it +-- doesn't seem to work. +-- +-- TODO: Don't cheat with fromList. +-- +-- Examples: +-- +-- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int +-- >>> transpose m +-- ((1,3),(2,4)) +-- +transpose :: (Vector v (Vn w a), + Vector w (Vn v a), + Vector v a, + Vector w a) + => Mat v w a + -> Mat w v a +transpose m = V.fromList column_list where - row i = - V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ] - | j <- [0..(ncols m2)-1] ] - rows = [row i | i <- [0..(nrows m1)-1]] + column_list = [ column m i | i <- [0..(ncols m)-1] ] -- | Is @m@ symmetric? -symmetric :: Eq a => (Matrix a) -> Bool +-- +-- Examples: +-- +-- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int +-- >>> symmetric m1 +-- True +-- +-- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int +-- >>> symmetric m2 +-- False +-- +symmetric :: (Vector v (Vn w a), + Vector w a, + v ~ w, + Vector w Bool, + Eq a) + => Mat v w a + -> Bool symmetric m = m == (transpose m) + -- | Construct a new matrix from a function @lambda@. The function -- @lambda@ should take two parameters i,j corresponding to the -- entries in the matrix. The i,j entry of the resulting matrix will -- have the value returned by lambda i j. -- --- The @imax@ and @jmax@ parameters determine the size of the matrix. +-- TODO: Don't cheat with fromList. +-- +-- Examples: -- -construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a) -construct imax jmax lambda = - Matrix rows +-- >>> let lambda i j = i + j +-- >>> construct lambda :: Mat3 Int +-- ((0,1,2),(1,2,3),(2,3,4)) +-- +construct :: forall v w a. + (Vector v (Vn w a), + Vector w a) + => (Int -> Int -> a) + -> Mat v w a +construct lambda = rows where - row i = V.fromList [ lambda i j | j <- [0..jmax] ] - rows = V.fromList [ row i | i <- [0..imax] ] + -- The arity trick is used in Data.Vector.Fixed.length. + imax = (arity (undefined :: Dim v)) - 1 + jmax = (arity (undefined :: Dim w)) - 1 + row' i = V.fromList [ lambda i j | j <- [0..jmax] ] + rows = V.fromList [ row' i | i <- [0..imax] ] -- | Given a positive-definite matrix @m@, computes the -- upper-triangular matrix @r@ with (transpose r)*r == m and all -- values on the diagonal of @r@ positive. -cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a) -cholesky m = - construct (nrows m - 1) (ncols m - 1) r +-- +-- Examples: +-- +-- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double +-- >>> cholesky m1 +-- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459)) +-- >>> (transpose (cholesky m1)) `mult` (cholesky m1) +-- ((20.000000000000004,-1.0),(-1.0,20.0)) +-- +cholesky :: forall a v w. + (RealFloat a, + Vector v (Vn w a), + Vector w a) + => (Mat v w a) + -> (Mat v w a) +cholesky m = construct r where r :: Int -> Int -> a - r i j | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) + r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) | i < j = - (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) + (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) | otherwise = 0 --- | It's not correct to use Num here, but I really don't want to have --- to define my own addition and subtraction. -instance Num a => Num (Matrix a) where - -- Standard componentwise addition. - (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2) - - -- Standard componentwise subtraction. - (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2) - - -- Matrix multiplication. - m1 * m2 = m1 `mtimes` m2 - - abs _ = error "absolute value of matrices is undefined" - - signum _ = error "signum of matrices is undefined" - - fromInteger _ = error "fromInteger of matrices is undefined" +-- | Matrix multiplication. Our 'Num' instance doesn't define one, and +-- we need additional restrictions on the result type anyway. +-- +-- Examples: +-- +-- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int +-- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int +-- >>> m1 `mult` m2 +-- ((22,28),(49,64)) +-- +mult :: (Num a, + Vector v (Vn w a), + Vector w a, + Vector w (Vn z a), + Vector z a, + Vector v (Vn z a)) + => Mat v w a + -> Mat w z a + -> Mat v z a +mult m1 m2 = construct lambda + where + lambda i j = + sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]