]> gitweb.michael.orlitzky.com - numerical-analysis.git/commitdiff
Add tolerant versions of is_{upper,lower}_triangular.
authorMichael Orlitzky <michael@orlitzky.com>
Mon, 3 Feb 2014 01:21:33 +0000 (20:21 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Mon, 3 Feb 2014 01:21:33 +0000 (20:21 -0500)
Fix the QR factorization code.
Add tests for the QR behavior.

src/Linear/Matrix.hs
src/Linear/QR.hs

index 69a74b50eeacf7ea944c790d3858dd0cba12c266..c6f4a83c71af1c5231113ec555fa79f2dcfcc038 100644 (file)
@@ -48,11 +48,13 @@ import Data.Vector.Fixed.Cont (Arity, arity)
 import Linear.Vector
 import Normed
 
-import NumericPrelude hiding ((*), abs)
-import qualified NumericPrelude as NP ((*))
+import NumericPrelude hiding ( (*), abs )
+import qualified NumericPrelude as NP ( (*) )
+import qualified Algebra.Absolute as Absolute ( C )
+import Algebra.Absolute ( abs )
+import qualified Algebra.Additive as Additive
 import qualified Algebra.Algebraic as Algebraic
 import Algebra.Algebraic (root)
-import qualified Algebra.Additive as Additive
 import qualified Algebra.Ring as Ring
 import qualified Algebra.Module as Module
 import qualified Algebra.RealRing as RealRing
@@ -250,21 +252,26 @@ cholesky m = construct r
 
 
 -- | Returns True if the given matrix is upper-triangular, and False
---   otherwise.
+--   otherwise. The parameter @epsilon@ lets the caller choose a
+--   tolerance.
 --
 --   Examples:
 --
---   >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
+--   >>> let m = fromList [[1,1],[1e-12,1]] :: Mat2 Double
 --   >>> is_upper_triangular m
 --   False
---
---   >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
---   >>> is_upper_triangular m
+--   >>> is_upper_triangular' 1e-10 m
 --   True
 --
-is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
-                    => Mat m n a -> Bool
-is_upper_triangular m =
+--   TODO:
+--
+--     1. Don't cheat with lists.
+--
+is_upper_triangular' :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
+                    => a -- ^ The tolerance @epsilon@.
+                    -> Mat m n a
+                    -> Bool
+is_upper_triangular' epsilon m =
   and $ concat results
   where
     results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
@@ -272,11 +279,36 @@ is_upper_triangular m =
     test :: Int -> Int -> Bool
     test i j
       | i <= j = True
-      | otherwise = m !!! (i,j) == 0
+      -- use "less than or equal to" so zero is a valid epsilon
+      | otherwise = abs (m !!! (i,j)) <= epsilon
+
+
+-- | Returns True if the given matrix is upper-triangular, and False
+--   otherwise. A specialized version of 'is_upper_triangular\'' with
+--   @epsilon = 0@.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
+--   >>> is_upper_triangular m
+--   False
+--
+--   >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
+--   >>> is_upper_triangular m
+--   True
+--
+--   TODO:
+--
+--     1. The Ord constraint is too strong here, Eq would suffice.
+--
+is_upper_triangular :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
+                    => Mat m n a -> Bool
+is_upper_triangular = is_upper_triangular' 0
 
 
 -- | Returns True if the given matrix is lower-triangular, and False
---   otherwise.
+--   otherwise. This is a specialized version of 'is_lower_triangular\''
+--   with @epsilon = 0@.
 --
 --   Examples:
 --
@@ -288,8 +320,9 @@ is_upper_triangular m =
 --   >>> is_lower_triangular m
 --   False
 --
-is_lower_triangular :: (Eq a,
+is_lower_triangular :: (Ord a,
                         Ring.C a,
+                        Absolute.C a,
                         Arity m,
                         Arity n)
                     => Mat m n a
@@ -297,6 +330,29 @@ is_lower_triangular :: (Eq a,
 is_lower_triangular = is_upper_triangular . transpose
 
 
+-- | Returns True if the given matrix is lower-triangular, and False
+--   otherwise. The parameter @epsilon@ lets the caller choose a
+--   tolerance.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,1e-12],[1,1]] :: Mat2 Double
+--   >>> is_lower_triangular m
+--   False
+--   >>> is_lower_triangular' 1e-12 m
+--   True
+--
+is_lower_triangular' :: (Ord a,
+                         Ring.C a,
+                         Absolute.C a,
+                         Arity m,
+                         Arity n)
+                    => a -- ^ The tolerance @epsilon@.
+                    -> Mat m n a
+                    -> Bool
+is_lower_triangular' epsilon = (is_upper_triangular' epsilon) . transpose
+
+
 -- | Returns True if the given matrix is triangular, and False
 --   otherwise.
 --
@@ -314,8 +370,9 @@ is_lower_triangular = is_upper_triangular . transpose
 --   >>> is_triangular m
 --   False
 --
-is_triangular :: (Eq a,
+is_triangular :: (Ord a,
                   Ring.C a,
+                  Absolute.C a,
                   Arity m,
                   Arity n)
               => Mat m n a
@@ -353,8 +410,9 @@ class (Eq a, Ring.C a) => Determined p a where
 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
   determinant (Mat rows) = (V.head . V.head) rows
 
-instance (Eq a,
+instance (Ord a,
           Ring.C a,
+          Absolute.C a,
           Arity n,
           Determined (Mat (S n) (S n)) a)
          => Determined (Mat (S (S n)) (S (S n))) a where
index 58027bb0be7a8b574fe78bd437a683c1b30508e2..ea72958d74163c30679ce75231b5a48a8f0f45dd 100644 (file)
@@ -12,6 +12,7 @@ import qualified Algebra.Ring as Ring ( C )
 import qualified Algebra.Algebraic as Algebraic ( C )
 import Data.Vector.Fixed ( ifoldl )
 import Data.Vector.Fixed.Cont ( Arity )
+import Debug.Trace
 import NumericPrelude hiding ( (*) )
 
 import Linear.Matrix (
@@ -30,20 +31,52 @@ import Linear.Matrix (
 --
 --   Examples (Watkins, p. 193):
 --
---   >>> import Linear.Matrix ( Mat2, fromList )
+--   >>> import Normed ( Normed(..) )
+--   >>> import Linear.Vector ( Vec2, Vec3 )
+--   >>> import Linear.Matrix ( Mat2, Mat3, fromList, frobenius_norm )
+--   >>> import qualified Data.Vector.Fixed as V ( map )
+--
 --   >>> let m  = givens_rotator 0 1 1 1 :: Mat2 Double
 --   >>> let m2 = fromList [[1, -1],[1, 1]] :: Mat2 Double
 --   >>> m == (1 / (sqrt 2) :: Double) *> m2
 --   True
 --
-givens_rotator :: forall m a. (Ring.C a, Algebraic.C a, Arity m)
+--   >>> let m = fromList [[2,3],[5,7]] :: Mat2 Double
+--   >>> let rot =  givens_rotator 0 1 2.0 5.0 :: Mat2 Double
+--   >>> ((transpose rot) * m) !!! (1,0) < 1e-12
+--   True
+--   >>> let (Mat rows) = rot
+--   >>> let (Mat cols) = transpose rot
+--   >>> V.map norm rows :: Vec2 Double
+--   fromList [1.0,1.0]
+--   >>> V.map norm cols :: Vec2 Double
+--   fromList [1.0,1.0]
+--
+--   >>> let m = fromList [[12,-51,4],[6,167,-68],[-4,24,-41]] :: Mat3 Double
+--   >>> let rot = givens_rotator 1 2 6 (-4) :: Mat3 Double
+--   >>> let ex_rot_r1 = [1,0,0] :: [Double]
+--   >>> let ex_rot_r2 = [0,0.83205,-0.55470] :: [Double]
+--   >>> let ex_rot_r3 = [0, 0.55470, 0.83205] :: [Double]
+--   >>> let ex_rot = fromList [ex_rot_r1, ex_rot_r2, ex_rot_r3] :: Mat3 Double
+--   >>> frobenius_norm ((transpose rot) - ex_rot) < 1e-4
+--   True
+--   >>> ((transpose rot) * m) !!! (2,0) == 0
+--   True
+--   >>> let (Mat rows) = rot
+--   >>> let (Mat cols) = transpose rot
+--   >>> V.map norm rows :: Vec3 Double
+--   fromList [1.0,1.0,1.0]
+--   >>> V.map norm cols :: Vec3 Double
+--   fromList [1.0,1.0,1.0]
+--
+givens_rotator :: forall m a. (Eq a, Ring.C a, Algebraic.C a, Arity m)
                => Int -> Int -> a -> a -> Mat m m a
 givens_rotator i j xi xj =
   construct f
   where
     xnorm = sqrt $ xi^2 + xj^2
-    c = xi / xnorm
-    s = xj / xnorm
+    c = if xnorm == (fromInteger 0) then (fromInteger 1) else xi / xnorm
+    s = if xnorm == (fromInteger 0) then (fromInteger 0) else xj / xnorm
 
     f :: Int -> Int -> a
     f y z
@@ -65,7 +98,40 @@ givens_rotator i j xi xj =
 --   factorization. We keep the pair updated by multiplying @q@ and
 --   @r@ by the new rotator (or its transpose).
 --
-qr :: forall m n a. (Arity m, Arity n, Algebraic.C a, Ring.C a)
+--   Examples:
+--
+--   >>> import Linear.Matrix
+--
+--   >>> let m = fromList [[1,2],[1,3]] :: Mat2 Double
+--   >>> let (q,r) = qr m
+--   >>> let c = (1 / (sqrt 2 :: Double))
+--   >>> let ex_q = c *> (fromList [[1,-1],[1,1]] :: Mat2 Double)
+--   >>> let ex_r = c *> (fromList [[2,5],[0,1]] :: Mat2 Double)
+--   >>> frobenius_norm (q - ex_q) == 0
+--   True
+--   >>> frobenius_norm (r - ex_r) == 0
+--   True
+--   >>> let m' = q*r
+--   >>> frobenius_norm (m - m') < 1e-10
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+--   >>> let m = fromList [[2,3],[5,7]] :: Mat2 Double
+--   >>> let (q,r) = qr m
+--   >>> frobenius_norm (m - (q*r)) < 1e-12
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+--   >>> let m = fromList [[12,-51,4],[6,167,-68],[-4,24,-41]] :: Mat3 Double
+--   >>> let (q,r) = qr m
+--   >>> frobenius_norm (m - (q*r)) < 1e-12
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+qr :: forall m n a. (Arity m, Arity n, Eq a, Algebraic.C a, Ring.C a, Show a)
    => Mat m n a -> (Mat m m a, Mat m n a)
 qr matrix =
   ifoldl col_function initial_qr columns
@@ -83,10 +149,13 @@ qr matrix =
     -- | Process the entries in a column, doing basically the same
     --   thing as col_dunction does. It updates the QR factorization,
     --   maybe, and returns the current one.
-    f col_idx (q,r) idx x
-      | idx <= col_idx = (q,r) -- leave it alone.
-      | otherwise =
-          (q*rotator, (transpose rotator)*r)
+    f col_idx (q,r) idx _ -- ignore the current element
+      | idx <= col_idx = (q,r)
+--          trace ("---------------\nidx: " ++ (show idx) ++ ";\ncol_idx: " ++ (show col_idx) ++ "; leaving it alone") (q,r) -- leave it alone.
+      | otherwise = (q*rotator, (transpose rotator)*r)
+--          trace ("---------------\nidx: " ++ (show idx) ++ ";\ncol_idx: " ++ (show col_idx) ++ ";\nupdating Q and R;\nq: " ++ (show q) ++ ";\nr " ++ (show r) ++ ";\nnew q: " ++ (show $ q*rotator) ++ ";\nnew r: " ++ (show $ (transpose rotator)*r) ++ ";\ny: " ++ (show y) ++ ";\nr[i,j]: " ++ (show (r !!! (col_idx, col_idx))))
+--          (q*rotator, (transpose rotator)*r)
           where
+            y = r !!! (idx, col_idx)
             rotator :: Mat m m a
-            rotator = givens_rotator col_idx idx (r !!! (idx, col_idx)) x
+            rotator = givens_rotator col_idx idx (r !!! (col_idx, col_idx)) y