--- /dev/null
+{-# LANGUAGE ScopedTypeVariables #-}
+
+-- | A Matrix type using Data.Vector as the underlying type. In other
+-- words, the size is not fixed, but at least we have safe indexing if
+-- we want it.
+--
+-- This should be replaced with a fixed-size implementation eventually!
+--
+module Matrix
+where
+
+import qualified Data.Vector as V
+
+type Rows a = V.Vector (V.Vector a)
+type Columns a = V.Vector (V.Vector a)
+data Matrix a = Matrix (Rows a) deriving Eq
+
+-- | Unsafe indexing
+(!) :: (Matrix a) -> (Int, Int) -> a
+(Matrix rows) ! (i, j) = (rows V.! i) V.! j
+
+-- | Safe indexing
+(!?) :: (Matrix a) -> (Int, Int) -> Maybe a
+(Matrix rows) !? (i, j) = do
+ row <- rows V.!? i
+ col <- row V.!? j
+ return col
+
+-- | Unsafe indexing without bounds checking
+unsafeIndex :: (Matrix a) -> (Int, Int) -> a
+(Matrix rows) `unsafeIndex` (i, j) =
+ (rows `V.unsafeIndex` i) `V.unsafeIndex` j
+
+-- | Return the @i@th column of @m@. Unsafe!
+column :: (Matrix a) -> Int -> (V.Vector a)
+column (Matrix rows) i =
+ V.fromList [row V.! i | row <- V.toList rows]
+
+-- | The number of rows in the matrix.
+nrows :: (Matrix a) -> Int
+nrows (Matrix rows) = V.length rows
+
+-- | The number of columns in the first row of the matrix.
+ncols :: (Matrix a) -> Int
+ncols (Matrix rows)
+ | V.length rows == 0 = 0
+ | otherwise = V.length (rows V.! 0)
+
+-- | Return the vector of @m@'s columns.
+columns :: (Matrix a) -> (Columns a)
+columns m =
+ V.fromList [column m i | i <- [0..(ncols m)-1]]
+
+-- | Transose @m@; switch it's columns and its rows.
+transpose :: (Matrix a) -> (Matrix a)
+transpose m =
+ Matrix (columns m)
+
+instance Show a => Show (Matrix a) where
+ show (Matrix rows) =
+ concat $ V.toList $ V.map show_row rows
+ where show_row r = "[" ++ (show r) ++ "]\n"
+
+instance Functor Matrix where
+ f `fmap` (Matrix rows) = Matrix (V.map (fmap f) rows)
+
+
+-- | Vector addition.
+vplus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
+vplus xs ys = V.zipWith (+) xs ys
+
+-- | Vector subtraction.
+vminus :: Num a => (V.Vector a) -> (V.Vector a) -> (V.Vector a)
+vminus xs ys = V.zipWith (-) xs ys
+
+-- | Add two vectors of rows.
+rowsplus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
+rowsplus rs1 rs2 =
+ V.zipWith vplus rs1 rs2
+
+-- | Subtract two vectors of rows.
+rowsminus :: Num a => (Rows a) -> (Rows a) -> (Rows a)
+rowsminus rs1 rs2 =
+ V.zipWith vminus rs1 rs2
+
+-- | Matrix multiplication.
+mtimes :: Num a => (Matrix a) -> (Matrix a) -> (Matrix a)
+mtimes m1@(Matrix rows1) m2@(Matrix rows2) =
+ Matrix (V.fromList rows)
+ where
+ row i = V.fromList [ sum [ (m1 ! (i,k)) * (m2 ! (k,j)) | k <- [0..(ncols m1)-1] ]
+ | j <- [0..(ncols m2)-1] ]
+ rows = [row i | i <- [0..(nrows m1)-1]]
+
+-- | Is @m@ symmetric?
+symmetric :: Eq a => (Matrix a) -> Bool
+symmetric m =
+ m == (transpose m)
+
+-- | Construct a new matrix from a function @lambda@. The function
+-- @lambda@ should take two parameters i,j corresponding to the
+-- entries in the matrix. The i,j entry of the resulting matrix will
+-- have the value returned by lambda i j.
+--
+-- The @imax@ and @jmax@ parameters determine the size of the matrix.
+--
+construct :: Int -> Int -> (Int -> Int -> a) -> (Matrix a)
+construct imax jmax lambda =
+ Matrix rows
+ where
+ row i = V.fromList [ lambda i j | j <- [0..jmax] ]
+ rows = V.fromList [ row i | i <- [0..imax] ]
+
+-- | Given a positive-definite matrix @m@, computes the
+-- upper-triangular matrix @r@ with (transpose r)*r == m and all
+-- values on the diagonal of @r@ positive.
+cholesky :: forall a. RealFloat a => (Matrix a) -> (Matrix a)
+cholesky m =
+ construct (nrows m - 1) (ncols m - 1) r
+ where
+ r :: Int -> Int -> a
+ r i j | i > j = 0
+ | i == j = sqrt(m ! (i,j) - sum [(r k i)**2 | k <- [0..i-1]])
+ | i < j = (((m ! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i)
+
+-- | It's not correct to use Num here, but I really don't want to have
+-- to define my own addition and subtraction.
+instance Num a => Num (Matrix a) where
+ -- Standard componentwise addition.
+ (Matrix rows1) + (Matrix rows2) = Matrix (rows1 `rowsplus` rows2)
+
+ -- Standard componentwise subtraction.
+ (Matrix rows1) - (Matrix rows2) = Matrix (rows1 `rowsminus` rows2)
+
+ -- Matrix multiplication.
+ m1 * m2 = m1 `mtimes` m2
+
+ abs _ = error "absolute value of matrices is undefined"
+
+ signum _ = error "signum of matrices is undefined"
+
+ fromInteger x = error "fromInteger of matrices is undefined"