{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} module Linear.Matrix where import Data.Vector.Fixed ( Dim, Vector ) import qualified Data.Vector.Fixed as V ( fromList, length, map, toList ) import Data.Vector.Fixed.Internal (arity) import Linear.Vector type Mat v w a = Vn v (Vn w a) type Mat2 a = Mat Vec2D Vec2D a type Mat3 a = Mat Vec3D Vec3D a type Mat4 a = Mat Vec4D Vec4D a -- | Convert a matrix to a nested list. toList :: (Vector v (Vn w a), Vector w a) => Mat v w a -> [[a]] toList m = map V.toList (V.toList m) -- | Create a matrix from a nested list. fromList :: (Vector v (Vn w a), Vector w a) => [[a]] -> Mat v w a fromList vs = V.fromList $ map V.fromList vs -- | Unsafe indexing. (!!!) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> a (!!!) m (i, j) = (row m i) ! j -- | Safe indexing. (!!?) :: (Vector v (Vn w a), Vector w a) => Mat v w a -> (Int, Int) -> Maybe a (!!?) m (i, j) | i < 0 || j < 0 = Nothing | i > V.length m = Nothing | otherwise = if j > V.length (row m j) then Nothing else Just $ (row m j) ! j -- | The number of rows in the matrix. nrows :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int nrows = V.length -- | The number of columns in the first row of the -- matrix. Implementation stolen from Data.Vector.Fixed.length. ncols :: forall v w a. (Vector v (Vn w a), Vector w a) => Mat v w a -> Int ncols _ = arity (undefined :: Dim w) -- | Return the @i@th row of @m@. Unsafe. row :: (Vector v (Vn w a), Vector w a) => Mat v w a -> Int -> Vn w a row m i = m ! i -- | Return the @j@th column of @m@. Unsafe. column :: (Vector v a, Vector v (Vn w a), Vector w a) => Mat v w a -> Int -> Vn v a column m j = V.map (element j) m where element = flip (!) -- | Transpose @m@; switch it's columns and its rows. This is a dirty -- implementation.. it would be a little cleaner to use imap, but it -- doesn't seem to work. -- -- TODO: Don't cheat with fromList. -- -- Examples: -- -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int -- >>> transpose m -- ((1,3),(2,4)) -- transpose :: (Vector v (Vn w a), Vector w (Vn v a), Vector v a, Vector w a) => Mat v w a -> Mat w v a transpose m = V.fromList column_list where column_list = [ column m i | i <- [0..(ncols m)-1] ] -- | Is @m@ symmetric? -- -- Examples: -- -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int -- >>> symmetric m1 -- True -- -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int -- >>> symmetric m2 -- False -- symmetric :: (Vector v (Vn w a), Vector w a, v ~ w, Vector w Bool, Eq a) => Mat v w a -> Bool symmetric m = m == (transpose m) -- | Construct a new matrix from a function @lambda@. The function -- @lambda@ should take two parameters i,j corresponding to the -- entries in the matrix. The i,j entry of the resulting matrix will -- have the value returned by lambda i j. -- -- TODO: Don't cheat with fromList. -- -- Examples: -- -- >>> let lambda i j = i + j -- >>> construct lambda :: Mat3 Int -- ((0,1,2),(1,2,3),(2,3,4)) -- construct :: forall v w a. (Vector v (Vn w a), Vector w a) => (Int -> Int -> a) -> Mat v w a construct lambda = rows where -- The arity trick is used in Data.Vector.Fixed.length. imax = (arity (undefined :: Dim v)) - 1 jmax = (arity (undefined :: Dim w)) - 1 row' i = V.fromList [ lambda i j | j <- [0..jmax] ] rows = V.fromList [ row' i | i <- [0..imax] ] -- | Given a positive-definite matrix @m@, computes the -- upper-triangular matrix @r@ with (transpose r)*r == m and all -- values on the diagonal of @r@ positive. -- -- Examples: -- -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double -- >>> cholesky m1 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459)) -- >>> (transpose (cholesky m1)) `mult` (cholesky m1) -- ((20.000000000000004,-1.0),(-1.0,20.0)) -- cholesky :: forall a v w. (RealFloat a, Vector v (Vn w a), Vector w a) => (Mat v w a) -> (Mat v w a) cholesky m = construct r where r :: Int -> Int -> a r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)**2 | k <- [0..i-1]]) | i < j = (((m !!! (i,j)) - sum [(r k i)*(r k j) | k <- [0..i-1]]))/(r i i) | otherwise = 0 -- | Matrix multiplication. Our 'Num' instance doesn't define one, and -- we need additional restrictions on the result type anyway. -- -- Examples: -- -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat Vec2D Vec3D Int -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat Vec3D Vec2D Int -- >>> m1 `mult` m2 -- ((22,28),(49,64)) -- mult :: (Num a, Vector v (Vn w a), Vector w a, Vector w (Vn z a), Vector z a, Vector v (Vn z a)) => Mat v w a -> Mat w z a -> Mat v z a mult m1 m2 = construct lambda where lambda i j = sum [(m1 !!! (i,k)) * (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ]