]> gitweb.michael.orlitzky.com - numerical-analysis.git/blobdiff - src/Roots/Fast.hs
Remove assumptions on the Normed class.
[numerical-analysis.git] / src / Roots / Fast.hs
index 8e49750e650a12f727d2c63282686bc0bfa7dfc1..0deb1fd6237a5909ea7d15a252f09b524093d1bf 100644 (file)
@@ -1,3 +1,5 @@
+{-# LANGUAGE RebindableSyntax #-}
+
 -- | The Roots.Fast module contains faster implementations of the
 --   'Roots.Simple' algorithms. Generally, we will pass precomputed
 --   values to the next iteration of a function rather than passing
@@ -6,7 +8,21 @@
 module Roots.Fast
 where
 
-has_root :: (Fractional a, Ord a, Ord b, Num b)
+import Data.List (find)
+
+import Normed
+
+import NumericPrelude hiding (abs)
+import qualified Algebra.Absolute as Absolute
+import qualified Algebra.Additive as Additive
+import qualified Algebra.Algebraic as Algebraic
+import qualified Algebra.Field as Field
+import qualified Algebra.RealRing as RealRing
+import qualified Algebra.RealField as RealField
+
+has_root :: (RealField.C a,
+             RealRing.C b,
+             Absolute.C b)
          => (a -> b) -- ^ The function @f@
          -> a       -- ^ The \"left\" endpoint, @a@
          -> a       -- ^ The \"right\" endpoint, @b@
@@ -46,8 +62,9 @@ has_root f a b epsilon f_of_a f_of_b =
     c = (a + b)/2
 
 
-
-bisect :: (Fractional a, Ord a, Num b, Ord b)
+bisect :: (RealField.C a,
+           RealRing.C b,
+           Absolute.C b)
        => (a -> b) -- ^ The function @f@ whose root we seek
        -> a       -- ^ The \"left\" endpoint of the interval, @a@
        -> a       -- ^ The \"right\" endpoint of the interval, @b@
@@ -80,3 +97,55 @@ bisect f a b epsilon f_of_a f_of_b
                  Just v -> v
 
     c = (a + b) / 2
+
+
+
+
+-- | Iterate the function @f@ with the initial guess @x0@ in hopes of
+--   finding a fixed point.
+fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
+                       -> a       -- ^ The initial value @x0@.
+                       -> [a]     -- ^ The resulting sequence of x_{n}.
+fixed_point_iterations f x0 =
+  iterate f x0
+
+
+-- | Find a fixed point of the function @f@ with the search starting
+--   at x0. This will find the first element in the chain f(x0),
+--   f(f(x0)),... such that the magnitude of the difference between it
+--   and the next element is less than epsilon.
+--
+--   We also return the number of iterations required.
+--
+fixed_point_with_iterations :: (Normed a,
+                                Algebraic.C a,
+                                RealField.C b,
+                                Algebraic.C b)
+                            => (a -> a)  -- ^ The function @f@ to iterate.
+                            -> b        -- ^ The tolerance, @epsilon@.
+                            -> a        -- ^ The initial value @x0@.
+                            -> (Int, a) -- ^ The (iterations, fixed point) pair
+fixed_point_with_iterations f epsilon x0 =
+  (fst winning_pair)
+  where
+    xn = fixed_point_iterations f x0
+    xn_plus_one = tail xn
+
+    abs_diff v w = norm (v - w)
+
+    -- The nth entry in this list is the absolute value of x_{n} -
+    -- x_{n+1}.
+    differences = zipWith abs_diff xn xn_plus_one
+
+    -- This produces the list [(n, xn)] so that we can determine
+    -- the number of iterations required.
+    numbered_xn = zip [0..] xn
+
+    -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
+    pairs = zip numbered_xn differences
+
+    -- The pair (xn, |x_{n} - x_{n+1}|) with
+    -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
+    -- "safe" since the list is infinite. We'll succeed or loop
+    -- forever.
+    Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs