X-Git-Url: http://gitweb.michael.orlitzky.com/?a=blobdiff_plain;f=src%2FFixedMatrix.hs;h=e3a1e2b1143aca00e0cf4d121ef5156605c4ac20;hb=55cf46834f181be01b17b4c5a02ecd772c4e3090;hp=7b34fee0b5ec2798e77b3671c960048f4261df0c;hpb=9bd0e10d2c4a18c21269d520190a5d6b65b6390f;p=numerical-analysis.git diff --git a/src/FixedMatrix.hs b/src/FixedMatrix.hs index 7b34fee..e3a1e2b 100644 --- a/src/FixedMatrix.hs +++ b/src/FixedMatrix.hs @@ -7,8 +7,19 @@ module FixedMatrix where -import FixedVector as FV -import qualified Data.Vector.Fixed as V +import FixedVector +import Data.Vector.Fixed ( + Arity(..), + Dim, + Vector, + (!), + ) +import qualified Data.Vector.Fixed as V ( + fromList, + length, + map, + toList + ) import Data.Vector.Fixed.Internal (arity) type Mat v w a = Vn v (Vn w a) @@ -17,63 +28,72 @@ type Mat3 a = Mat Vec3D Vec3D a type Mat4 a = Mat Vec4D Vec4D a -- | Convert a matrix to a nested list. -toList :: (V.Vector v (Vn w a), V.Vector w a) => Mat v w a -> [[a]] -toList m = Prelude.map V.toList (V.toList m) +toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]] +toList m = map V.toList (V.toList m) -- | Create a matrix from a nested list. -fromList :: (V.Vector v (Vn w a), V.Vector w a) => [[a]] -> Mat v w a -fromList vs = V.fromList $ Prelude.map V.fromList vs +fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a +fromList vs = V.fromList $ map V.fromList vs -- | Unsafe indexing. -(!) :: (V.Vector v (Vn w a), V.Vector w a) => Mat v w a -> (Int, Int) -> a -(!) m (i, j) = (row m i) V.! j +(!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a +(!!!) m (i, j) = (row m i) ! j -- | Safe indexing. -(!?) :: (V.Vector v (Vn w a), V.Vector w a) => Mat v w a - -> (Int, Int) - -> Maybe a -(!?) m (i, j) +(!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a + -> (Int, Int) + -> Maybe a +(!!?) m (i, j) | i < 0 || j < 0 = Nothing | i > V.length m = Nothing | otherwise = if j > V.length (row m j) then Nothing - else Just $ (row m j) V.! j + else Just $ (row m j) ! j -- | The number of rows in the matrix. -nrows :: forall v w a. (V.Vector v (Vn w a), V.Vector w a) => Mat v w a -> Int +nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int nrows = V.length -- | The number of columns in the first row of the -- matrix. Implementation stolen from Data.Vector.Fixed.length. -ncols :: forall v w a. (V.Vector v (Vn w a), V.Vector w a) => Mat v w a -> Int -ncols _ = arity (undefined :: V.Dim w) +ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int +ncols _ = arity (undefined :: Dim w) -- | Return the @i@th row of @m@. Unsafe. -row :: (V.Vector v (Vn w a), V.Vector w a) => Mat v w a - -> Int - -> Vn w a -row m i = m V.! i +row :: (Vector v (Vn w a), Vector w a) => Mat v w a + -> Int + -> Vn w a +row m i = m ! i -- | Return the @j@th column of @m@. Unsafe. -column :: (V.Vector v a, V.Vector v (Vn w a), V.Vector w a) => Mat v w a - -> Int - -> Vn v a +column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a + -> Int + -> Vn v a column m j = V.map (element j) m where - element = flip (V.!) + element = flip (!) --- | Transose @m@; switch it's columns and its rows. This is a dirty +-- | Transpose @m@; switch it's columns and its rows. This is a dirty -- implementation.. it would be a little cleaner to use imap, but it -- doesn't seem to work. -transpose :: (V.Vector v (Vn w a), - V.Vector w (Vn v a), - V.Vector v a, - V.Vector w a) +-- +-- TODO: Don't cheat with fromList. +-- +-- Examples: +-- +-- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int +-- >>> transpose m +-- ((1,3),(2,4)) +-- +transpose :: (Vector v (Vn w a), + Vector w (Vn v a), + Vector v a, + Vector w a) => Mat v w a -> Mat w v a transpose m = V.fromList column_list @@ -81,10 +101,21 @@ transpose m = V.fromList column_list column_list = [ column m i | i <- [0..(ncols m)-1] ] -- | Is @m@ symmetric? -symmetric :: (V.Vector v (Vn w a), - V.Vector w a, +-- +-- Examples: +-- +-- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int +-- >>> symmetric m1 +-- True +-- +-- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int +-- >>> symmetric m2 +-- False +-- +symmetric :: (Vector v (Vn w a), + Vector w a, v ~ w, - V.Vector w Bool, + Vector w Bool, Eq a) => Mat v w a -> Bool @@ -96,15 +127,74 @@ symmetric m = -- @lambda@ should take two parameters i,j corresponding to the -- entries in the matrix. The i,j entry of the resulting matrix will -- have the value returned by lambda i j. +-- +-- TODO: Don't cheat with fromList. +-- +-- Examples: +-- +-- >>> let lambda i j = i + j +-- >>> construct lambda :: Mat3 Int +-- ((0,1,2),(1,2,3),(2,3,4)) +-- construct :: forall v w a. - (V.Vector v (Vn w a), - V.Vector w a) + (Vector v (Vn w a), + Vector w a) => (Int -> Int -> a) -> Mat v w a construct lambda = rows where -- The arity trick is used in Data.Vector.Fixed.length. - imax = (arity (undefined :: V.Dim v)) - 1 - jmax = (arity (undefined :: V.Dim w)) - 1 + imax = (arity (undefined :: Dim v)) - 1 + jmax = (arity (undefined :: Dim w)) - 1 row' i = V.fromList [ lambda i j | j <- [0..jmax] ] rows = V.fromList [ row' i | i <- [0..imax] ] + +-- | Given a positive-definite matrix @m@, computes the +-- upper-triangular matrix @r@ with (transpose r)*r == m and all +-- values on the diagonal of @r@ positive. +-- +-- Examples: +-- +-- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double +-- >>> cholesky m1 +-- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459)) +-- >>> (transpose (cholesky m1)) `mult` (cholesky m1) +-- ((20.000000000000004,-1.0),(-1.0,20.0)) +-- +cholesky :: forall a v w. + (RealFloat a, + Vector v (Vn w a), + Vector w a) + => (Mat v w a) + -> (Mat v w a) +cholesky m = construct r + where + r :: Int -> Int -> a + r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) + | i < j = + (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) + | otherwise = 0 + +-- | Matrix multiplication. Our 'Num' instance doesn't define one, and +-- we need additional restrictions on the result type anyway. +-- +-- Examples: +-- +-- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int +-- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int +-- >>> m1 `mult` m2 +-- ((22,28),(49,64)) +-- +mult :: (Num a, + Vector v (Vn w a), + Vector w a, + Vector w (Vn z a), + Vector z a, + Vector v (Vn z a)) + => Mat v w a + -> Mat w z a + -> Mat v z a +mult m1 m2 = construct lambda + where + lambda i j = + sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]