]> gitweb.michael.orlitzky.com - numerical-analysis.git/blobdiff - src/Roots/Fast.hs
Clean up imports everywhere.
[numerical-analysis.git] / src / Roots / Fast.hs
index cda999ceab62a2fcc3b1f77e07560fb22403d6d3..e5321c9fa82142c55b487104b5df1ec5d1fb9b70 100644 (file)
@@ -1,17 +1,34 @@
+{-# LANGUAGE RebindableSyntax #-}
+
 -- | The Roots.Fast module contains faster implementations of the
 --   'Roots.Simple' algorithms. Generally, we will pass precomputed
 --   values to the next iteration of a function rather than passing
 --   the function and the points at which to (re)evaluate it.
 
-module Roots.Fast
+module Roots.Fast (
+  bisect,
+  fixed_point_iterations,
+  fixed_point_with_iterations,
+  has_root,
+  trisect )
 where
 
-import Data.List (find)
+import Data.List ( find )
+import Data.Maybe ( fromMaybe )
+
+import Normed ( Normed(..) )
 
-import Vector
+import NumericPrelude hiding ( abs )
+import qualified Algebra.Absolute as Absolute ( C )
+import qualified Algebra.Additive as Additive ( C )
+import qualified Algebra.Algebraic as Algebraic ( C )
+import qualified Algebra.RealRing as RealRing ( C )
+import qualified Algebra.RealField as RealField ( C )
 
 
-has_root :: (Fractional a, Ord a, Ord b, Num b)
+has_root :: (RealField.C a,
+             RealRing.C b,
+             Absolute.C b)
          => (a -> b) -- ^ The function @f@
          -> a       -- ^ The \"left\" endpoint, @a@
          -> a       -- ^ The \"right\" endpoint, @b@
@@ -20,39 +37,27 @@ has_root :: (Fractional a, Ord a, Ord b, Num b)
          -> Maybe b -- ^ Precoumpted f(a)
          -> Maybe b -- ^ Precoumpted f(b)
          -> Bool
-has_root f a b epsilon f_of_a f_of_b =
-  if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then
-    -- We don't care about epsilon here, there's definitely a root!
-    True
-  else
-    if (b - a) <= epsilon' then
-      -- Give up, return false.
-      False
-    else
-      -- If either [a,c] or [c,b] have roots, we do too.
+has_root f a b epsilon f_of_a f_of_b
+  | (signum (f_of_a')) * (signum (f_of_b')) /= 1 = True
+  | (b - a) <= epsilon' = False
+  | otherwise =
       (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
         (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
   where
     -- If the size of the smallest subinterval is not specified,
     -- assume we just want to check once on all of [a,b].
-    epsilon' = case epsilon of
-                 Nothing -> (b-a)
-                 Just eps -> eps
+    epsilon' = fromMaybe (b-a) epsilon
 
     -- Compute f(a) and f(b) only if needed.
-    f_of_a'  = case f_of_a of
-                 Nothing -> f a
-                 Just v -> v
-
-    f_of_b'  = case f_of_b of
-                 Nothing -> f b
-                 Just v -> v
+    f_of_a'  = fromMaybe (f a) f_of_a
+    f_of_b'  = fromMaybe (f b) f_of_b
 
     c = (a + b)/2
 
 
-
-bisect :: (Fractional a, Ord a, Num b, Ord b)
+bisect :: (RealField.C a,
+           RealRing.C b,
+           Absolute.C b)
        => (a -> b) -- ^ The function @f@ whose root we seek
        -> a       -- ^ The \"left\" endpoint of the interval, @a@
        -> a       -- ^ The \"right\" endpoint of the interval, @b@
@@ -76,25 +81,64 @@ bisect f a b epsilon f_of_a f_of_b
         else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
   where
     -- Compute f(a) and f(b) only if needed.
-    f_of_a'  = case f_of_a of
-                 Nothing -> f a
-                 Just v -> v
-
-    f_of_b'  = case f_of_b of
-                 Nothing -> f b
-                 Just v -> v
+    f_of_a'  = fromMaybe (f a) f_of_a
+    f_of_b'  = fromMaybe (f b) f_of_b
 
     c = (a + b) / 2
 
 
 
+trisect :: (RealField.C a,
+            RealRing.C b,
+            Absolute.C b)
+        => (a -> b) -- ^ The function @f@ whose root we seek
+        -> a       -- ^ The \"left\" endpoint of the interval, @a@
+        -> a       -- ^ The \"right\" endpoint of the interval, @b@
+        -> a       -- ^ The tolerance, @epsilon@
+        -> Maybe b -- ^ Precomputed f(a)
+        -> Maybe b -- ^ Precomputed f(b)
+        -> Maybe a
+trisect f a b epsilon f_of_a f_of_b
+  -- We pass @epsilon@ to the 'has_root' function because if we want a
+  -- result within epsilon of the true root, we need to know that
+  -- there *is* a root within an interval of length epsilon.
+  | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
+  | f_of_a' == 0 = Just a
+  | f_of_b' == 0 = Just b
+  | otherwise =
+      -- Use a 'prime' just for consistency.
+    let (a', b', fa', fb')
+          | has_root f d b (Just epsilon) (Just f_of_d') (Just f_of_b') =
+              (d, b, f_of_d', f_of_b')
+          | has_root f c d (Just epsilon) (Just f_of_c') (Just f_of_d') =
+              (c, d, f_of_c', f_of_d')
+          | otherwise =
+              (a, c, f_of_a', f_of_c')
+    in
+      if (b-a) < 2*epsilon
+      then Just ((b+a)/2)
+      else trisect f a' b' epsilon (Just fa') (Just fb')
+  where
+    -- Compute f(a) and f(b) only if needed.
+    f_of_a'  = fromMaybe (f a) f_of_a
+    f_of_b'  = fromMaybe (f b) f_of_b
+
+    c = (2*a + b) / 3
+
+    d = (a + 2*b) / 3
+
+    f_of_c' = f c
+    f_of_d' = f d
+
+
+
 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
 --   finding a fixed point.
 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
                        -> a       -- ^ The initial value @x0@.
                        -> [a]     -- ^ The resulting sequence of x_{n}.
-fixed_point_iterations f x0 =
-  iterate f x0
+fixed_point_iterations =
+  iterate
 
 
 -- | Find a fixed point of the function @f@ with the search starting
@@ -104,7 +148,10 @@ fixed_point_iterations f x0 =
 --
 --   We also return the number of iterations required.
 --
-fixed_point_with_iterations :: (Vector a, RealFrac b)
+fixed_point_with_iterations :: (Normed a,
+                                Additive.C a,
+                                RealField.C b,
+                                Algebraic.C b)
                             => (a -> a)  -- ^ The function @f@ to iterate.
                             -> b        -- ^ The tolerance, @epsilon@.
                             -> a        -- ^ The initial value @x0@.
@@ -115,7 +162,7 @@ fixed_point_with_iterations f epsilon x0 =
     xn = fixed_point_iterations f x0
     xn_plus_one = tail xn
 
-    abs_diff v w = norm_2 (v - w)
+    abs_diff v w = norm (v - w)
 
     -- The nth entry in this list is the absolute value of x_{n} -
     -- x_{n+1}.
@@ -133,4 +180,3 @@ fixed_point_with_iterations f epsilon x0 =
     -- "safe" since the list is infinite. We'll succeed or loop
     -- forever.
     Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs
-