]> gitweb.michael.orlitzky.com - numerical-analysis.git/blobdiff - src/Roots/Simple.hs
Change the type signature of fixed_point to work on Vectors.
[numerical-analysis.git] / src / Roots / Simple.hs
index 1ab9034038d996bf5906a985d78ada0eec8e362f..5aed7a1f4af210ea393033dd8bcf2f8b9c0f1923 100644 (file)
@@ -11,6 +11,8 @@ where
 
 import Data.List (find)
 
+import Vector
+
 import qualified Roots.Fast as F
 
 -- | Does the (continuous) function @f@ have a root on the interval
@@ -114,14 +116,22 @@ newton_iterations f f' x0 =
 --   >>> abs (f root) < 1/100000
 --   True
 --
+--   >>> import Data.Number.BigFloat
+--   >>> let eps = 1/(10^20) :: BigFloat Prec50
+--   >>> let Just root = newtons_method f f' eps 2
+--   >>> root
+--   1.13472413840151949260544605450647284028100785303643e0
+--   >>> abs (f root) < eps
+--   True
+--
 newtons_method :: (Fractional a, Ord a)
                  => (a -> a) -- ^ The function @f@ whose root we seek
                  -> (a -> a) -- ^ The derivative of @f@
                  -> a       -- ^ The tolerance epsilon
                  -> a       -- ^ Initial guess, x-naught
                  -> Maybe a
-newtons_method f f' epsilon x0
-  find (\x -> abs (f x) < epsilon) x_n
+newtons_method f f' epsilon x0 =
+  find (\x -> abs (f x) < epsilon) x_n
   where
     x_n = newton_iterations f f' x0
 
@@ -217,9 +227,9 @@ fixed_point_iterations f x0 =
 --   f(f(x0)),... such that the magnitude of the difference between it
 --   and the next element is less than epsilon.
 --
-fixed_point :: (Num a, Ord a)
+fixed_point :: (Num a, Vector a, RealFrac b)
             => (a -> a) -- ^ The function @f@ to iterate.
-            -> a       -- ^ The tolerance, @epsilon@.
+            -> b       -- ^ The tolerance, @epsilon@.
             -> a       -- ^ The initial value @x0@.
             -> a       -- ^ The fixed point.
 fixed_point f epsilon x0 =
@@ -228,8 +238,7 @@ fixed_point f epsilon x0 =
     xn = fixed_point_iterations f x0
     xn_plus_one = tail $ fixed_point_iterations f x0
 
-    abs_diff v w =
-      abs (v - w)
+    abs_diff v w = norm (v - w)
 
     -- The nth entry in this list is the absolute value of x_{n} -
     -- x_{n+1}.