]> gitweb.michael.orlitzky.com - numerical-analysis.git/blobdiff - src/Matrix.hs
Remove non-fixed Matrix module.
[numerical-analysis.git] / src / Matrix.hs
index 349b71af2ad3d72810b3c0f6bd3585c739aa69dc..ea44ae968878ab1b5d1b02a7dd981e77f5446337 100644 (file)
 {-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE TypeFamilies #-}
 
--- | A Matrix type using Data.Vector as the underlying type. In other
---   words, the size is not fixed, but at least we have safe indexing if
---   we want it.
---
---   This should be replaced with a fixed-size implementation eventually!
---
 module Matrix
 where
 
-import qualified Data.Vector as V
-
-type Rows a = V.Vector (V.Vector a)
-type Columns a = V.Vector (V.Vector a)
-data Matrix a = Matrix (Rows a) deriving Eq
-
--- | Unsafe indexing
-(!) :: (Matrix a) -> (Int, Int) -> a
-(Matrix rows) ! (i, j) = (rows V.! i) V.! j
+import Vector
+import Data.Vector.Fixed (
+  Arity(..),
+  Dim,
+  Vector,
+  (!),
+  )
+import qualified Data.Vector.Fixed as V (
+  fromList,
+  length,
+  map,
+  toList
+  )
+import Data.Vector.Fixed.Internal (arity)
+
+type Mat v w a = Vn v (Vn w a)
+type Mat2 a = Mat Vec2D Vec2D a
+type Mat3 a = Mat Vec3D Vec3D a
+type Mat4 a = Mat Vec4D Vec4D a
+
+-- | Convert a matrix to a nested list.
+toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]]
+toList m = map V.toList (V.toList m)
+
+-- | Create a matrix from a nested list.
+fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a
+fromList vs = V.fromList $ map V.fromList vs
+
+
+-- | Unsafe indexing.
+(!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a
+(!!!) m (i, j) = (row m i) ! j
+
+-- | Safe indexing.
+(!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a
+                                         -> (Int, Int)
+                                         -> Maybe a
+(!!?) m (i, j)
+  | i < 0 || j < 0 = Nothing
+  | i > V.length m = Nothing
+  | otherwise = if j > V.length (row m j)
+                then Nothing
+                else Just $ (row m j) ! j
 
--- | Safe indexing
-(!?) :: (Matrix a) -> (Int, Int) -> Maybe a
-(Matrix rows) !? (i, j) = do
-  row <- rows V.!? i
-  col <- row V.!? j
-  return col
 
--- | Unsafe indexing without bounds checking
-unsafeIndex :: (Matrix a) -> (Int, Int) -> a
-(Matrix rows) `unsafeIndex` (i, j) =
-  (rows `V.unsafeIndex` i) `V.unsafeIndex` j
+-- | The number of rows in the matrix.
+nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
+nrows = V.length
+
+-- | The number of columns in the first row of the
+--   matrix. Implementation stolen from Data.Vector.Fixed.length.
+ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int
+ncols _ = arity (undefined :: Dim w)
+
+-- | Return the @i@th row of @m@. Unsafe.
+row :: (Vector v (Vn w a), Vector w a) => Mat v w a
+                                       -> Int
+                                       -> Vn w a
+row m i = m ! i
+
+
+-- | Return the @j@th column of @m@. Unsafe.
+column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a
+                                                      -> Int
+                                                      -> Vn v a
+column m j =
+  V.map (element j) m
+  where
+    element = flip (!)
 
--- | Return the @i@th column of @m@. Unsafe!
-column :: (Matrix a) -> Int -> (V.Vector a)
-column (Matrix rows) i =
-  V.fromList [row V.! i | row <- V.toList rows]
 
--- | The number of rows in the matrix.
-nrows :: (Matrix a) -> Int
-nrows (Matrix rows) = V.length rows
-
--- | The number of columns in the first row of the matrix.
-ncols :: (Matrix a) -> Int
-ncols (Matrix rows)
-  | V.length rows == 0 = 0
-  | otherwise = V.length (rows V.! 0)
-
--- | Return the vector of @m@'s columns.
-columns :: (Matrix a) -> (Columns a)
-columns m =
-  V.fromList [column m i | i <- [0..(ncols m)-1]]
-
--- | Transose @m@; switch it's columns and its rows.
-transpose :: (Matrix a) -> (Matrix a)
-transpose m =
-  Matrix (columns m)
-
-instance Show a => Show (Matrix a) where
-  show (Matrix rows) =
-    concat $ V.toList $ V.map show_row rows
-        where show_row r = "[" ++ (show r) ++ "]\n"
-
-instance Functor Matrix where
-  f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows)
-
-
--- | Vector addition.
-vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
-vplus xs ys = V.zipWith (+) xs ys
-
--- | Vector subtraction.
-vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
-vminus xs ys = V.zipWith (-) xs ys
-
--- | Add two vectors of rows.
-rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
-rowsplus rs1 rs2 =
-  V.zipWith vplus rs1 rs2
-
--- | Subtract two vectors of rows.
-rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
-rowsminus rs1 rs2 =
-  V.zipWith vminus rs1 rs2
-
--- | Matrix multiplication.
-mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a)
-mtimes m1 m2 =
-  Matrix (V.fromList rows)
+-- | Transpose @m@; switch it's columns and its rows. This is a dirty
+--   implementation.. it would be a little cleaner to use imap, but it
+--   doesn't seem to work.
+--
+--   TODO: Don't cheat with fromList.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int
+--   >>> transpose m
+--   ((1,3),(2,4))
+--
+transpose :: (Vector v (Vn w a),
+              Vector w (Vn v a),
+              Vector v a,
+              Vector w a)
+             => Mat v w a
+             -> Mat w v a
+transpose m = V.fromList column_list
   where
-    row i =
-      V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ]
-                                                     | j <- [0..(ncols m2)-1] ]
-    rows = [row i | i <- [0..(nrows m1)-1]]
+    column_list = [ column m i | i <- [0..(ncols m)-1] ]
 
 -- | Is @m@ symmetric?
-symmetric :: Eq a => (Matrix a) -> Bool
+--
+--   Examples:
+--
+--   >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int
+--   >>> symmetric m1
+--   True
+--
+--   >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int
+--   >>> symmetric m2
+--   False
+--
+symmetric :: (Vector v (Vn w a),
+              Vector w a,
+              v ~ w,
+              Vector w Bool,
+              Eq a)
+             => Mat v w a
+             -> Bool
 symmetric m =
   m == (transpose m)
 
+
 -- | Construct a new matrix from a function @lambda@. The function
 --   @lambda@ should take two parameters i,j corresponding to the
 --   entries in the matrix. The i,j entry of the resulting matrix will
 --   have the value returned by lambda i j.
 --
---   The @imax@ and @jmax@ parameters determine the size of the matrix.
+--   TODO: Don't cheat with fromList.
+--
+--   Examples:
 --
-construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a)
-construct imax jmax lambda =
-  Matrix rows
+--   >>> let lambda i j = i + j
+--   >>> construct lambda :: Mat3 Int
+--   ((0,1,2),(1,2,3),(2,3,4))
+--
+construct :: forall v w a.
+             (Vector v (Vn w a),
+              Vector w a)
+             => (Int -> Int -> a)
+             -> Mat v w a
+construct lambda = rows
   where
-    row i = V.fromList [ lambda i j | j <- [0..jmax] ]
-    rows = V.fromList [ row i | i <- [0..imax] ]
+    -- The arity trick is used in Data.Vector.Fixed.length.
+    imax = (arity (undefined :: Dim v)) - 1
+    jmax = (arity (undefined :: Dim w)) - 1
+    row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
+    rows = V.fromList [ row' i | i <- [0..imax] ]
 
 -- | Given a positive-definite matrix @m@, computes the
 --   upper-triangular matrix @r@ with (transpose r)*r == m and all
 --   values on the diagonal of @r@ positive.
-cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a)
-cholesky m =
-  construct (nrows m - 1) (ncols m - 1) r
+--
+--   Examples:
+--
+--   >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double
+--   >>> cholesky m1
+--   ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459))
+--   >>> (transpose (cholesky m1)) `mult` (cholesky m1)
+--   ((20.000000000000004,-1.0),(-1.0,20.0))
+--
+cholesky :: forall a v w.
+            (RealFloat a,
+             Vector v (Vn w a),
+             Vector w a)
+            => (Mat v w a)
+            -> (Mat v w a)
+cholesky m = construct r
   where
     r :: Int -> Int -> a
-    r i j | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
+    r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
           | i < j =
-              (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
+              (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
           | otherwise = 0
 
--- | It's not correct to use Num here, but I really don't want to have
---   to define my own addition and subtraction.
-instance Num a => Num (Matrix a) where
-  -- Standard componentwise addition.
-  (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2)
-
-  -- Standard componentwise subtraction.
-  (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2)
-
-  -- Matrix multiplication.
-  m1 * m2 = m1 `mtimes` m2
-
-  abs _ = error "absolute value of matrices is undefined"
-
-  signum _ = error "signum of matrices is undefined"
-
-  fromInteger _ = error "fromInteger of matrices is undefined"
+-- | Matrix multiplication. Our 'Num' instance doesn't define one, and
+--   we need additional restrictions on the result type anyway.
+--
+--   Examples:
+--
+--   >>> let m1 = fromList [[1,2,3], [4,5,6]]  :: Mat Vec2D Vec3D Int
+--   >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int
+--   >>> m1 `mult` m2
+--   ((22,28),(49,64))
+--
+mult :: (Num a,
+         Vector v (Vn w a),
+         Vector w a,
+         Vector w (Vn z a),
+         Vector z a,
+         Vector v (Vn z a))
+        => Mat v w a
+        -> Mat w z a
+        -> Mat v z a
+mult m1 m2 = construct lambda
+  where
+    lambda i j =
+      sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]