{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} -- | Classical iterative methods to solve the system Ax = b. module Linear.Iteration where import Data.List (find) import Data.Maybe (fromJust) import Data.Vector.Fixed (Arity, N1, S) import NumericPrelude hiding ((*)) import qualified Algebra.Algebraic as Algebraic import qualified Algebra.Field as Field import qualified Algebra.RealField as RealField import qualified Algebra.ToRational as ToRational import qualified Prelude as P import Linear.Matrix import Linear.System import Normed -- | A generalized implementation for Jacobi, Gauss-Seidel, etc. All -- that we really need to know is how to construct the matrix M, so we -- take a function that does it as an argument. classical_iteration :: (Field.C a, Arity m) => (Mat m m a -> Mat m m a) -> Mat m m a -> Mat m N1 a -> Mat m N1 a -> Mat m N1 a classical_iteration m_function matrix b x_current = x_next where big_m = m_function matrix big_n = big_m - matrix rhs = big_n*x_current + b -- TODO: Should be solve below! M might not be lower-triangular. x_next = forward_substitute big_m rhs -- | Perform one iteration of successive over-relaxation. -- sor_iteration :: forall m a. (Field.C a, Arity m) => a -- ^ Omega -> Mat m m a -- ^ Matrix A -> Mat m N1 a -- ^ Vector b -> Mat m N1 a -- ^ Vector x_current -> Mat m N1 a -- ^ Output vector x_next sor_iteration omega = classical_iteration m_function where m_function :: Mat m m a -> Mat m m a m_function matrix = let diag = (recip omega) *> (diagonal_part matrix) lt = lt_part_strict matrix in diag + lt -- | Compute an infinite list of SOR iterations starting with the -- vector x0. sor_iterations :: (Field.C a, Arity m) => a -> Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a] sor_iterations omega matrix b = iterate (sor_iteration omega matrix b) -- | Perform one iteration of Gauss-Seidel. gauss_seidel_iteration :: (Field.C a, Arity m) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> Mat m N1 a gauss_seidel_iteration = sor_iteration one -- | Compute an infinite list of Gauss-Seidel iterations starting with -- the vector x0. gauss_seidel_iterations :: (Field.C a, Arity m) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a] gauss_seidel_iterations matrix b = iterate (gauss_seidel_iteration matrix b) -- | Perform one Jacobi iteration, -- -- x1 = M^(-1) * (N*x0 + b) -- -- Examples: -- -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double -- >>> let x0 = vec2d (0, 0::Double) -- >>> let b = vec2d (1, 1::Double) -- >>> jacobi_iteration m b x0 -- ((0.25),(0.5)) -- >>> let x1 = jacobi_iteration m b x0 -- >>> jacobi_iteration m b x1 -- ((0.0),(0.25)) -- jacobi_iteration :: (Field.C a, Arity m) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> Mat m N1 a jacobi_iteration = classical_iteration diagonal_part -- | Compute an infinite list of Jacobi iterations starting with the -- vector x0. jacobi_iterations :: (Field.C a, Arity m) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a] jacobi_iterations matrix b = iterate (jacobi_iteration matrix b) -- | Solve the system Ax = b using the Jacobi method. This will run -- forever if the iterations do not converge. -- -- Examples: -- -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double -- >>> let x0 = vec2d (0, 0::Double) -- >>> let b = vec2d (1, 1::Double) -- >>> let epsilon = 10**(-6) -- >>> jacobi_method m b x0 epsilon -- ((0.0),(0.4999995231628418)) -- jacobi_method :: (RealField.C a, Algebraic.C a, -- Normed instance ToRational.C a, -- Normed instance Algebraic.C b, RealField.C b, Arity m, Arity n, -- Normed instance m ~ S n) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> b -> Mat m N1 a jacobi_method = classical_method jacobi_iterations -- | Solve the system Ax = b using the Gauss-Seidel method. This will -- run forever if the iterations do not converge. -- -- Examples: -- -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double -- >>> let x0 = vec2d (0, 0::Double) -- >>> let b = vec2d (1, 1::Double) -- >>> let epsilon = 10**(-12) -- >>> gauss_seidel_method m b x0 epsilon -- ((4.547473508864641e-13),(0.49999999999954525)) -- gauss_seidel_method :: (RealField.C a, Algebraic.C a, -- Normed instance ToRational.C a, -- Normed instance Algebraic.C b, RealField.C b, Arity m, Arity n, -- Normed instance m ~ S n) => Mat m m a -> Mat m N1 a -> Mat m N1 a -> b -> Mat m N1 a gauss_seidel_method = classical_method gauss_seidel_iterations -- | Solve the system Ax = b using the Successive Over-Relaxation -- (SOR) method. This will run forever if the iterations do not -- converge. -- -- Examples: -- -- >>> let m = fromList [[4,2],[2,2]] :: Mat2 Double -- >>> let x0 = vec2d (0, 0::Double) -- >>> let b = vec2d (1, 1::Double) -- >>> let epsilon = 10**(-12) -- >>> sor_method 1.5 m b x0 epsilon -- ((6.567246746413957e-13),(0.4999999999993727)) -- sor_method :: (RealField.C a, Algebraic.C a, -- Normed instance ToRational.C a, -- Normed instance Algebraic.C b, RealField.C b, Arity m, Arity n, -- Normed instance m ~ S n) => a -> Mat m m a -> Mat m N1 a -> Mat m N1 a -> b -> Mat m N1 a sor_method omega = classical_method (sor_iterations omega) -- | General implementation for all classical iteration methods. For -- its first argument, it takes a function which generates the -- sequence of iterates when supplied with the remaining arguments -- (except for the tolerance). -- classical_method :: forall m n a b. (RealField.C a, Algebraic.C a, -- Normed instance ToRational.C a, -- Normed instance Algebraic.C b, RealField.C b, Arity m, Arity n, -- Normed instance m ~ S n) => (Mat m m a -> Mat m N1 a -> Mat m N1 a -> [Mat m N1 a]) -> Mat m m a -> Mat m N1 a -> Mat m N1 a -> b -> Mat m N1 a classical_method iterations_function matrix b x0 epsilon = -- fromJust is "safe," because the list is infinite. If the -- algorithm doesn't converge, 'find' will search forever and never -- return Nothing. fst' $ fromJust $ find error_small_enough diff_pairs where x_n = iterations_function matrix b x0 pairs :: [(Mat m N1 a, Mat m N1 a)] pairs = zip (tail x_n) x_n append_diff :: (Mat m N1 a, Mat m N1 a) -> (Mat m N1 a, Mat m N1 a, b) append_diff (cur,prev) = (cur,prev,diff) where diff = norm (cur - prev) diff_pairs :: [(Mat m N1 a, Mat m N1 a, b)] diff_pairs = map append_diff pairs fst' :: (c, d, e) -> c fst' (x,_,_) = x error_small_enough :: (Mat m N1 a, Mat m N1 a, b)-> Bool error_small_enough (_,_,err) = err < epsilon -- | Compute the Rayleigh quotient of @matrix@ and @vector@. -- -- Examples: -- -- >>> let m = fromList [[3,1],[1,2]] :: Mat2 Rational -- >>> let v = vec2d (1, 1::Rational) -- >>> rayleigh_quotient m v -- 7 % 2 -- rayleigh_quotient :: (RealField.C a, Arity m, Arity n, m ~ S n) => (Mat m m a) -> (Mat m N1 a) -> a rayleigh_quotient matrix vector = (vector `dot` (matrix * vector)) / (norm_squared vector) where -- We don't use the norm function here to avoid the algebraic -- requirement on our field. norm_squared v = ((transpose v) * v) !!! (0,0)