$darkmode
Eigen  5.0.1-dev
KLUSupport.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Kyle Macfarlan <kyle.macfarlan@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_KLUSUPPORT_H
11 #define EIGEN_KLUSUPPORT_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 /* TODO extract L, extract U, compute det, etc... */
19 
36 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[],
37  klu_common *Common, double) {
38  return klu_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B,
39  Common);
40 }
41 
42 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double> B[],
43  klu_common *Common, std::complex<double>) {
44  return klu_z_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs),
45  &numext::real_ref(B[0]), Common);
46 }
47 
48 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[],
49  klu_common *Common, double) {
50  return klu_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B,
51  Common);
52 }
53 
54 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double> B[],
55  klu_common *Common, std::complex<double>) {
56  return klu_z_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs),
57  &numext::real_ref(B[0]), 0, Common);
58 }
59 
60 inline klu_numeric *klu_factor(int Ap[], int Ai[], double Ax[], klu_symbolic *Symbolic, klu_common *Common, double) {
61  return klu_factor(Ap, Ai, Ax, Symbolic, Common);
62 }
63 
64 inline klu_numeric *klu_factor(int Ap[], int Ai[], std::complex<double> Ax[], klu_symbolic *Symbolic,
65  klu_common *Common, std::complex<double>) {
66  return klu_z_factor(Ap, Ai, &numext::real_ref(Ax[0]), Symbolic, Common);
67 }
68 
69 template <typename MatrixType_>
70 class KLU : public SparseSolverBase<KLU<MatrixType_> > {
71  protected:
72  typedef SparseSolverBase<KLU<MatrixType_> > Base;
73  using Base::m_isInitialized;
74 
75  public:
76  using Base::_solve_impl;
77  typedef MatrixType_ MatrixType;
78  typedef typename MatrixType::Scalar Scalar;
79  typedef typename MatrixType::RealScalar RealScalar;
80  typedef typename MatrixType::StorageIndex StorageIndex;
81  typedef Matrix<Scalar, Dynamic, 1> Vector;
82  typedef Matrix<int, 1, MatrixType::ColsAtCompileTime> IntRowVectorType;
83  typedef Matrix<int, MatrixType::RowsAtCompileTime, 1> IntColVectorType;
84  typedef SparseMatrix<Scalar> LUMatrixType;
85  typedef SparseMatrix<Scalar, ColMajor, int> KLUMatrixType;
86  typedef Ref<const KLUMatrixType, StandardCompressedFormat> KLUMatrixRef;
87  enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };
88 
89  public:
90  KLU() : m_dummy(0, 0), mp_matrix(m_dummy) { init(); }
91 
92  template <typename InputMatrixType>
93  explicit KLU(const InputMatrixType &matrix) : mp_matrix(matrix) {
94  init();
95  compute(matrix);
96  }
97 
98  ~KLU() {
99  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
100  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
101  }
102 
103  constexpr Index rows() const noexcept { return mp_matrix.rows(); }
104  constexpr Index cols() const noexcept { return mp_matrix.cols(); }
105 
111  ComputationInfo info() const {
112  eigen_assert(m_isInitialized && "Decomposition is not initialized.");
113  return m_info;
114  }
115 #if 0 // not implemented yet
116  inline const LUMatrixType& matrixL() const
117  {
118  if (m_extractedDataAreDirty) extractData();
119  return m_l;
120  }
121 
122  inline const LUMatrixType& matrixU() const
123  {
124  if (m_extractedDataAreDirty) extractData();
125  return m_u;
126  }
127 
128  inline const IntColVectorType& permutationP() const
129  {
130  if (m_extractedDataAreDirty) extractData();
131  return m_p;
132  }
133 
134  inline const IntRowVectorType& permutationQ() const
135  {
136  if (m_extractedDataAreDirty) extractData();
137  return m_q;
138  }
139 #endif
140 
144  template <typename InputMatrixType>
145  void compute(const InputMatrixType &matrix) {
146  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
147  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
148  grab(matrix.derived());
149  analyzePattern_impl();
150  factorize_impl();
151  }
152 
159  template <typename InputMatrixType>
160  void analyzePattern(const InputMatrixType &matrix) {
161  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
162  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
163 
164  grab(matrix.derived());
165 
166  analyzePattern_impl();
167  }
168 
173  inline const klu_common &kluCommon() const { return m_common; }
174 
181  inline klu_common &kluCommon() { return m_common; }
182 
189  template <typename InputMatrixType>
190  void factorize(const InputMatrixType &matrix) {
191  eigen_assert(m_analysisIsOk && "KLU: you must first call analyzePattern()");
192  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
193 
194  grab(matrix.derived());
195 
196  factorize_impl();
197  }
198 
200  template <typename BDerived, typename XDerived>
201  bool _solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const;
202 
203 #if 0 // not implemented yet
204  Scalar determinant() const;
205 
206  void extractData() const;
207 #endif
208 
209  protected:
210  void init() {
211  m_info = InvalidInput;
212  m_isInitialized = false;
213  m_numeric = 0;
214  m_symbolic = 0;
215  m_extractedDataAreDirty = true;
216 
217  klu_defaults(&m_common);
218  }
219 
220  void analyzePattern_impl() {
221  m_info = InvalidInput;
222  m_analysisIsOk = false;
223  m_factorizationIsOk = false;
224  m_symbolic = klu_analyze(internal::convert_index<int>(mp_matrix.rows()),
225  const_cast<StorageIndex *>(mp_matrix.outerIndexPtr()),
226  const_cast<StorageIndex *>(mp_matrix.innerIndexPtr()), &m_common);
227  if (m_symbolic) {
228  m_isInitialized = true;
229  m_info = Success;
230  m_analysisIsOk = true;
231  m_extractedDataAreDirty = true;
232  }
233  }
234 
235  void factorize_impl() {
236  m_numeric = klu_factor(const_cast<StorageIndex *>(mp_matrix.outerIndexPtr()),
237  const_cast<StorageIndex *>(mp_matrix.innerIndexPtr()),
238  const_cast<Scalar *>(mp_matrix.valuePtr()), m_symbolic, &m_common, Scalar());
239 
240  m_info = m_numeric ? Success : NumericalIssue;
241  m_factorizationIsOk = m_numeric ? 1 : 0;
242  m_extractedDataAreDirty = true;
243  }
244 
245  template <typename MatrixDerived>
246  void grab(const EigenBase<MatrixDerived> &A) {
247  internal::destroy_at(&mp_matrix);
248  internal::construct_at(&mp_matrix, A.derived());
249  }
250 
251  void grab(const KLUMatrixRef &A) {
252  if (&(A.derived()) != &mp_matrix) {
253  internal::destroy_at(&mp_matrix);
254  internal::construct_at(&mp_matrix, A);
255  }
256  }
257 
258  // cached data to reduce reallocation, etc.
259 #if 0 // not implemented yet
260  mutable LUMatrixType m_l;
261  mutable LUMatrixType m_u;
262  mutable IntColVectorType m_p;
263  mutable IntRowVectorType m_q;
264 #endif
265 
266  KLUMatrixType m_dummy;
267  KLUMatrixRef mp_matrix;
268 
269  klu_numeric *m_numeric;
270  klu_symbolic *m_symbolic;
271  klu_common m_common;
272  mutable ComputationInfo m_info;
273  int m_factorizationIsOk;
274  int m_analysisIsOk;
275  mutable bool m_extractedDataAreDirty;
276 
277  private:
278  KLU(const KLU &) {}
279 };
280 
281 #if 0 // not implemented yet
282 template<typename MatrixType>
283 void KLU<MatrixType>::extractData() const
284 {
285  if (m_extractedDataAreDirty)
286  {
287  eigen_assert(false && "KLU: extractData Not Yet Implemented");
288 
289  // get size of the data
290  int lnz, unz, rows, cols, nz_udiag;
291  umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
292 
293  // allocate data
294  m_l.resize(rows,(std::min)(rows,cols));
295  m_l.resizeNonZeros(lnz);
296 
297  m_u.resize((std::min)(rows,cols),cols);
298  m_u.resizeNonZeros(unz);
299 
300  m_p.resize(rows);
301  m_q.resize(cols);
302 
303  // extract
304  umfpack_get_numeric(m_l.outerIndexPtr(), m_l.innerIndexPtr(), m_l.valuePtr(),
305  m_u.outerIndexPtr(), m_u.innerIndexPtr(), m_u.valuePtr(),
306  m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
307 
308  m_extractedDataAreDirty = false;
309  }
310 }
311 
312 template<typename MatrixType>
313 typename KLU<MatrixType>::Scalar KLU<MatrixType>::determinant() const
314 {
315  eigen_assert(false && "KLU: extractData Not Yet Implemented");
316  return Scalar();
317 }
318 #endif
319 
320 template <typename MatrixType>
321 template <typename BDerived, typename XDerived>
322 bool KLU<MatrixType>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const {
323  Index rhsCols = b.cols();
324  EIGEN_STATIC_ASSERT((XDerived::Flags & RowMajorBit) == 0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
325  eigen_assert(m_factorizationIsOk &&
326  "The decomposition is not in a valid state for solving, you must first call either compute() or "
327  "analyzePattern()/factorize()");
328 
329  x = b;
330  int info = klu_solve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(),
331  const_cast<klu_common *>(&m_common), Scalar());
332 
333  m_info = info != 0 ? Success : NumericalIssue;
334  return true;
335 }
336 
337 } // end namespace Eigen
338 
339 #endif // EIGEN_KLUSUPPORT_H
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double)
A sparse LU factorization and solver based on KLU.
Definition: KLUSupport.h:36
const unsigned int RowMajorBit
Definition: Constants.h:70
Matrix< Type, Size, 1 > Vector
[c++11] SizeƗ1 vector of type Type.
Definition: Matrix.h:522
Definition: Constants.h:442
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:447
Definition: Constants.h:440
ComputationInfo
Definition: Constants.h:438