1 {-# LANGUAGE RebindableSyntax #-}
3 -- | The Roots.Fast module contains faster implementations of the
4 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
5 -- values to the next iteration of a function rather than passing
6 -- the function and the points at which to (re)evaluate it.
11 import Data.List (find)
15 import NumericPrelude hiding (abs)
16 import qualified Algebra.Absolute as Absolute
17 import qualified Algebra.Additive as Additive
18 import qualified Algebra.Algebraic as Algebraic
19 import qualified Algebra.RealRing as RealRing
20 import qualified Algebra.RealField as RealField
22 has_root :: (RealField.C a,
25 => (a -> b) -- ^ The function @f@
26 -> a -- ^ The \"left\" endpoint, @a@
27 -> a -- ^ The \"right\" endpoint, @b@
28 -> Maybe a -- ^ The size of the smallest subinterval
29 -- we'll examine, @epsilon@
30 -> Maybe b -- ^ Precoumpted f(a)
31 -> Maybe b -- ^ Precoumpted f(b)
33 has_root f a b epsilon f_of_a f_of_b =
34 if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then
35 -- We don't care about epsilon here, there's definitely a root!
38 if (b - a) <= epsilon' then
39 -- Give up, return false.
42 -- If either [a,c] or [c,b] have roots, we do too.
43 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
44 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
46 -- If the size of the smallest subinterval is not specified,
47 -- assume we just want to check once on all of [a,b].
48 epsilon' = case epsilon of
52 -- Compute f(a) and f(b) only if needed.
53 f_of_a' = case f_of_a of
57 f_of_b' = case f_of_b of
64 bisect :: (RealField.C a,
67 => (a -> b) -- ^ The function @f@ whose root we seek
68 -> a -- ^ The \"left\" endpoint of the interval, @a@
69 -> a -- ^ The \"right\" endpoint of the interval, @b@
70 -> a -- ^ The tolerance, @epsilon@
71 -> Maybe b -- ^ Precomputed f(a)
72 -> Maybe b -- ^ Precomputed f(b)
74 bisect f a b epsilon f_of_a f_of_b
75 -- We pass @epsilon@ to the 'has_root' function because if we want a
76 -- result within epsilon of the true root, we need to know that
77 -- there *is* a root within an interval of length epsilon.
78 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
79 | f_of_a' == 0 = Just a
80 | f_of_b' == 0 = Just b
81 | (b - c) < epsilon = Just c
83 -- Use a 'prime' just for consistency.
85 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
86 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
87 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
89 -- Compute f(a) and f(b) only if needed.
90 f_of_a' = case f_of_a of
94 f_of_b' = case f_of_b of
103 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
104 -- finding a fixed point.
105 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
106 -> a -- ^ The initial value @x0@.
107 -> [a] -- ^ The resulting sequence of x_{n}.
108 fixed_point_iterations f x0 =
112 -- | Find a fixed point of the function @f@ with the search starting
113 -- at x0. This will find the first element in the chain f(x0),
114 -- f(f(x0)),... such that the magnitude of the difference between it
115 -- and the next element is less than epsilon.
117 -- We also return the number of iterations required.
119 fixed_point_with_iterations :: (Normed a,
123 => (a -> a) -- ^ The function @f@ to iterate.
124 -> b -- ^ The tolerance, @epsilon@.
125 -> a -- ^ The initial value @x0@.
126 -> (Int, a) -- ^ The (iterations, fixed point) pair
127 fixed_point_with_iterations f epsilon x0 =
130 xn = fixed_point_iterations f x0
131 xn_plus_one = tail xn
133 abs_diff v w = norm (v - w)
135 -- The nth entry in this list is the absolute value of x_{n} -
137 differences = zipWith abs_diff xn xn_plus_one
139 -- This produces the list [(n, xn)] so that we can determine
140 -- the number of iterations required.
141 numbered_xn = zip [0..] xn
143 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
144 pairs = zip numbered_xn differences
146 -- The pair (xn, |x_{n} - x_{n+1}|) with
147 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
148 -- "safe" since the list is infinite. We'll succeed or loop
150 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs