]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Roots/Fast.hs
Clean up imports everywhere.
[numerical-analysis.git] / src / Roots / Fast.hs
1 {-# LANGUAGE RebindableSyntax #-}
2
3 -- | The Roots.Fast module contains faster implementations of the
4 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
5 -- values to the next iteration of a function rather than passing
6 -- the function and the points at which to (re)evaluate it.
7
8 module Roots.Fast (
9 bisect,
10 fixed_point_iterations,
11 fixed_point_with_iterations,
12 has_root,
13 trisect )
14 where
15
16 import Data.List ( find )
17 import Data.Maybe ( fromMaybe )
18
19 import Normed ( Normed(..) )
20
21 import NumericPrelude hiding ( abs )
22 import qualified Algebra.Absolute as Absolute ( C )
23 import qualified Algebra.Additive as Additive ( C )
24 import qualified Algebra.Algebraic as Algebraic ( C )
25 import qualified Algebra.RealRing as RealRing ( C )
26 import qualified Algebra.RealField as RealField ( C )
27
28
29 has_root :: (RealField.C a,
30 RealRing.C b,
31 Absolute.C b)
32 => (a -> b) -- ^ The function @f@
33 -> a -- ^ The \"left\" endpoint, @a@
34 -> a -- ^ The \"right\" endpoint, @b@
35 -> Maybe a -- ^ The size of the smallest subinterval
36 -- we'll examine, @epsilon@
37 -> Maybe b -- ^ Precoumpted f(a)
38 -> Maybe b -- ^ Precoumpted f(b)
39 -> Bool
40 has_root f a b epsilon f_of_a f_of_b
41 | (signum (f_of_a')) * (signum (f_of_b')) /= 1 = True
42 | (b - a) <= epsilon' = False
43 | otherwise =
44 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
45 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
46 where
47 -- If the size of the smallest subinterval is not specified,
48 -- assume we just want to check once on all of [a,b].
49 epsilon' = fromMaybe (b-a) epsilon
50
51 -- Compute f(a) and f(b) only if needed.
52 f_of_a' = fromMaybe (f a) f_of_a
53 f_of_b' = fromMaybe (f b) f_of_b
54
55 c = (a + b)/2
56
57
58 bisect :: (RealField.C a,
59 RealRing.C b,
60 Absolute.C b)
61 => (a -> b) -- ^ The function @f@ whose root we seek
62 -> a -- ^ The \"left\" endpoint of the interval, @a@
63 -> a -- ^ The \"right\" endpoint of the interval, @b@
64 -> a -- ^ The tolerance, @epsilon@
65 -> Maybe b -- ^ Precomputed f(a)
66 -> Maybe b -- ^ Precomputed f(b)
67 -> Maybe a
68 bisect f a b epsilon f_of_a f_of_b
69 -- We pass @epsilon@ to the 'has_root' function because if we want a
70 -- result within epsilon of the true root, we need to know that
71 -- there *is* a root within an interval of length epsilon.
72 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
73 | f_of_a' == 0 = Just a
74 | f_of_b' == 0 = Just b
75 | (b - c) < epsilon = Just c
76 | otherwise =
77 -- Use a 'prime' just for consistency.
78 let f_of_c' = f c in
79 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
80 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
81 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
82 where
83 -- Compute f(a) and f(b) only if needed.
84 f_of_a' = fromMaybe (f a) f_of_a
85 f_of_b' = fromMaybe (f b) f_of_b
86
87 c = (a + b) / 2
88
89
90
91 trisect :: (RealField.C a,
92 RealRing.C b,
93 Absolute.C b)
94 => (a -> b) -- ^ The function @f@ whose root we seek
95 -> a -- ^ The \"left\" endpoint of the interval, @a@
96 -> a -- ^ The \"right\" endpoint of the interval, @b@
97 -> a -- ^ The tolerance, @epsilon@
98 -> Maybe b -- ^ Precomputed f(a)
99 -> Maybe b -- ^ Precomputed f(b)
100 -> Maybe a
101 trisect f a b epsilon f_of_a f_of_b
102 -- We pass @epsilon@ to the 'has_root' function because if we want a
103 -- result within epsilon of the true root, we need to know that
104 -- there *is* a root within an interval of length epsilon.
105 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
106 | f_of_a' == 0 = Just a
107 | f_of_b' == 0 = Just b
108 | otherwise =
109 -- Use a 'prime' just for consistency.
110 let (a', b', fa', fb')
111 | has_root f d b (Just epsilon) (Just f_of_d') (Just f_of_b') =
112 (d, b, f_of_d', f_of_b')
113 | has_root f c d (Just epsilon) (Just f_of_c') (Just f_of_d') =
114 (c, d, f_of_c', f_of_d')
115 | otherwise =
116 (a, c, f_of_a', f_of_c')
117 in
118 if (b-a) < 2*epsilon
119 then Just ((b+a)/2)
120 else trisect f a' b' epsilon (Just fa') (Just fb')
121 where
122 -- Compute f(a) and f(b) only if needed.
123 f_of_a' = fromMaybe (f a) f_of_a
124 f_of_b' = fromMaybe (f b) f_of_b
125
126 c = (2*a + b) / 3
127
128 d = (a + 2*b) / 3
129
130 f_of_c' = f c
131 f_of_d' = f d
132
133
134
135 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
136 -- finding a fixed point.
137 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
138 -> a -- ^ The initial value @x0@.
139 -> [a] -- ^ The resulting sequence of x_{n}.
140 fixed_point_iterations =
141 iterate
142
143
144 -- | Find a fixed point of the function @f@ with the search starting
145 -- at x0. This will find the first element in the chain f(x0),
146 -- f(f(x0)),... such that the magnitude of the difference between it
147 -- and the next element is less than epsilon.
148 --
149 -- We also return the number of iterations required.
150 --
151 fixed_point_with_iterations :: (Normed a,
152 Additive.C a,
153 RealField.C b,
154 Algebraic.C b)
155 => (a -> a) -- ^ The function @f@ to iterate.
156 -> b -- ^ The tolerance, @epsilon@.
157 -> a -- ^ The initial value @x0@.
158 -> (Int, a) -- ^ The (iterations, fixed point) pair
159 fixed_point_with_iterations f epsilon x0 =
160 (fst winning_pair)
161 where
162 xn = fixed_point_iterations f x0
163 xn_plus_one = tail xn
164
165 abs_diff v w = norm (v - w)
166
167 -- The nth entry in this list is the absolute value of x_{n} -
168 -- x_{n+1}.
169 differences = zipWith abs_diff xn xn_plus_one
170
171 -- This produces the list [(n, xn)] so that we can determine
172 -- the number of iterations required.
173 numbered_xn = zip [0..] xn
174
175 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
176 pairs = zip numbered_xn differences
177
178 -- The pair (xn, |x_{n} - x_{n+1}|) with
179 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
180 -- "safe" since the list is infinite. We'll succeed or loop
181 -- forever.
182 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs