{-# LANGUAGE RebindableSyntax #-} -- | The Roots.Fast module contains faster implementations of the -- 'Roots.Simple' algorithms. Generally, we will pass precomputed -- values to the next iteration of a function rather than passing -- the function and the points at which to (re)evaluate it. module Roots.Fast where import Data.List (find) import Normed import NumericPrelude hiding (abs) import qualified Algebra.Absolute as Absolute import qualified Algebra.Additive as Additive import qualified Algebra.Algebraic as Algebraic import qualified Algebra.Field as Field import qualified Algebra.RealRing as RealRing import qualified Algebra.RealField as RealField has_root :: (RealField.C a, RealRing.C b, Absolute.C b) => (a -> b) -- ^ The function @f@ -> a -- ^ The \"left\" endpoint, @a@ -> a -- ^ The \"right\" endpoint, @b@ -> Maybe a -- ^ The size of the smallest subinterval -- we'll examine, @epsilon@ -> Maybe b -- ^ Precoumpted f(a) -> Maybe b -- ^ Precoumpted f(b) -> Bool has_root f a b epsilon f_of_a f_of_b = if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then -- We don't care about epsilon here, there's definitely a root! True else if (b - a) <= epsilon' then -- Give up, return false. False else -- If either [a,c] or [c,b] have roots, we do too. (has_root f a c (Just epsilon') (Just f_of_a') Nothing) || (has_root f c b (Just epsilon') Nothing (Just f_of_b')) where -- If the size of the smallest subinterval is not specified, -- assume we just want to check once on all of [a,b]. epsilon' = case epsilon of Nothing -> (b-a) Just eps -> eps -- Compute f(a) and f(b) only if needed. f_of_a' = case f_of_a of Nothing -> f a Just v -> v f_of_b' = case f_of_b of Nothing -> f b Just v -> v c = (a + b)/2 bisect :: (RealField.C a, RealRing.C b, Absolute.C b) => (a -> b) -- ^ The function @f@ whose root we seek -> a -- ^ The \"left\" endpoint of the interval, @a@ -> a -- ^ The \"right\" endpoint of the interval, @b@ -> a -- ^ The tolerance, @epsilon@ -> Maybe b -- ^ Precomputed f(a) -> Maybe b -- ^ Precomputed f(b) -> Maybe a bisect f a b epsilon f_of_a f_of_b -- We pass @epsilon@ to the 'has_root' function because if we want a -- result within epsilon of the true root, we need to know that -- there *is* a root within an interval of length epsilon. | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing | f_of_a' == 0 = Just a | f_of_b' == 0 = Just b | (b - c) < epsilon = Just c | otherwise = -- Use a 'prime' just for consistency. let f_of_c' = f c in if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c')) then bisect f a c epsilon (Just f_of_a') (Just f_of_c') else bisect f c b epsilon (Just f_of_c') (Just f_of_b') where -- Compute f(a) and f(b) only if needed. f_of_a' = case f_of_a of Nothing -> f a Just v -> v f_of_b' = case f_of_b of Nothing -> f b Just v -> v c = (a + b) / 2 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of -- finding a fixed point. fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate. -> a -- ^ The initial value @x0@. -> [a] -- ^ The resulting sequence of x_{n}. fixed_point_iterations f x0 = iterate f x0 -- | Find a fixed point of the function @f@ with the search starting -- at x0. This will find the first element in the chain f(x0), -- f(f(x0)),... such that the magnitude of the difference between it -- and the next element is less than epsilon. -- -- We also return the number of iterations required. -- fixed_point_with_iterations :: (Normed a, Algebraic.C a, RealField.C b, Algebraic.C b) => (a -> a) -- ^ The function @f@ to iterate. -> b -- ^ The tolerance, @epsilon@. -> a -- ^ The initial value @x0@. -> (Int, a) -- ^ The (iterations, fixed point) pair fixed_point_with_iterations f epsilon x0 = (fst winning_pair) where xn = fixed_point_iterations f x0 xn_plus_one = tail xn abs_diff v w = norm (v - w) -- The nth entry in this list is the absolute value of x_{n} - -- x_{n+1}. differences = zipWith abs_diff xn xn_plus_one -- This produces the list [(n, xn)] so that we can determine -- the number of iterations required. numbered_xn = zip [0..] xn -- A list of pairs, (xn, |x_{n} - x_{n+1}|). pairs = zip numbered_xn differences -- The pair (xn, |x_{n} - x_{n+1}|) with -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is -- "safe" since the list is infinite. We'll succeed or loop -- forever. Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs