$darkmode
Eigen  5.0.1-dev
IterativeSolverBase.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
11 #define EIGEN_ITERATIVE_SOLVER_BASE_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <typename MatrixType>
21 struct is_ref_compatible_impl {
22  private:
23  template <typename T0>
24  struct any_conversion {
25  template <typename T>
26  any_conversion(const volatile T&);
27  template <typename T>
28  any_conversion(T&);
29  };
30  struct yes {
31  int a[1];
32  };
33  struct no {
34  int a[2];
35  };
36 
37  template <typename T>
38  static yes test(const Ref<const T>&, int);
39  template <typename T>
40  static no test(any_conversion<T>, ...);
41 
42  public:
43  static MatrixType ms_from;
44  enum { value = sizeof(test<MatrixType>(ms_from, 0)) == sizeof(yes) };
45 };
46 
47 template <typename MatrixType>
48 struct is_ref_compatible {
49  enum { value = is_ref_compatible_impl<remove_all_t<MatrixType>>::value };
50 };
51 
52 template <typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
53 class generic_matrix_wrapper;
54 
55 // We have an explicit matrix at hand, compatible with Ref<>
56 template <typename MatrixType>
57 class generic_matrix_wrapper<MatrixType, false> {
58  public:
59  typedef Ref<const MatrixType> ActualMatrixType;
60  template <int UpLo>
61  struct ConstSelfAdjointViewReturnType {
62  typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
63  };
64 
65  enum { MatrixFree = false };
66 
67  generic_matrix_wrapper() : m_dummy(0, 0), m_matrix(m_dummy) {}
68 
69  template <typename InputType>
70  generic_matrix_wrapper(const InputType& mat) : m_matrix(mat) {}
71 
72  const ActualMatrixType& matrix() const { return m_matrix; }
73 
74  template <typename MatrixDerived>
75  void grab(const EigenBase<MatrixDerived>& mat) {
76  internal::destroy_at(&m_matrix);
77  internal::construct_at(&m_matrix, mat.derived());
78  }
79 
80  void grab(const Ref<const MatrixType>& mat) {
81  if (&(mat.derived()) != &m_matrix) {
82  internal::destroy_at(&m_matrix);
83  internal::construct_at(&m_matrix, mat);
84  }
85  }
86 
87  protected:
88  MatrixType m_dummy; // used to default initialize the Ref<> object
89  ActualMatrixType m_matrix;
90 };
91 
92 // MatrixType is not compatible with Ref<> -> matrix-free wrapper
93 template <typename MatrixType>
94 class generic_matrix_wrapper<MatrixType, true> {
95  public:
96  typedef MatrixType ActualMatrixType;
97  template <int UpLo>
98  struct ConstSelfAdjointViewReturnType {
99  typedef ActualMatrixType Type;
100  };
101 
102  enum { MatrixFree = true };
103 
104  generic_matrix_wrapper() : mp_matrix(0) {}
105 
106  generic_matrix_wrapper(const MatrixType& mat) : mp_matrix(&mat) {}
107 
108  const ActualMatrixType& matrix() const { return *mp_matrix; }
109 
110  void grab(const MatrixType& mat) { mp_matrix = &mat; }
111 
112  protected:
113  const ActualMatrixType* mp_matrix;
114 };
115 
116 } // namespace internal
117 
123 template <typename Derived>
124 class IterativeSolverBase : public SparseSolverBase<Derived> {
125  protected:
127  using Base::m_isInitialized;
128 
129  public:
130  typedef typename internal::traits<Derived>::MatrixType MatrixType;
131  typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
132  typedef typename MatrixType::Scalar Scalar;
133  typedef typename MatrixType::StorageIndex StorageIndex;
134  typedef typename MatrixType::RealScalar RealScalar;
135 
136  enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };
137 
138  public:
139  using Base::derived;
140 
142  IterativeSolverBase() { init(); }
143 
154  template <typename MatrixDerived>
155  explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A) : m_matrixWrapper(A.derived()) {
156  init();
157  compute(matrix());
158  }
159 
161 
162  ~IterativeSolverBase() {}
163 
169  template <typename MatrixDerived>
171  grab(A.derived());
172  m_preconditioner.analyzePattern(matrix());
173  m_isInitialized = true;
174  m_analysisIsOk = true;
175  m_info = m_preconditioner.info();
176  return derived();
177  }
178 
189  template <typename MatrixDerived>
190  Derived& factorize(const EigenBase<MatrixDerived>& A) {
191  eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
192  grab(A.derived());
193  m_preconditioner.factorize(matrix());
194  m_factorizationIsOk = true;
195  m_info = m_preconditioner.info();
196  return derived();
197  }
198 
209  template <typename MatrixDerived>
210  Derived& compute(const EigenBase<MatrixDerived>& A) {
211  grab(A.derived());
212  m_preconditioner.compute(matrix());
213  m_isInitialized = true;
214  m_analysisIsOk = true;
215  m_factorizationIsOk = true;
216  m_info = m_preconditioner.info();
217  return derived();
218  }
219 
221  constexpr Index rows() const noexcept { return matrix().rows(); }
222 
224  constexpr Index cols() const noexcept { return matrix().cols(); }
225 
229  RealScalar tolerance() const { return m_tolerance; }
230 
236  Derived& setTolerance(const RealScalar& tolerance) {
237  m_tolerance = tolerance;
238  return derived();
239  }
240 
242  Preconditioner& preconditioner() { return m_preconditioner; }
243 
245  const Preconditioner& preconditioner() const { return m_preconditioner; }
246 
251  Index maxIterations() const { return (m_maxIterations < 0) ? 2 * matrix().cols() : m_maxIterations; }
252 
256  Derived& setMaxIterations(Index maxIters) {
257  m_maxIterations = maxIters;
258  return derived();
259  }
260 
262  Index iterations() const {
263  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
264  return m_iterations;
265  }
266 
270  RealScalar error() const {
271  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
272  return m_error;
273  }
274 
280  template <typename Rhs, typename Guess>
281  inline const SolveWithGuess<Derived, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const {
282  eigen_assert(m_isInitialized && "Solver is not initialized.");
283  eigen_assert(derived().rows() == b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
284  return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
285  }
286 
289  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
290  return m_info;
291  }
292 
294  template <typename Rhs, typename DestDerived>
295  void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived>& aDest) const {
296  eigen_assert(rows() == b.rows());
297 
298  Index rhsCols = b.cols();
299  Index size = b.rows();
300  DestDerived& dest(aDest.derived());
301  typedef typename DestDerived::Scalar DestScalar;
304  // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
305  // For non square least-square problems, b and dest might not have the same size whereas they might alias
306  // each-other.
307  typename DestDerived::PlainObject tmp(cols(), rhsCols);
308  ComputationInfo global_info = Success;
309  for (Index k = 0; k < rhsCols; ++k) {
310  tb = b.col(k);
311  tx = dest.col(k);
312  derived()._solve_vector_with_guess_impl(tb, tx);
313  tmp.col(k) = tx.sparseView(0);
314 
315  // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
316  // we need to restore it to the worst value.
317  if (m_info == NumericalIssue)
318  global_info = NumericalIssue;
319  else if (m_info == NoConvergence)
320  global_info = NoConvergence;
321  }
322  m_info = global_info;
323  dest.swap(tmp);
324  }
325 
326  template <typename Rhs, typename DestDerived>
327  std::enable_if_t<Rhs::ColsAtCompileTime != 1 && DestDerived::ColsAtCompileTime != 1> _solve_with_guess_impl(
328  const Rhs& b, MatrixBase<DestDerived>& aDest) const {
329  eigen_assert(rows() == b.rows());
330 
331  Index rhsCols = b.cols();
332  DestDerived& dest(aDest.derived());
333  ComputationInfo global_info = Success;
334  for (Index k = 0; k < rhsCols; ++k) {
335  typename DestDerived::ColXpr xk(dest, k);
336  typename Rhs::ConstColXpr bk(b, k);
337  derived()._solve_vector_with_guess_impl(bk, xk);
338 
339  // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
340  // we need to restore it to the worst value.
341  if (m_info == NumericalIssue)
342  global_info = NumericalIssue;
343  else if (m_info == NoConvergence)
344  global_info = NoConvergence;
345  }
346  m_info = global_info;
347  }
348 
349  template <typename Rhs, typename DestDerived>
350  std::enable_if_t<Rhs::ColsAtCompileTime == 1 || DestDerived::ColsAtCompileTime == 1> _solve_with_guess_impl(
351  const Rhs& b, MatrixBase<DestDerived>& dest) const {
352  derived()._solve_vector_with_guess_impl(b, dest.derived());
353  }
354 
356  template <typename Rhs, typename Dest>
357  void _solve_impl(const Rhs& b, Dest& x) const {
358  x.setZero();
359  derived()._solve_with_guess_impl(b, x);
360  }
361 
362  protected:
363  void init() {
364  m_isInitialized = false;
365  m_analysisIsOk = false;
366  m_factorizationIsOk = false;
367  m_maxIterations = -1;
368  m_tolerance = NumTraits<Scalar>::epsilon();
369  }
370 
371  typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
372  typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
373 
374  const ActualMatrixType& matrix() const { return m_matrixWrapper.matrix(); }
375 
376  template <typename InputType>
377  void grab(const InputType& A) {
378  m_matrixWrapper.grab(A);
379  }
380 
381  MatrixWrapper m_matrixWrapper;
382  Preconditioner m_preconditioner;
383 
384  Index m_maxIterations;
385  RealScalar m_tolerance;
386 
387  mutable RealScalar m_error;
388  mutable Index m_iterations;
389  mutable ComputationInfo m_info;
390  mutable bool m_analysisIsOk, m_factorizationIsOk;
391 };
392 
393 } // end namespace Eigen
394 
395 #endif // EIGEN_ITERATIVE_SOLVER_BASE_H
constexpr Derived & derived()
Definition: EigenBase.h:49
Pseudo expression representing a solving operation.
Definition: SolveWithGuess.h:19
Derived & setTolerance(const RealScalar &tolerance)
Definition: IterativeSolverBase.h:236
ComputationInfo info() const
Definition: IterativeSolverBase.h:288
IterativeSolverBase(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:155
A base class for sparse solvers.
Definition: SparseSolverBase.h:67
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
const Preconditioner & preconditioner() const
Definition: IterativeSolverBase.h:245
Index iterations() const
Definition: IterativeSolverBase.h:262
Derived & factorize(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:190
RealScalar tolerance() const
Definition: IterativeSolverBase.h:229
Definition: EigenBase.h:33
RealScalar error() const
Definition: IterativeSolverBase.h:270
Base class of any sparse matrices or sparse expressions.
Definition: ForwardDeclarations.h:481
Derived & analyzePattern(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:170
Derived & compute(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:210
Definition: Constants.h:442
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const SolveWithGuess< Derived, Rhs, Guess > solveWithGuess(const MatrixBase< Rhs > &b, const Guess &x0) const
Definition: IterativeSolverBase.h:281
Derived & setMaxIterations(Index maxIters)
Definition: IterativeSolverBase.h:256
Definition: Constants.h:440
constexpr Derived & derived()
Definition: EigenBase.h:49
Preconditioner & preconditioner()
Definition: IterativeSolverBase.h:242
The matrix class, also used for vectors and row-vectors.
Definition: Matrix.h:186
IterativeSolverBase()
Definition: IterativeSolverBase.h:142
constexpr Index rows() const noexcept
Definition: EigenBase.h:59
ComputationInfo
Definition: Constants.h:438
Base class for linear iterative solvers.
Definition: IterativeSolverBase.h:124
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
Definition: Constants.h:444
Index maxIterations() const
Definition: IterativeSolverBase.h:251