]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Roots/Fast.hs
Remove assumptions on the Normed class.
[numerical-analysis.git] / src / Roots / Fast.hs
1 {-# LANGUAGE RebindableSyntax #-}
2
3 -- | The Roots.Fast module contains faster implementations of the
4 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
5 -- values to the next iteration of a function rather than passing
6 -- the function and the points at which to (re)evaluate it.
7
8 module Roots.Fast
9 where
10
11 import Data.List (find)
12
13 import Normed
14
15 import NumericPrelude hiding (abs)
16 import qualified Algebra.Absolute as Absolute
17 import qualified Algebra.Additive as Additive
18 import qualified Algebra.Algebraic as Algebraic
19 import qualified Algebra.Field as Field
20 import qualified Algebra.RealRing as RealRing
21 import qualified Algebra.RealField as RealField
22
23 has_root :: (RealField.C a,
24 RealRing.C b,
25 Absolute.C b)
26 => (a -> b) -- ^ The function @f@
27 -> a -- ^ The \"left\" endpoint, @a@
28 -> a -- ^ The \"right\" endpoint, @b@
29 -> Maybe a -- ^ The size of the smallest subinterval
30 -- we'll examine, @epsilon@
31 -> Maybe b -- ^ Precoumpted f(a)
32 -> Maybe b -- ^ Precoumpted f(b)
33 -> Bool
34 has_root f a b epsilon f_of_a f_of_b =
35 if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then
36 -- We don't care about epsilon here, there's definitely a root!
37 True
38 else
39 if (b - a) <= epsilon' then
40 -- Give up, return false.
41 False
42 else
43 -- If either [a,c] or [c,b] have roots, we do too.
44 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
45 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
46 where
47 -- If the size of the smallest subinterval is not specified,
48 -- assume we just want to check once on all of [a,b].
49 epsilon' = case epsilon of
50 Nothing -> (b-a)
51 Just eps -> eps
52
53 -- Compute f(a) and f(b) only if needed.
54 f_of_a' = case f_of_a of
55 Nothing -> f a
56 Just v -> v
57
58 f_of_b' = case f_of_b of
59 Nothing -> f b
60 Just v -> v
61
62 c = (a + b)/2
63
64
65 bisect :: (RealField.C a,
66 RealRing.C b,
67 Absolute.C b)
68 => (a -> b) -- ^ The function @f@ whose root we seek
69 -> a -- ^ The \"left\" endpoint of the interval, @a@
70 -> a -- ^ The \"right\" endpoint of the interval, @b@
71 -> a -- ^ The tolerance, @epsilon@
72 -> Maybe b -- ^ Precomputed f(a)
73 -> Maybe b -- ^ Precomputed f(b)
74 -> Maybe a
75 bisect f a b epsilon f_of_a f_of_b
76 -- We pass @epsilon@ to the 'has_root' function because if we want a
77 -- result within epsilon of the true root, we need to know that
78 -- there *is* a root within an interval of length epsilon.
79 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
80 | f_of_a' == 0 = Just a
81 | f_of_b' == 0 = Just b
82 | (b - c) < epsilon = Just c
83 | otherwise =
84 -- Use a 'prime' just for consistency.
85 let f_of_c' = f c in
86 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
87 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
88 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
89 where
90 -- Compute f(a) and f(b) only if needed.
91 f_of_a' = case f_of_a of
92 Nothing -> f a
93 Just v -> v
94
95 f_of_b' = case f_of_b of
96 Nothing -> f b
97 Just v -> v
98
99 c = (a + b) / 2
100
101
102
103
104 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
105 -- finding a fixed point.
106 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
107 -> a -- ^ The initial value @x0@.
108 -> [a] -- ^ The resulting sequence of x_{n}.
109 fixed_point_iterations f x0 =
110 iterate f x0
111
112
113 -- | Find a fixed point of the function @f@ with the search starting
114 -- at x0. This will find the first element in the chain f(x0),
115 -- f(f(x0)),... such that the magnitude of the difference between it
116 -- and the next element is less than epsilon.
117 --
118 -- We also return the number of iterations required.
119 --
120 fixed_point_with_iterations :: (Normed a,
121 Algebraic.C a,
122 RealField.C b,
123 Algebraic.C b)
124 => (a -> a) -- ^ The function @f@ to iterate.
125 -> b -- ^ The tolerance, @epsilon@.
126 -> a -- ^ The initial value @x0@.
127 -> (Int, a) -- ^ The (iterations, fixed point) pair
128 fixed_point_with_iterations f epsilon x0 =
129 (fst winning_pair)
130 where
131 xn = fixed_point_iterations f x0
132 xn_plus_one = tail xn
133
134 abs_diff v w = norm (v - w)
135
136 -- The nth entry in this list is the absolute value of x_{n} -
137 -- x_{n+1}.
138 differences = zipWith abs_diff xn xn_plus_one
139
140 -- This produces the list [(n, xn)] so that we can determine
141 -- the number of iterations required.
142 numbered_xn = zip [0..] xn
143
144 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
145 pairs = zip numbered_xn differences
146
147 -- The pair (xn, |x_{n} - x_{n+1}|) with
148 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
149 -- "safe" since the list is infinite. We'll succeed or loop
150 -- forever.
151 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs