]> gitweb.michael.orlitzky.com - numerical-analysis.git/commitdiff
Add tolerant versions of is_{upper,lower}_triangular.
authorMichael Orlitzky <michael@orlitzky.com>
Mon, 3 Feb 2014 01:21:33 +0000 (20:21 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Mon, 3 Feb 2014 01:21:33 +0000 (20:21 -0500)
Fix the QR factorization code.
Add tests for the QR behavior.

src/Linear/Matrix.hs
src/Linear/QR.hs

index 69a74b50eeacf7ea944c790d3858dd0cba12c266..c6f4a83c71af1c5231113ec555fa79f2dcfcc038 100644 (file)
@@ -48,11 +48,13 @@ import Data.Vector.Fixed.Cont (Arity, arity)
 import Linear.Vector
 import Normed
 
 import Linear.Vector
 import Normed
 
-import NumericPrelude hiding ((*), abs)
-import qualified NumericPrelude as NP ((*))
+import NumericPrelude hiding ( (*), abs )
+import qualified NumericPrelude as NP ( (*) )
+import qualified Algebra.Absolute as Absolute ( C )
+import Algebra.Absolute ( abs )
+import qualified Algebra.Additive as Additive
 import qualified Algebra.Algebraic as Algebraic
 import Algebra.Algebraic (root)
 import qualified Algebra.Algebraic as Algebraic
 import Algebra.Algebraic (root)
-import qualified Algebra.Additive as Additive
 import qualified Algebra.Ring as Ring
 import qualified Algebra.Module as Module
 import qualified Algebra.RealRing as RealRing
 import qualified Algebra.Ring as Ring
 import qualified Algebra.Module as Module
 import qualified Algebra.RealRing as RealRing
@@ -250,21 +252,26 @@ cholesky m = construct r
 
 
 -- | Returns True if the given matrix is upper-triangular, and False
 
 
 -- | Returns True if the given matrix is upper-triangular, and False
---   otherwise.
+--   otherwise. The parameter @epsilon@ lets the caller choose a
+--   tolerance.
 --
 --   Examples:
 --
 --
 --   Examples:
 --
---   >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
+--   >>> let m = fromList [[1,1],[1e-12,1]] :: Mat2 Double
 --   >>> is_upper_triangular m
 --   False
 --   >>> is_upper_triangular m
 --   False
---
---   >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
---   >>> is_upper_triangular m
+--   >>> is_upper_triangular' 1e-10 m
 --   True
 --
 --   True
 --
-is_upper_triangular :: (Eq a, Ring.C a, Arity m, Arity n)
-                    => Mat m n a -> Bool
-is_upper_triangular m =
+--   TODO:
+--
+--     1. Don't cheat with lists.
+--
+is_upper_triangular' :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
+                    => a -- ^ The tolerance @epsilon@.
+                    -> Mat m n a
+                    -> Bool
+is_upper_triangular' epsilon m =
   and $ concat results
   where
     results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
   and $ concat results
   where
     results = [[ test i j | i <- [0..(nrows m)-1]] | j <- [0..(ncols m)-1] ]
@@ -272,11 +279,36 @@ is_upper_triangular m =
     test :: Int -> Int -> Bool
     test i j
       | i <= j = True
     test :: Int -> Int -> Bool
     test i j
       | i <= j = True
-      | otherwise = m !!! (i,j) == 0
+      -- use "less than or equal to" so zero is a valid epsilon
+      | otherwise = abs (m !!! (i,j)) <= epsilon
+
+
+-- | Returns True if the given matrix is upper-triangular, and False
+--   otherwise. A specialized version of 'is_upper_triangular\'' with
+--   @epsilon = 0@.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,0],[1,1]] :: Mat2 Int
+--   >>> is_upper_triangular m
+--   False
+--
+--   >>> let m = fromList [[1,2],[0,3]] :: Mat2 Int
+--   >>> is_upper_triangular m
+--   True
+--
+--   TODO:
+--
+--     1. The Ord constraint is too strong here, Eq would suffice.
+--
+is_upper_triangular :: (Ord a, Ring.C a, Absolute.C a, Arity m, Arity n)
+                    => Mat m n a -> Bool
+is_upper_triangular = is_upper_triangular' 0
 
 
 -- | Returns True if the given matrix is lower-triangular, and False
 
 
 -- | Returns True if the given matrix is lower-triangular, and False
---   otherwise.
+--   otherwise. This is a specialized version of 'is_lower_triangular\''
+--   with @epsilon = 0@.
 --
 --   Examples:
 --
 --
 --   Examples:
 --
@@ -288,8 +320,9 @@ is_upper_triangular m =
 --   >>> is_lower_triangular m
 --   False
 --
 --   >>> is_lower_triangular m
 --   False
 --
-is_lower_triangular :: (Eq a,
+is_lower_triangular :: (Ord a,
                         Ring.C a,
                         Ring.C a,
+                        Absolute.C a,
                         Arity m,
                         Arity n)
                     => Mat m n a
                         Arity m,
                         Arity n)
                     => Mat m n a
@@ -297,6 +330,29 @@ is_lower_triangular :: (Eq a,
 is_lower_triangular = is_upper_triangular . transpose
 
 
 is_lower_triangular = is_upper_triangular . transpose
 
 
+-- | Returns True if the given matrix is lower-triangular, and False
+--   otherwise. The parameter @epsilon@ lets the caller choose a
+--   tolerance.
+--
+--   Examples:
+--
+--   >>> let m = fromList [[1,1e-12],[1,1]] :: Mat2 Double
+--   >>> is_lower_triangular m
+--   False
+--   >>> is_lower_triangular' 1e-12 m
+--   True
+--
+is_lower_triangular' :: (Ord a,
+                         Ring.C a,
+                         Absolute.C a,
+                         Arity m,
+                         Arity n)
+                    => a -- ^ The tolerance @epsilon@.
+                    -> Mat m n a
+                    -> Bool
+is_lower_triangular' epsilon = (is_upper_triangular' epsilon) . transpose
+
+
 -- | Returns True if the given matrix is triangular, and False
 --   otherwise.
 --
 -- | Returns True if the given matrix is triangular, and False
 --   otherwise.
 --
@@ -314,8 +370,9 @@ is_lower_triangular = is_upper_triangular . transpose
 --   >>> is_triangular m
 --   False
 --
 --   >>> is_triangular m
 --   False
 --
-is_triangular :: (Eq a,
+is_triangular :: (Ord a,
                   Ring.C a,
                   Ring.C a,
+                  Absolute.C a,
                   Arity m,
                   Arity n)
               => Mat m n a
                   Arity m,
                   Arity n)
               => Mat m n a
@@ -353,8 +410,9 @@ class (Eq a, Ring.C a) => Determined p a where
 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
   determinant (Mat rows) = (V.head . V.head) rows
 
 instance (Eq a, Ring.C a) => Determined (Mat (S Z) (S Z)) a where
   determinant (Mat rows) = (V.head . V.head) rows
 
-instance (Eq a,
+instance (Ord a,
           Ring.C a,
           Ring.C a,
+          Absolute.C a,
           Arity n,
           Determined (Mat (S n) (S n)) a)
          => Determined (Mat (S (S n)) (S (S n))) a where
           Arity n,
           Determined (Mat (S n) (S n)) a)
          => Determined (Mat (S (S n)) (S (S n))) a where
index 58027bb0be7a8b574fe78bd437a683c1b30508e2..ea72958d74163c30679ce75231b5a48a8f0f45dd 100644 (file)
@@ -12,6 +12,7 @@ import qualified Algebra.Ring as Ring ( C )
 import qualified Algebra.Algebraic as Algebraic ( C )
 import Data.Vector.Fixed ( ifoldl )
 import Data.Vector.Fixed.Cont ( Arity )
 import qualified Algebra.Algebraic as Algebraic ( C )
 import Data.Vector.Fixed ( ifoldl )
 import Data.Vector.Fixed.Cont ( Arity )
+import Debug.Trace
 import NumericPrelude hiding ( (*) )
 
 import Linear.Matrix (
 import NumericPrelude hiding ( (*) )
 
 import Linear.Matrix (
@@ -30,20 +31,52 @@ import Linear.Matrix (
 --
 --   Examples (Watkins, p. 193):
 --
 --
 --   Examples (Watkins, p. 193):
 --
---   >>> import Linear.Matrix ( Mat2, fromList )
+--   >>> import Normed ( Normed(..) )
+--   >>> import Linear.Vector ( Vec2, Vec3 )
+--   >>> import Linear.Matrix ( Mat2, Mat3, fromList, frobenius_norm )
+--   >>> import qualified Data.Vector.Fixed as V ( map )
+--
 --   >>> let m  = givens_rotator 0 1 1 1 :: Mat2 Double
 --   >>> let m2 = fromList [[1, -1],[1, 1]] :: Mat2 Double
 --   >>> m == (1 / (sqrt 2) :: Double) *> m2
 --   True
 --
 --   >>> let m  = givens_rotator 0 1 1 1 :: Mat2 Double
 --   >>> let m2 = fromList [[1, -1],[1, 1]] :: Mat2 Double
 --   >>> m == (1 / (sqrt 2) :: Double) *> m2
 --   True
 --
-givens_rotator :: forall m a. (Ring.C a, Algebraic.C a, Arity m)
+--   >>> let m = fromList [[2,3],[5,7]] :: Mat2 Double
+--   >>> let rot =  givens_rotator 0 1 2.0 5.0 :: Mat2 Double
+--   >>> ((transpose rot) * m) !!! (1,0) < 1e-12
+--   True
+--   >>> let (Mat rows) = rot
+--   >>> let (Mat cols) = transpose rot
+--   >>> V.map norm rows :: Vec2 Double
+--   fromList [1.0,1.0]
+--   >>> V.map norm cols :: Vec2 Double
+--   fromList [1.0,1.0]
+--
+--   >>> let m = fromList [[12,-51,4],[6,167,-68],[-4,24,-41]] :: Mat3 Double
+--   >>> let rot = givens_rotator 1 2 6 (-4) :: Mat3 Double
+--   >>> let ex_rot_r1 = [1,0,0] :: [Double]
+--   >>> let ex_rot_r2 = [0,0.83205,-0.55470] :: [Double]
+--   >>> let ex_rot_r3 = [0, 0.55470, 0.83205] :: [Double]
+--   >>> let ex_rot = fromList [ex_rot_r1, ex_rot_r2, ex_rot_r3] :: Mat3 Double
+--   >>> frobenius_norm ((transpose rot) - ex_rot) < 1e-4
+--   True
+--   >>> ((transpose rot) * m) !!! (2,0) == 0
+--   True
+--   >>> let (Mat rows) = rot
+--   >>> let (Mat cols) = transpose rot
+--   >>> V.map norm rows :: Vec3 Double
+--   fromList [1.0,1.0,1.0]
+--   >>> V.map norm cols :: Vec3 Double
+--   fromList [1.0,1.0,1.0]
+--
+givens_rotator :: forall m a. (Eq a, Ring.C a, Algebraic.C a, Arity m)
                => Int -> Int -> a -> a -> Mat m m a
 givens_rotator i j xi xj =
   construct f
   where
     xnorm = sqrt $ xi^2 + xj^2
                => Int -> Int -> a -> a -> Mat m m a
 givens_rotator i j xi xj =
   construct f
   where
     xnorm = sqrt $ xi^2 + xj^2
-    c = xi / xnorm
-    s = xj / xnorm
+    c = if xnorm == (fromInteger 0) then (fromInteger 1) else xi / xnorm
+    s = if xnorm == (fromInteger 0) then (fromInteger 0) else xj / xnorm
 
     f :: Int -> Int -> a
     f y z
 
     f :: Int -> Int -> a
     f y z
@@ -65,7 +98,40 @@ givens_rotator i j xi xj =
 --   factorization. We keep the pair updated by multiplying @q@ and
 --   @r@ by the new rotator (or its transpose).
 --
 --   factorization. We keep the pair updated by multiplying @q@ and
 --   @r@ by the new rotator (or its transpose).
 --
-qr :: forall m n a. (Arity m, Arity n, Algebraic.C a, Ring.C a)
+--   Examples:
+--
+--   >>> import Linear.Matrix
+--
+--   >>> let m = fromList [[1,2],[1,3]] :: Mat2 Double
+--   >>> let (q,r) = qr m
+--   >>> let c = (1 / (sqrt 2 :: Double))
+--   >>> let ex_q = c *> (fromList [[1,-1],[1,1]] :: Mat2 Double)
+--   >>> let ex_r = c *> (fromList [[2,5],[0,1]] :: Mat2 Double)
+--   >>> frobenius_norm (q - ex_q) == 0
+--   True
+--   >>> frobenius_norm (r - ex_r) == 0
+--   True
+--   >>> let m' = q*r
+--   >>> frobenius_norm (m - m') < 1e-10
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+--   >>> let m = fromList [[2,3],[5,7]] :: Mat2 Double
+--   >>> let (q,r) = qr m
+--   >>> frobenius_norm (m - (q*r)) < 1e-12
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+--   >>> let m = fromList [[12,-51,4],[6,167,-68],[-4,24,-41]] :: Mat3 Double
+--   >>> let (q,r) = qr m
+--   >>> frobenius_norm (m - (q*r)) < 1e-12
+--   True
+--   >>> is_upper_triangular' 1e-10 r
+--   True
+--
+qr :: forall m n a. (Arity m, Arity n, Eq a, Algebraic.C a, Ring.C a, Show a)
    => Mat m n a -> (Mat m m a, Mat m n a)
 qr matrix =
   ifoldl col_function initial_qr columns
    => Mat m n a -> (Mat m m a, Mat m n a)
 qr matrix =
   ifoldl col_function initial_qr columns
@@ -83,10 +149,13 @@ qr matrix =
     -- | Process the entries in a column, doing basically the same
     --   thing as col_dunction does. It updates the QR factorization,
     --   maybe, and returns the current one.
     -- | Process the entries in a column, doing basically the same
     --   thing as col_dunction does. It updates the QR factorization,
     --   maybe, and returns the current one.
-    f col_idx (q,r) idx x
-      | idx <= col_idx = (q,r) -- leave it alone.
-      | otherwise =
-          (q*rotator, (transpose rotator)*r)
+    f col_idx (q,r) idx _ -- ignore the current element
+      | idx <= col_idx = (q,r)
+--          trace ("---------------\nidx: " ++ (show idx) ++ ";\ncol_idx: " ++ (show col_idx) ++ "; leaving it alone") (q,r) -- leave it alone.
+      | otherwise = (q*rotator, (transpose rotator)*r)
+--          trace ("---------------\nidx: " ++ (show idx) ++ ";\ncol_idx: " ++ (show col_idx) ++ ";\nupdating Q and R;\nq: " ++ (show q) ++ ";\nr " ++ (show r) ++ ";\nnew q: " ++ (show $ q*rotator) ++ ";\nnew r: " ++ (show $ (transpose rotator)*r) ++ ";\ny: " ++ (show y) ++ ";\nr[i,j]: " ++ (show (r !!! (col_idx, col_idx))))
+--          (q*rotator, (transpose rotator)*r)
           where
           where
+            y = r !!! (idx, col_idx)
             rotator :: Mat m m a
             rotator :: Mat m m a
-            rotator = givens_rotator col_idx idx (r !!! (idx, col_idx)) x
+            rotator = givens_rotator col_idx idx (r !!! (col_idx, col_idx)) y