]> gitweb.michael.orlitzky.com - numerical-analysis.git/blob - src/Roots/Fast.hs
Fix a "norm" in Roots.Fast.
[numerical-analysis.git] / src / Roots / Fast.hs
1 -- | The Roots.Fast module contains faster implementations of the
2 -- 'Roots.Simple' algorithms. Generally, we will pass precomputed
3 -- values to the next iteration of a function rather than passing
4 -- the function and the points at which to (re)evaluate it.
5
6 module Roots.Fast
7 where
8
9 import Data.List (find)
10
11 import Vector
12
13
14 has_root :: (Fractional a, Ord a, Ord b, Num b)
15 => (a -> b) -- ^ The function @f@
16 -> a -- ^ The \"left\" endpoint, @a@
17 -> a -- ^ The \"right\" endpoint, @b@
18 -> Maybe a -- ^ The size of the smallest subinterval
19 -- we'll examine, @epsilon@
20 -> Maybe b -- ^ Precoumpted f(a)
21 -> Maybe b -- ^ Precoumpted f(b)
22 -> Bool
23 has_root f a b epsilon f_of_a f_of_b =
24 if not ((signum (f_of_a')) * (signum (f_of_b')) == 1) then
25 -- We don't care about epsilon here, there's definitely a root!
26 True
27 else
28 if (b - a) <= epsilon' then
29 -- Give up, return false.
30 False
31 else
32 -- If either [a,c] or [c,b] have roots, we do too.
33 (has_root f a c (Just epsilon') (Just f_of_a') Nothing) ||
34 (has_root f c b (Just epsilon') Nothing (Just f_of_b'))
35 where
36 -- If the size of the smallest subinterval is not specified,
37 -- assume we just want to check once on all of [a,b].
38 epsilon' = case epsilon of
39 Nothing -> (b-a)
40 Just eps -> eps
41
42 -- Compute f(a) and f(b) only if needed.
43 f_of_a' = case f_of_a of
44 Nothing -> f a
45 Just v -> v
46
47 f_of_b' = case f_of_b of
48 Nothing -> f b
49 Just v -> v
50
51 c = (a + b)/2
52
53
54
55 bisect :: (Fractional a, Ord a, Num b, Ord b)
56 => (a -> b) -- ^ The function @f@ whose root we seek
57 -> a -- ^ The \"left\" endpoint of the interval, @a@
58 -> a -- ^ The \"right\" endpoint of the interval, @b@
59 -> a -- ^ The tolerance, @epsilon@
60 -> Maybe b -- ^ Precomputed f(a)
61 -> Maybe b -- ^ Precomputed f(b)
62 -> Maybe a
63 bisect f a b epsilon f_of_a f_of_b
64 -- We pass @epsilon@ to the 'has_root' function because if we want a
65 -- result within epsilon of the true root, we need to know that
66 -- there *is* a root within an interval of length epsilon.
67 | not (has_root f a b (Just epsilon) (Just f_of_a') (Just f_of_b')) = Nothing
68 | f_of_a' == 0 = Just a
69 | f_of_b' == 0 = Just b
70 | (b - c) < epsilon = Just c
71 | otherwise =
72 -- Use a 'prime' just for consistency.
73 let f_of_c' = f c in
74 if (has_root f a c (Just epsilon) (Just f_of_a') (Just f_of_c'))
75 then bisect f a c epsilon (Just f_of_a') (Just f_of_c')
76 else bisect f c b epsilon (Just f_of_c') (Just f_of_b')
77 where
78 -- Compute f(a) and f(b) only if needed.
79 f_of_a' = case f_of_a of
80 Nothing -> f a
81 Just v -> v
82
83 f_of_b' = case f_of_b of
84 Nothing -> f b
85 Just v -> v
86
87 c = (a + b) / 2
88
89
90
91 -- | Iterate the function @f@ with the initial guess @x0@ in hopes of
92 -- finding a fixed point.
93 fixed_point_iterations :: (a -> a) -- ^ The function @f@ to iterate.
94 -> a -- ^ The initial value @x0@.
95 -> [a] -- ^ The resulting sequence of x_{n}.
96 fixed_point_iterations f x0 =
97 iterate f x0
98
99
100 -- | Find a fixed point of the function @f@ with the search starting
101 -- at x0. This will find the first element in the chain f(x0),
102 -- f(f(x0)),... such that the magnitude of the difference between it
103 -- and the next element is less than epsilon.
104 --
105 -- We also return the number of iterations required.
106 --
107 fixed_point_with_iterations :: (Vector a, RealFrac b)
108 => (a -> a) -- ^ The function @f@ to iterate.
109 -> b -- ^ The tolerance, @epsilon@.
110 -> a -- ^ The initial value @x0@.
111 -> (Int, a) -- ^ The (iterations, fixed point) pair
112 fixed_point_with_iterations f epsilon x0 =
113 (fst winning_pair)
114 where
115 xn = fixed_point_iterations f x0
116 xn_plus_one = tail xn
117
118 abs_diff v w = norm_2 (v - w)
119
120 -- The nth entry in this list is the absolute value of x_{n} -
121 -- x_{n+1}.
122 differences = zipWith abs_diff xn xn_plus_one
123
124 -- This produces the list [(n, xn)] so that we can determine
125 -- the number of iterations required.
126 numbered_xn = zip [0..] xn
127
128 -- A list of pairs, (xn, |x_{n} - x_{n+1}|).
129 pairs = zip numbered_xn differences
130
131 -- The pair (xn, |x_{n} - x_{n+1}|) with
132 -- |x_{n} - x_{n+1}| < epsilon. The pattern match on 'Just' is
133 -- "safe" since the list is infinite. We'll succeed or loop
134 -- forever.
135 Just winning_pair = find (\(_, diff) -> diff < epsilon) pairs
136