]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Roots/Fast.hs
src/Roots/Fast.hs: fix monomorphism restriction warnings.
[numerical-analysis.git] / src / Roots / Fast.hs
1 {-# LANGUAGE RebindableSyntax #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3
4 -- | The Roots.Fast module contains faster implementations of the
5 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
6 -- values to the next iteration of a function rather than passing
7 -- the function and the points at which to (re)evaluate it.
8
9 module Roots.Fast (
10 bisect,
11 fixed_point_iterations,
12 fixed_point_with_iterations,
13 has_root,
14 trisect )
15 where
16
17 import Data.List ( find )
18 import Data.Maybe ( fromMaybe )
19
20 import Normed ( Normed(..) )
21
22 import NumericPrelude hiding ( abs )
23 import qualified Algebra.Absolute as Absolute ( C )
24 import qualified Algebra.Additive as Additive ( C )
25 import qualified Algebra.Algebraic as Algebraic ( C )
26 import qualified Algebra.RealRing as RealRing ( C )
27 import qualified Algebra.RealField as RealField ( C )
28
29
30 has_root :: (RealField.C a,
31 RealRing.C b,
32 Absolute.C b)
33 => (a -> b) -- ^ The function @f@
34 -> a -- ^ The \"left\" endpoint, @a@
35 -> a -- ^ The \"right\" endpoint, @b@
36 -> Maybe a -- ^ The size of the smallest subinterval
37 -- we'll examine, @epsilon@
38 -> Maybe b -- ^ Precoumpted f(a)
39 -> Maybe b -- ^ Precoumpted f(b)
40 -> Bool
41 has_root f a b epsilon f_of_a f_of_b
42 | (signum (f_of_a')) * (signum (f_of_b')) /= 1 = True
43 | (b - a) <= epsilon' = False
44 | otherwise =
45 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
46 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
47 where
48 -- If the size of the smallest subinterval is not specified,
49 -- assume we just want to check once on all of [a,b].
50 epsilon' = fromMaybe (b-a) epsilon
51
52 -- Compute f(a) and f(b) only if needed.
53 f_of_a' = fromMaybe (f a) f_of_a
54 f_of_b' = fromMaybe (f b) f_of_b
55
56 c = (a + b)/2
57
58
59 bisect :: (RealField.C a,
60 RealRing.C b,
61 Absolute.C b)
62 => (a -> b) -- ^ The function @f@ whose root we seek
63 -> a -- ^ The \"left\" endpoint of the interval, @a@
64 -> a -- ^ The \"right\" endpoint of the interval, @b@
65 -> a -- ^ The tolerance, @epsilon@
66 -> Maybe b -- ^ Precomputed f(a)
67 -> Maybe b -- ^ Precomputed f(b)
68 -> Maybe a
69 bisect f a b epsilon f_of_a f_of_b
70 -- We pass @epsilon@ to the 'has_root' function because if we want a
71 -- result within epsilon of the true root, we need to know that
72 -- there *is* a root within an interval of length epsilon.
73 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
74 | f_of_a' == 0 = Just a
75 | f_of_b' == 0 = Just b
76 | (b - c) < epsilon = Just c
77 | otherwise =
78 -- Use a 'prime' just for consistency.
79 let f_of_c' = f c in
80 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
81 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
82 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
83 where
84 -- Compute f(a) and f(b) only if needed.
85 f_of_a' = fromMaybe (f a) f_of_a
86 f_of_b' = fromMaybe (f b) f_of_b
87
88 c = (a + b) / 2
89
90
91
92 trisect :: (RealField.C a,
93 RealRing.C b,
94 Absolute.C b)
95 => (a -> b) -- ^ The function @f@ whose root we seek
96 -> a -- ^ The \"left\" endpoint of the interval, @a@
97 -> a -- ^ The \"right\" endpoint of the interval, @b@
98 -> a -- ^ The tolerance, @epsilon@
99 -> Maybe b -- ^ Precomputed f(a)
100 -> Maybe b -- ^ Precomputed f(b)
101 -> Maybe a
102 trisect f a b epsilon f_of_a f_of_b
103 -- We pass @epsilon@ to the 'has_root' function because if we want a
104 -- result within epsilon of the true root, we need to know that
105 -- there *is* a root within an interval of length epsilon.
106 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
107 | f_of_a' == 0 = Just a
108 | f_of_b' == 0 = Just b
109 | otherwise =
110 -- Use a 'prime' just for consistency.
111 let (a', b', fa', fb')
112 | has_root f d b (Just epsilon) (Just f_of_d') (Just f_of_b') =
113 (d, b, f_of_d', f_of_b')
114 | has_root f c d (Just epsilon) (Just f_of_c') (Just f_of_d') =
115 (c, d, f_of_c', f_of_d')
116 | otherwise =
117 (a, c, f_of_a', f_of_c')
118 in
119 if (b-a) < 2*epsilon
120 then Just ((b+a)/2)
121 else trisect f a' b' epsilon (Just fa') (Just fb')
122 where
123 -- Compute f(a) and f(b) only if needed.
124 f_of_a' = fromMaybe (f a) f_of_a
125 f_of_b' = fromMaybe (f b) f_of_b
126
127 c = (2*a + b) / 3
128
129 d = (a + 2*b) / 3
130
131 f_of_c' = f c
132 f_of_d' = f d
133
134
135
136 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
137 -- finding a fixed point.
138 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
139 -> a -- ^ The initial value @x0@.
140 -> [a] -- ^ The resulting sequence of x_{n}.
141 fixed_point_iterations =
142 iterate
143
144
145 -- | Find a fixed point of the function @f@ with the search starting
146 -- at x0. This will find the first element in the chain f(x0),
147 -- f(f(x0)),... such that the magnitude of the difference between it
148 -- and the next element is less than epsilon.
149 --
150 -- We also return the number of iterations required.
151 --
152 fixed_point_with_iterations :: forall a b. (Normed a,
153 Additive.C a,
154 RealField.C b,
155 Algebraic.C b)
156 => (a -> a) -- ^ The function @f@ to iterate.
157 -> b -- ^ The tolerance, @epsilon@.
158 -> a -- ^ The initial value @x0@.
159 -> (Int, a) -- ^ The (iterations, fixed point) pair
160 fixed_point_with_iterations f epsilon x0 =
161 (fst winning_pair)
162 where
163 xn = fixed_point_iterations f x0
164 xn_plus_one = tail xn
165
166 abs_diff v w = norm (v - w)
167
168 -- The nth entry in this list is the absolute value of x_{n} -
169 -- x_{n+1}. They're of type "b" because we're going to compare
170 -- them against epsilon.
171 differences = zipWith abs_diff xn xn_plus_one :: [b]
172
173 -- This produces the list [(n, xn)] so that we can determine
174 -- the number of iterations required.
175 numbered_xn = zip [0..] xn :: [(Int,a)]
176
177 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
178 pairs = zip numbered_xn differences
179
180 -- The pair (xn, |x_{n} - x_{n+1}|) with
181 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
182 -- "safe" since the list is infinite. We'll succeed or loop
183 -- forever.
184 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs