]> gitweb.michael.orlitzky.com - numerical-analysis.git/blobdiff - src/Linear/Matrix.hs
New function: Linear.Matrix.identity_matrix.
[numerical-analysis.git] / src / Linear / Matrix.hs
index 20769b4708950d103bd0ec1fd18737d9616cb0ca..c0f56b348f4d4d2fbdf272cbac05e7c799eddce7 100644 (file)
@@ -25,6 +25,7 @@ import Data.Vector.Fixed (
   N5,
   S,
   Z,
+  generate,
   mk1,
   mk2,
   mk3,
@@ -43,7 +44,7 @@ import qualified Data.Vector.Fixed as V (
   zipWith
   )
 import Data.Vector.Fixed.Boxed (Vec)
-import Data.Vector.Fixed.Internal.Arity (Arity, arity)
+import Data.Vector.Fixed.Cont (Arity, arity)
 import Linear.Vector
 import Normed
 
@@ -198,8 +199,6 @@ symmetric m =
 --   entries in the matrix. The i,j entry of the resulting matrix will
 --   have the value returned by lambda i j.
 --
---   TODO: Don't cheat with fromList.
---
 --   Examples:
 --
 --   >>> let lambda i j = i + j
@@ -208,15 +207,25 @@ symmetric m =
 --
 construct :: forall m n a. (Arity m, Arity n)
           => (Int -> Int -> a) -> Mat m n a
-construct lambda = Mat rows
+construct lambda = Mat $ generate make_row
   where
-    -- The arity trick is used in Data.Vector.Fixed.length.
-    imax = (arity (undefined :: m)) - 1
-    jmax = (arity (undefined :: n)) - 1
-    row' i = V.fromList [ lambda i j | j <- [0..jmax] ]
-    rows = V.fromList [ row' i | i <- [0..imax] ]
+    make_row :: Int -> Vec n a
+    make_row i = generate (lambda i)
 
 
+-- | Create an identity matrix with the right dimensions.
+--
+--   Examples:
+--
+--   >>> identity_matrix :: Mat3 Int
+--   ((1,0,0),(0,1,0),(0,0,1))
+--   >>> identity_matrix :: Mat3 Double
+--   ((1.0,0.0,0.0),(0.0,1.0,0.0),(0.0,0.0,1.0))
+--
+identity_matrix :: (Arity m, Ring.C a) => Mat m m a
+identity_matrix =
+  construct (\i j -> if i == j then (fromInteger 1) else (fromInteger 0))
+
 -- | Given a positive-definite matrix @m@, computes the
 --   upper-triangular matrix @r@ with (transpose r)*r == m and all
 --   values on the diagonal of @r@ positive.
@@ -431,7 +440,7 @@ instance (Algebraic.C a,
   --   5.0
   --
   norm_p p (Mat rows) =
-    (root p') $ sum [(fromRational' $ toRational x)^p' | x <- xs]
+    (root p') $ sum [fromRational' (toRational x)^p' | x <- xs]
     where
       p' = toInteger p
       xs = concat $ V.toList $ V.map V.toList rows
@@ -530,13 +539,83 @@ angle v1 v2 =
 --   Examples:
 --
 --   >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
---   >>> diagonal m
+--   >>> diagonal_part m
 --   ((1,0,0),(0,5,0),(0,0,9))
 --
-diagonal :: (Arity m, Ring.C a)
+diagonal_part :: (Arity m, Ring.C a)
          => Mat m m a
          -> Mat m m a
-diagonal matrix =
+diagonal_part matrix =
   construct lambda
   where
     lambda i j = if i == j then matrix !!! (i,j) else 0
+
+
+-- | Given a square @matrix@, return a new matrix of the same size
+--   containing only the on-diagonal and below-diagonal entries of
+--   @matrix@. The above-diagonal entries are set to zero.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
+--   >>> lt_part m
+--   ((1,0,0),(4,5,0),(7,8,9))
+--
+lt_part :: (Arity m, Ring.C a)
+        => Mat m m a
+        -> Mat m m a
+lt_part matrix =
+  construct lambda
+  where
+    lambda i j = if i >= j then matrix !!! (i,j) else 0
+
+
+-- | Given a square @matrix@, return a new matrix of the same size
+--   containing only the below-diagonal entries of @matrix@. The on-
+--   and above-diagonal entries are set to zero.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
+--   >>> lt_part_strict m
+--   ((0,0,0),(4,0,0),(7,8,0))
+--
+lt_part_strict :: (Arity m, Ring.C a)
+        => Mat m m a
+        -> Mat m m a
+lt_part_strict matrix =
+  construct lambda
+  where
+    lambda i j = if i > j then matrix !!! (i,j) else 0
+
+
+-- | Given a square @matrix@, return a new matrix of the same size
+--   containing only the on-diagonal and above-diagonal entries of
+--   @matrix@. The below-diagonal entries are set to zero.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
+--   >>> ut_part m
+--   ((1,2,3),(0,5,6),(0,0,9))
+--
+ut_part :: (Arity m, Ring.C a)
+        => Mat m m a
+        -> Mat m m a
+ut_part = transpose . lt_part . transpose
+
+
+-- | Given a square @matrix@, return a new matrix of the same size
+--   containing only the above-diagonal entries of @matrix@. The on-
+--   and below-diagonal entries are set to zero.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,2,3],[4,5,6],[7,8,9]] :: Mat3 Int
+--   >>> ut_part_strict m
+--   ((0,2,3),(0,0,6),(0,0,0))
+--
+ut_part_strict :: (Arity m, Ring.C a)
+        => Mat m m a
+        -> Mat m m a
+ut_part_strict = transpose . lt_part_strict . transpose