$darkmode
Eigen  5.0.1-dev
SparsePermutation.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2012 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSE_PERMUTATION_H
11 #define EIGEN_SPARSE_PERMUTATION_H
12 
13 // This file implements sparse * permutation products
14 
15 // IWYU pragma: private
16 #include "./InternalHeaderCheck.h"
17 
18 namespace Eigen {
19 
20 namespace internal {
21 
22 template <typename ExpressionType, typename PlainObjectType,
23  bool NeedEval = !is_same<ExpressionType, PlainObjectType>::value>
24 struct XprHelper {
25  XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
26  inline const PlainObjectType& xpr() const { return m_xpr; }
27  // this is a new PlainObjectType initialized by xpr
28  const PlainObjectType m_xpr;
29 };
30 template <typename ExpressionType, typename PlainObjectType>
31 struct XprHelper<ExpressionType, PlainObjectType, false> {
32  XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
33  inline const PlainObjectType& xpr() const { return m_xpr; }
34  // this is a reference to xpr
35  const PlainObjectType& m_xpr;
36 };
37 
38 template <typename PermDerived, bool NeedInverseEval>
39 struct PermHelper {
40  using IndicesType = typename PermDerived::IndicesType;
41  using PermutationIndex = typename IndicesType::Scalar;
42  using type = PermutationMatrix<IndicesType::SizeAtCompileTime, IndicesType::MaxSizeAtCompileTime, PermutationIndex>;
43  PermHelper(const PermDerived& perm) : m_perm(perm.inverse()) {}
44  inline const type& perm() const { return m_perm; }
45  // this is a new PermutationMatrix initialized by perm.inverse()
46  const type m_perm;
47 };
48 template <typename PermDerived>
49 struct PermHelper<PermDerived, false> {
50  using type = PermDerived;
51  PermHelper(const PermDerived& perm) : m_perm(perm) {}
52  inline const type& perm() const { return m_perm; }
53  // this is a reference to perm
54  const type& m_perm;
55 };
56 
57 template <typename ExpressionType, int Side, bool Transposed>
58 struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape> {
59  using MatrixType = typename nested_eval<ExpressionType, 1>::type;
60  using MatrixTypeCleaned = remove_all_t<MatrixType>;
61 
62  using Scalar = typename MatrixTypeCleaned::Scalar;
63  using StorageIndex = typename MatrixTypeCleaned::StorageIndex;
64 
65  // the actual "return type" is `Dest`. this is a temporary type
66  using ReturnType = SparseMatrix<Scalar, MatrixTypeCleaned::IsRowMajor ? RowMajor : ColMajor, StorageIndex>;
67  using TmpHelper = XprHelper<ExpressionType, ReturnType>;
68 
69  static constexpr bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side == OnTheLeft : Side == OnTheRight;
70  static constexpr bool NeedInversePermutation = Transposed ? Side == OnTheLeft : Side == OnTheRight;
71 
72  template <typename Dest, typename PermutationType>
73  static inline void permute_outer(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
74  // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
75  // otherwise, just reference `xpr`
76  // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
77  const TmpHelper tmpHelper(xpr);
78  const ReturnType& tmp = tmpHelper.xpr();
79 
80  ReturnType result(tmp.rows(), tmp.cols());
81 
82  for (Index j = 0; j < tmp.outerSize(); j++) {
83  Index jp = perm.indices().coeff(j);
84  Index jsrc = NeedInversePermutation ? jp : j;
85  Index jdst = NeedInversePermutation ? j : jp;
86  Index begin = tmp.outerIndexPtr()[jsrc];
87  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
88  result.outerIndexPtr()[jdst + 1] += end - begin;
89  }
90 
91  std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
92  result.resizeNonZeros(result.nonZeros());
93 
94  for (Index j = 0; j < tmp.outerSize(); j++) {
95  Index jp = perm.indices().coeff(j);
96  Index jsrc = NeedInversePermutation ? jp : j;
97  Index jdst = NeedInversePermutation ? j : jp;
98  Index begin = tmp.outerIndexPtr()[jsrc];
99  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
100  Index target = result.outerIndexPtr()[jdst];
101  smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target);
102  smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
103  }
104  dst = std::move(result);
105  }
106 
107  template <typename Dest, typename PermutationType>
108  static inline void permute_inner(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
109  using InnerPermHelper = PermHelper<PermutationType, NeedInversePermutation>;
110  using InnerPermType = typename InnerPermHelper::type;
111 
112  // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
113  // otherwise, just reference `xpr`
114  // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
115  const TmpHelper tmpHelper(xpr);
116  const ReturnType& tmp = tmpHelper.xpr();
117 
118  // if inverse permutation of inner indices is requested, calculate perm.inverse() (allocation)
119  // otherwise, just reference `perm`
120  const InnerPermHelper permHelper(perm);
121  const InnerPermType& innerPerm = permHelper.perm();
122 
123  ReturnType result(tmp.rows(), tmp.cols());
124 
125  for (Index j = 0; j < tmp.outerSize(); j++) {
126  Index begin = tmp.outerIndexPtr()[j];
127  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
128  result.outerIndexPtr()[j + 1] += end - begin;
129  }
130 
131  std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
132  result.resizeNonZeros(result.nonZeros());
133 
134  for (Index j = 0; j < tmp.outerSize(); j++) {
135  Index begin = tmp.outerIndexPtr()[j];
136  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
137  Index target = result.outerIndexPtr()[j];
138  std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target,
139  [&innerPerm](StorageIndex i) { return innerPerm.indices().coeff(i); });
140  smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
141  }
142  // the inner indices were permuted, and must be sorted
143  result.sortInnerIndices();
144  dst = std::move(result);
145  }
146 
147  template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
148  std::enable_if_t<DoOuter, int> = 0>
149  static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
150  permute_outer(dst, perm, xpr);
151  }
152 
153  template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
154  std::enable_if_t<!DoOuter, int> = 0>
155  static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
156  permute_inner(dst, perm, xpr);
157  }
158 };
159 
160 } // namespace internal
161 
162 namespace internal {
163 
164 template <int ProductTag>
165 struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> {
166  typedef Sparse ret;
167 };
168 template <int ProductTag>
169 struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> {
170  typedef Sparse ret;
171 };
172 
173 // TODO, the following two overloads are only needed to define the right temporary type through
174 // typename traits<permutation_sparse_matrix_product<Rhs,Lhs,OnTheRight,false> >::ReturnType
175 // whereas it should be correctly handled by traits<Product<> >::PlainObject
176 
177 template <typename Lhs, typename Rhs, int ProductTag>
178 struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, PermutationShape, SparseShape>
179  : public evaluator<typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType> {
180  typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
181  typedef typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType PlainObject;
182  typedef evaluator<PlainObject> Base;
183 
184  enum { Flags = Base::Flags | EvalBeforeNestingBit };
185 
186  explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
187  internal::construct_at<Base>(this, m_result);
188  generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
189  }
190 
191  protected:
192  PlainObject m_result;
193 };
194 
195 template <typename Lhs, typename Rhs, int ProductTag>
196 struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, SparseShape, PermutationShape>
197  : public evaluator<typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType> {
198  typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
199  typedef typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType PlainObject;
200  typedef evaluator<PlainObject> Base;
201 
202  enum { Flags = Base::Flags | EvalBeforeNestingBit };
203 
204  explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
205  ::new (static_cast<Base*>(this)) Base(m_result);
206  generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
207  }
208 
209  protected:
210  PlainObject m_result;
211 };
212 
213 } // end namespace internal
214 
217 template <typename SparseDerived, typename PermDerived>
218 inline const Product<SparseDerived, PermDerived, AliasFreeProduct> operator*(
219  const SparseMatrixBase<SparseDerived>& matrix, const PermutationBase<PermDerived>& perm) {
220  return Product<SparseDerived, PermDerived, AliasFreeProduct>(matrix.derived(), perm.derived());
221 }
222 
225 template <typename SparseDerived, typename PermDerived>
229 }
230 
233 template <typename SparseDerived, typename PermutationType>
235  const SparseMatrixBase<SparseDerived>& matrix, const InverseImpl<PermutationType, PermutationStorage>& tperm) {
236  return Product<SparseDerived, Inverse<PermutationType>, AliasFreeProduct>(matrix.derived(), tperm.derived());
237 }
238 
241 template <typename SparseDerived, typename PermutationType>
242 inline const Product<Inverse<PermutationType>, SparseDerived, AliasFreeProduct> operator*(
243  const InverseImpl<PermutationType, PermutationStorage>& tperm, const SparseMatrixBase<SparseDerived>& matrix) {
244  return Product<Inverse<PermutationType>, SparseDerived, AliasFreeProduct>(tperm.derived(), matrix.derived());
245 }
246 
247 } // end namespace Eigen
248 
249 #endif // EIGEN_SPARSE_SELFADJOINTVIEW_H
static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
const Product< MatrixDerived, PermutationDerived, DefaultProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:474
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:198
Definition: Constants.h:333
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Base class for permutations.
Definition: PermutationMatrix.h:49
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)
Base class of any sparse matrices or sparse expressions.
Definition: ForwardDeclarations.h:481
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:331
constexpr Derived & derived()
Definition: EigenBase.h:49
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:74