{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE RebindableSyntax #-} -- | Boxed matrices; that is, boxed m-vectors of boxed n-vectors. We -- assume that the underlying representation is -- Data.Vector.Fixed.Boxed.Vec for simplicity. It was tried in -- generality and failed. -- module Linear.Matrix where import Data.List (intercalate) import Data.Vector.Fixed ( (!), N1, N2, N3, N4, N5, S, Z, mk1, mk2, mk3, mk4, mk5 ) import qualified Data.Vector.Fixed as V ( and, fromList, head, length, map, maximum, replicate, toList, zipWith ) import Data.Vector.Fixed.Boxed (Vec) import Data.Vector.Fixed.Internal.Arity (Arity, arity) import Linear.Vector import Normed import NumericPrelude hiding ((*), abs) import qualified NumericPrelude as NP ((*)) import qualified Algebra.Algebraic as Algebraic import Algebra.Algebraic (root) import qualified Algebra.Additive as Additive import qualified Algebra.Ring as Ring import qualified Algebra.Module as Module import qualified Algebra.RealRing as RealRing import qualified Algebra.ToRational as ToRational import qualified Algebra.Transcendental as Transcendental import qualified Prelude as P data Mat m n a = (Arity m, Arity n) => Mat (Vec m (Vec n a)) type Mat1 a = Mat N1 N1 a type Mat2 a = Mat N2 N2 a type Mat3 a = Mat N3 N3 a type Mat4 a = Mat N4 N4 a type Mat5 a = Mat N5 N5 a instance (Eq a) => Eq (Mat m n a) where -- | Compare a row at a time. -- -- Examples: -- -- >>> let m1 = fromList [[1,2],[3,4]] :: Mat2 Int -- >>> let m2 = fromList [[1,2],[3,4]] :: Mat2 Int -- >>> let m3 = fromList [[5,6],[7,8]] :: Mat2 Int -- >>> m1 == m2 -- True -- >>> m1 == m3 -- False -- (Mat rows1) == (Mat rows2) = V.and $ V.zipWith comp rows1 rows2 where -- Compare a row, one column at a time. comp row1 row2 = V.and (V.zipWith (==) row1 row2) instance (Show a) => Show (Mat m n a) where -- | Display matrices and vectors as ordinary tuples. This is poor -- practice, but these results are primarily displayed -- interactively and convenience trumps correctness (said the guy -- who insists his vector lengths be statically checked at -- compile-time). -- -- Examples: -- -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int -- >>> show m -- ((1,2),(3,4)) -- show (Mat rows) = "(" ++ (intercalate "," (V.toList row_strings)) ++ ")" where row_strings = V.map show_vector rows show_vector v1 = "(" ++ (intercalate "," element_strings) ++ ")" where v1l = V.toList v1 element_strings = P.map show v1l -- | Convert a matrix to a nested list. toList :: Mat m n a -> [[a]] toList (Mat rows) = map V.toList (V.toList rows) -- | Create a matrix from a nested list. fromList :: (Arity m, Arity n) => [[a]] -> Mat m n a fromList vs = Mat (V.fromList $ map V.fromList vs) -- | Unsafe indexing. (!!!) :: (Arity m, Arity n) => Mat m n a -> (Int, Int) -> a (!!!) m (i, j) = (row m i) ! j -- | Safe indexing. (!!?) :: Mat m n a -> (Int, Int) -> Maybe a (!!?) m@(Mat rows) (i, j) | i < 0 || j < 0 = Nothing | i > V.length rows = Nothing | otherwise = if j > V.length (row m j) then Nothing else Just $ (row m j) ! j -- | The number of rows in the matrix. nrows :: forall m n a. (Arity m) => Mat m n a -> Int nrows _ = arity (undefined :: m) -- | The number of columns in the first row of the -- matrix. Implementation stolen from Data.Vector.Fixed.length. ncols :: forall m n a. (Arity n) => Mat m n a -> Int ncols _ = arity (undefined :: n) -- | Return the @i@th row of @m@. Unsafe. row :: Mat m n a -> Int -> (Vec n a) row (Mat rows) i = rows ! i -- | Return the @j@th column of @m@. Unsafe. column :: Mat m n a -> Int -> (Vec m a) column (Mat rows) j = V.map (element j) rows where element = flip (!) -- | Transpose @m@; switch it's columns and its rows. This is a dirty -- implementation.. it would be a little cleaner to use imap, but it -- doesn't seem to work. -- -- TODO: Don't cheat with fromList. -- -- Examples: -- -- >>> let m = fromList [[1,2], [3,4]] :: Mat2 Int -- >>> transpose m -- ((1,3),(2,4)) -- transpose :: (Arity m, Arity n) => Mat m n a -> Mat n m a transpose m = Mat $ V.fromList column_list where column_list = [ column m i | i <- [0..(ncols m)-1] ] -- | Is @m@ symmetric? -- -- Examples: -- -- >>> let m1 = fromList [[1,2], [2,1]] :: Mat2 Int -- >>> symmetric m1 -- True -- -- >>> let m2 = fromList [[1,2], [3,1]] :: Mat2 Int -- >>> symmetric m2 -- False -- symmetric :: (Eq a, Arity m) => Mat m m a -> Bool symmetric m = m == (transpose m) -- | Construct a new matrix from a function @lambda@. The function -- @lambda@ should take two parameters i,j corresponding to the -- entries in the matrix. The i,j entry of the resulting matrix will -- have the value returned by lambda i j. -- -- TODO: Don't cheat with fromList. -- -- Examples: -- -- >>> let lambda i j = i + j -- >>> construct lambda :: Mat3 Int -- ((0,1,2),(1,2,3),(2,3,4)) -- construct :: forall m n a. (Arity m, Arity n) => (Int -> Int -> a) -> Mat m n a construct lambda = Mat rows where -- The arity trick is used in Data.Vector.Fixed.length. imax = (arity (undefined :: m)) - 1 jmax = (arity (undefined :: n)) - 1 row' i = V.fromList [ lambda i j | j <- [0..jmax] ] rows = V.fromList [ row' i | i <- [0..imax] ] -- | Given a positive-definite matrix @m@, computes the -- upper-triangular matrix @r@ with (transpose r)*r == m and all -- values on the diagonal of @r@ positive. -- -- Examples: -- -- >>> let m1 = fromList [[20,-1], [-1,20]] :: Mat2 Double -- >>> cholesky m1 -- ((4.47213595499958,-0.22360679774997896),(0.0,4.466542286825459)) -- >>> (transpose (cholesky m1)) * (cholesky m1) -- ((20.000000000000004,-1.0),(-1.0,20.0)) -- cholesky :: forall m n a. (Algebraic.C a, Arity m, Arity n) => (Mat m n a) -> (Mat m n a) cholesky m = construct r where r :: Int -> Int -> a r i j | i == j = sqrt(m !!! (i,j) - sum [(r k i)^2 | k <- [0..i-1]]) | i < j = (((m !!! (i,j)) - sum [(r k i) NP.* (r k j) | k <- [0..i-1]]))/(r i i) | otherwise = 0 -- | Returns True if the given matrix is upper-triangular, and False -- otherwise. -- -- Examples: -- -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int -- >>> is_upper_triangular m -- False -- -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int -- >>> is_upper_triangular m -- True -- is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n) => Mat m n a -> Bool is_upper_triangular m = and $ concat results where results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ] test :: Int -> Int -> Bool test i j | i <= j = True | otherwise = m !!! (i,j) == 0 -- | Returns True if the given matrix is lower-triangular, and False -- otherwise. -- -- Examples: -- -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int -- >>> is_lower_triangular m -- True -- -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int -- >>> is_lower_triangular m -- False -- is_lower_triangular :: (Eq a, Ring.C a, Arity m, Arity n) => Mat m n a -> Bool is_lower_triangular = is_upper_triangular . transpose -- | Returns True if the given matrix is triangular, and False -- otherwise. -- -- Examples: -- -- >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int -- >>> is_triangular m -- True -- -- >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int -- >>> is_triangular m -- True -- -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int -- >>> is_triangular m -- False -- is_triangular :: (Eq a, Ring.C a, Arity m, Arity n) => Mat m n a -> Bool is_triangular m = is_upper_triangular m || is_lower_triangular m -- | Return the (i,j)th minor of m. -- -- Examples: -- -- >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int -- >>> minor m 0 0 :: Mat2 Int -- ((5,6),(8,9)) -- >>> minor m 1 1 :: Mat2 Int -- ((1,3),(7,9)) -- minor :: (m ~ S r, n ~ S t, Arity r, Arity t) => Mat m n a -> Int -> Int -> Mat r t a minor (Mat rows) i j = m where rows' = delete rows i m = Mat $ V.map ((flip delete) j) rows' class (Eq a, Ring.C a) => Determined p a where determinant :: (p a) -> a instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where determinant (Mat rows) = (V.head . V.head) rows instance (Eq a, Ring.C a, Arity n, Determined (Mat (S n) (S n)) a) => Determined (Mat (S (S n)) (S (S n))) a where -- | The recursive definition with a special-case for triangular matrices. -- -- Examples: -- -- >>> let m = fromList [[1,2],[3,4]] :: Mat2 Int -- >>> determinant m -- -1 -- determinant m | is_triangular m = product [ m !!! (i,i) | i <- [0..(nrows m)-1] ] | otherwise = determinant_recursive where m' i j = m !!! (i,j) det_minor i j = determinant (minor m i j) determinant_recursive = sum [ (-1)^(toInteger j) NP.* (m' 0 j) NP.* (det_minor 0 j) | j <- [0..(ncols m)-1] ] -- | Matrix multiplication. -- -- Examples: -- -- >>> let m1 = fromList [[1,2,3], [4,5,6]] :: Mat N2 N3 Int -- >>> let m2 = fromList [[1,2],[3,4],[5,6]] :: Mat N3 N2 Int -- >>> m1 * m2 -- ((22,28),(49,64)) -- infixl 7 * (*) :: (Ring.C a, Arity m, Arity n, Arity p) => Mat m n a -> Mat n p a -> Mat m p a (*) m1 m2 = construct lambda where lambda i j = sum [(m1 !!! (i,k)) NP.* (m2 !!! (k,j)) | k <- [0..(ncols m1)-1] ] instance (Ring.C a, Arity m, Arity n) => Additive.C (Mat m n a) where (Mat rows1) + (Mat rows2) = Mat $ V.zipWith (V.zipWith (+)) rows1 rows2 (Mat rows1) - (Mat rows2) = Mat $ V.zipWith (V.zipWith (-)) rows1 rows2 zero = Mat (V.replicate $ V.replicate (fromInteger 0)) instance (Ring.C a, Arity m, Arity n, m ~ n) => Ring.C (Mat m n a) where -- The first * is ring multiplication, the second is matrix -- multiplication. m1 * m2 = m1 * m2 instance (Ring.C a, Arity m, Arity n) => Module.C a (Mat m n a) where -- We can multiply a matrix by a scalar of the same type as its -- elements. x *> (Mat rows) = Mat $ V.map (V.map (NP.* x)) rows instance (Algebraic.C a, ToRational.C a, Arity m, Arity n) => Normed (Mat (S m) (S n) a) where -- | Generic p-norms. The usual norm in R^n is (norm_p 2). We treat -- all matrices as big vectors. -- -- Examples: -- -- >>> let v1 = vec2d (3,4) -- >>> norm_p 1 v1 -- 7.0 -- >>> norm_p 2 v1 -- 5.0 -- norm_p p (Mat rows) = (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs] where p' = toInteger p xs = concat $ V.toList $ V.map V.toList rows -- | The infinity norm. -- -- Examples: -- -- >>> let v1 = vec3d (1,5,2) -- >>> norm_infty v1 -- 5 -- norm_infty (Mat rows) = fromRational' $ toRational $ V.maximum $ V.map V.maximum rows -- Vector helpers. We want it to be easy to create low-dimension -- column vectors, which are nx1 matrices. -- | Convenient constructor for 2D vectors. -- -- Examples: -- -- >>> import Roots.Simple -- >>> let fst m = m !!! (0,0) -- >>> let snd m = m !!! (1,0) -- >>> let h = 0.5 :: Double -- >>> let g1 m = 1.0 + h NP.* exp(-((fst m)^2))/(1.0 + (snd m)^2) -- >>> let g2 m = 0.5 + h NP.* atan((fst m)^2 + (snd m)^2) -- >>> let g u = vec2d ((g1 u), (g2 u)) -- >>> let u0 = vec2d (1.0, 1.0) -- >>> let eps = 1/(10^9) -- >>> fixed_point g eps u0 -- ((1.0728549599342185),(1.0820591495686167)) -- vec1d :: (a) -> Mat N1 N1 a vec1d (x) = Mat (mk1 (mk1 x)) vec2d :: (a,a) -> Mat N2 N1 a vec2d (x,y) = Mat (mk2 (mk1 x) (mk1 y)) vec3d :: (a,a,a) -> Mat N3 N1 a vec3d (x,y,z) = Mat (mk3 (mk1 x) (mk1 y) (mk1 z)) vec4d :: (a,a,a,a) -> Mat N4 N1 a vec4d (w,x,y,z) = Mat (mk4 (mk1 w) (mk1 x) (mk1 y) (mk1 z)) vec5d :: (a,a,a,a,a) -> Mat N5 N1 a vec5d (v,w,x,y,z) = Mat (mk5 (mk1 v) (mk1 w) (mk1 x) (mk1 y) (mk1 z)) -- Since we commandeered multiplication, we need to create 1x1 -- matrices in order to multiply things. scalar :: a -> Mat N1 N1 a scalar x = Mat (mk1 (mk1 x)) dot :: (RealRing.C a, n ~ N1, m ~ S t, Arity t) => Mat m n a -> Mat m n a -> a v1 `dot` v2 = ((transpose v1) * v2) !!! (0, 0) -- | The angle between @v1@ and @v2@ in Euclidean space. -- -- Examples: -- -- >>> let v1 = vec2d (1.0, 0.0) -- >>> let v2 = vec2d (0.0, 1.0) -- >>> angle v1 v2 == pi/2.0 -- True -- angle :: (Transcendental.C a, RealRing.C a, n ~ N1, m ~ S t, Arity t, ToRational.C a) => Mat m n a -> Mat m n a -> a angle v1 v2 = acos theta where theta = (recip norms) NP.* (v1 `dot` v2) norms = (norm v1) NP.* (norm v2)