]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Roots/Fast.hs
Update numeric-prelude and fixed-vector.
[numerical-analysis.git] / src / Roots / Fast.hs
1 {-# LANGUAGE RebindableSyntax #-}
2
3 -- | The Roots.Fast module contains faster implementations of the
4 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
5 -- values to the next iteration of a function rather than passing
6 -- the function and the points at which to (re)evaluate it.
7
8 module Roots.Fast
9 where
10
11 import Data.List (find)
12
13 import Normed
14
15 import NumericPrelude hiding (abs)
16 import qualified Algebra.Absolute as Absolute
17 import qualified Algebra.Additive as Additive
18 import qualified Algebra.Algebraic as Algebraic
19 import qualified Algebra.RealRing as RealRing
20 import qualified Algebra.RealField as RealField
21
22 has_root :: (RealField.C a,
23 RealRing.C b,
24 Absolute.C b)
25 => (a -> b) -- ^ The function @f@
26 -> a -- ^ The \"left\" endpoint, @a@
27 -> a -- ^ The \"right\" endpoint, @b@
28 -> Maybe a -- ^ The size of the smallest subinterval
29 -- we'll examine, @epsilon@
30 -> Maybe b -- ^ Precoumpted f(a)
31 -> Maybe b -- ^ Precoumpted f(b)
32 -> Bool
33 has_root f a b epsilon f_of_a f_of_b =
34 if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then
35 -- We don't care about epsilon here, there's definitely a root!
36 True
37 else
38 if (b - a) <= epsilon' then
39 -- Give up, return false.
40 False
41 else
42 -- If either [a,c] or [c,b] have roots, we do too.
43 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
44 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
45 where
46 -- If the size of the smallest subinterval is not specified,
47 -- assume we just want to check once on all of [a,b].
48 epsilon' = case epsilon of
49 Nothing -> (b-a)
50 Just eps -> eps
51
52 -- Compute f(a) and f(b) only if needed.
53 f_of_a' = case f_of_a of
54 Nothing -> f a
55 Just v -> v
56
57 f_of_b' = case f_of_b of
58 Nothing -> f b
59 Just v -> v
60
61 c = (a + b)/2
62
63
64 bisect :: (RealField.C a,
65 RealRing.C b,
66 Absolute.C b)
67 => (a -> b) -- ^ The function @f@ whose root we seek
68 -> a -- ^ The \"left\" endpoint of the interval, @a@
69 -> a -- ^ The \"right\" endpoint of the interval, @b@
70 -> a -- ^ The tolerance, @epsilon@
71 -> Maybe b -- ^ Precomputed f(a)
72 -> Maybe b -- ^ Precomputed f(b)
73 -> Maybe a
74 bisect f a b epsilon f_of_a f_of_b
75 -- We pass @epsilon@ to the 'has_root' function because if we want a
76 -- result within epsilon of the true root, we need to know that
77 -- there *is* a root within an interval of length epsilon.
78 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
79 | f_of_a' == 0 = Just a
80 | f_of_b' == 0 = Just b
81 | (b - c) < epsilon = Just c
82 | otherwise =
83 -- Use a 'prime' just for consistency.
84 let f_of_c' = f c in
85 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
86 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
87 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
88 where
89 -- Compute f(a) and f(b) only if needed.
90 f_of_a' = case f_of_a of
91 Nothing -> f a
92 Just v -> v
93
94 f_of_b' = case f_of_b of
95 Nothing -> f b
96 Just v -> v
97
98 c = (a + b) / 2
99
100
101
102 trisect :: (RealField.C a,
103 RealRing.C b,
104 Absolute.C b)
105 => (a -> b) -- ^ The function @f@ whose root we seek
106 -> a -- ^ The \"left\" endpoint of the interval, @a@
107 -> a -- ^ The \"right\" endpoint of the interval, @b@
108 -> a -- ^ The tolerance, @epsilon@
109 -> Maybe b -- ^ Precomputed f(a)
110 -> Maybe b -- ^ Precomputed f(b)
111 -> Maybe a
112 trisect f a b epsilon f_of_a f_of_b
113 -- We pass @epsilon@ to the 'has_root' function because if we want a
114 -- result within epsilon of the true root, we need to know that
115 -- there *is* a root within an interval of length epsilon.
116 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
117 | f_of_a' == 0 = Just a
118 | f_of_b' == 0 = Just b
119 | otherwise =
120 -- Use a 'prime' just for consistency.
121 let (a', b', fa', fb') =
122 if (has_root f d b (Just epsilon) (Just f_of_d') (Just f_of_b'))
123 then (d, b, f_of_d', f_of_b')
124 else
125 if (has_root f c d (Just epsilon) (Just f_of_c') (Just f_of_d'))
126 then (c, d, f_of_c', f_of_d')
127 else (a, c, f_of_a', f_of_c')
128 in
129 if (b-a) < 2*epsilon
130 then Just ((b+a)/2)
131 else trisect f a' b' epsilon (Just fa') (Just fb')
132 where
133 -- Compute f(a) and f(b) only if needed.
134 f_of_a' = case f_of_a of
135 Nothing -> f a
136 Just v -> v
137
138 f_of_b' = case f_of_b of
139 Nothing -> f b
140 Just v -> v
141
142 c = (2*a + b) / 3
143
144 d = (a + 2*b) / 3
145
146 f_of_c' = f c
147 f_of_d' = f d
148
149
150
151 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
152 -- finding a fixed point.
153 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
154 -> a -- ^ The initial value @x0@.
155 -> [a] -- ^ The resulting sequence of x_{n}.
156 fixed_point_iterations f x0 =
157 iterate f x0
158
159
160 -- | Find a fixed point of the function @f@ with the search starting
161 -- at x0. This will find the first element in the chain f(x0),
162 -- f(f(x0)),... such that the magnitude of the difference between it
163 -- and the next element is less than epsilon.
164 --
165 -- We also return the number of iterations required.
166 --
167 fixed_point_with_iterations :: (Normed a,
168 Additive.C a,
169 RealField.C b,
170 Algebraic.C b)
171 => (a -> a) -- ^ The function @f@ to iterate.
172 -> b -- ^ The tolerance, @epsilon@.
173 -> a -- ^ The initial value @x0@.
174 -> (Int, a) -- ^ The (iterations, fixed point) pair
175 fixed_point_with_iterations f epsilon x0 =
176 (fst winning_pair)
177 where
178 xn = fixed_point_iterations f x0
179 xn_plus_one = tail xn
180
181 abs_diff v w = norm (v - w)
182
183 -- The nth entry in this list is the absolute value of x_{n} -
184 -- x_{n+1}.
185 differences = zipWith abs_diff xn xn_plus_one
186
187 -- This produces the list [(n, xn)] so that we can determine
188 -- the number of iterations required.
189 numbered_xn = zip [0..] xn
190
191 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
192 pairs = zip numbered_xn differences
193
194 -- The pair (xn, |x_{n} - x_{n+1}|) with
195 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
196 -- "safe" since the list is infinite. We'll succeed or loop
197 -- forever.
198 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs