10 #ifndef EIGEN_SPARSE_PERMUTATION_H 11 #define EIGEN_SPARSE_PERMUTATION_H 16 #include "./InternalHeaderCheck.h" 22 template <
typename ExpressionType,
typename PlainObjectType,
23 bool NeedEval = !is_same<ExpressionType, PlainObjectType>::value>
25 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
26 inline const PlainObjectType& xpr()
const {
return m_xpr; }
28 const PlainObjectType m_xpr;
30 template <
typename ExpressionType,
typename PlainObjectType>
31 struct XprHelper<ExpressionType, PlainObjectType, false> {
32 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
33 inline const PlainObjectType& xpr()
const {
return m_xpr; }
35 const PlainObjectType& m_xpr;
38 template <
typename PermDerived,
bool NeedInverseEval>
40 using IndicesType =
typename PermDerived::IndicesType;
41 using PermutationIndex =
typename IndicesType::Scalar;
42 using type = PermutationMatrix<IndicesType::SizeAtCompileTime, IndicesType::MaxSizeAtCompileTime, PermutationIndex>;
43 PermHelper(
const PermDerived& perm) : m_perm(perm.
inverse()) {}
44 inline const type& perm()
const {
return m_perm; }
48 template <
typename PermDerived>
49 struct PermHelper<PermDerived, false> {
50 using type = PermDerived;
51 PermHelper(
const PermDerived& perm) : m_perm(perm) {}
52 inline const type& perm()
const {
return m_perm; }
57 template <
typename ExpressionType,
int S
ide,
bool Transposed>
58 struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape> {
59 using MatrixType =
typename nested_eval<ExpressionType, 1>::type;
60 using MatrixTypeCleaned = remove_all_t<MatrixType>;
62 using Scalar =
typename MatrixTypeCleaned::Scalar;
63 using StorageIndex =
typename MatrixTypeCleaned::StorageIndex;
66 using ReturnType = SparseMatrix<Scalar, MatrixTypeCleaned::IsRowMajor ? RowMajor : ColMajor, StorageIndex>;
67 using TmpHelper = XprHelper<ExpressionType, ReturnType>;
69 static constexpr
bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side ==
OnTheLeft : Side ==
OnTheRight;
70 static constexpr
bool NeedInversePermutation = Transposed ? Side ==
OnTheLeft : Side ==
OnTheRight;
72 template <
typename Dest,
typename PermutationType>
73 static inline void permute_outer(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
77 const TmpHelper tmpHelper(xpr);
78 const ReturnType& tmp = tmpHelper.xpr();
80 ReturnType result(tmp.rows(), tmp.cols());
82 for (
Index j = 0; j < tmp.outerSize(); j++) {
83 Index jp = perm.indices().coeff(j);
84 Index jsrc = NeedInversePermutation ? jp : j;
85 Index jdst = NeedInversePermutation ? j : jp;
86 Index begin = tmp.outerIndexPtr()[jsrc];
87 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
88 result.outerIndexPtr()[jdst + 1] +=
end - begin;
91 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
92 result.resizeNonZeros(result.nonZeros());
94 for (
Index j = 0; j < tmp.outerSize(); j++) {
95 Index jp = perm.indices().coeff(j);
96 Index jsrc = NeedInversePermutation ? jp : j;
97 Index jdst = NeedInversePermutation ? j : jp;
98 Index begin = tmp.outerIndexPtr()[jsrc];
99 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
100 Index target = result.outerIndexPtr()[jdst];
101 smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() +
end, result.innerIndexPtr() + target);
102 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() +
end, result.valuePtr() + target);
104 dst = std::move(result);
107 template <
typename Dest,
typename PermutationType>
108 static inline void permute_inner(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
109 using InnerPermHelper = PermHelper<PermutationType, NeedInversePermutation>;
110 using InnerPermType =
typename InnerPermHelper::type;
115 const TmpHelper tmpHelper(xpr);
116 const ReturnType& tmp = tmpHelper.xpr();
120 const InnerPermHelper permHelper(perm);
121 const InnerPermType& innerPerm = permHelper.perm();
123 ReturnType result(tmp.rows(), tmp.cols());
125 for (
Index j = 0; j < tmp.outerSize(); j++) {
126 Index begin = tmp.outerIndexPtr()[j];
127 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
128 result.outerIndexPtr()[j + 1] +=
end - begin;
131 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
132 result.resizeNonZeros(result.nonZeros());
134 for (
Index j = 0; j < tmp.outerSize(); j++) {
135 Index begin = tmp.outerIndexPtr()[j];
136 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
137 Index target = result.outerIndexPtr()[j];
138 std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() +
end, result.innerIndexPtr() + target,
139 [&innerPerm](StorageIndex i) {
return innerPerm.indices().coeff(i); });
140 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() +
end, result.valuePtr() + target);
143 result.sortInnerIndices();
144 dst = std::move(result);
147 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation,
148 std::enable_if_t<DoOuter, int> = 0>
149 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
150 permute_outer(dst, perm, xpr);
153 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation,
154 std::enable_if_t<!DoOuter, int> = 0>
155 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
156 permute_inner(dst, perm, xpr);
164 template <
int ProductTag>
165 struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> {
168 template <
int ProductTag>
169 struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> {
177 template <
typename Lhs,
typename Rhs,
int ProductTag>
178 struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, PermutationShape, SparseShape>
179 :
public evaluator<typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType> {
180 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
181 typedef typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType PlainObject;
182 typedef evaluator<PlainObject> Base;
186 explicit product_evaluator(
const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
187 internal::construct_at<Base>(
this, m_result);
188 generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
192 PlainObject m_result;
195 template <
typename Lhs,
typename Rhs,
int ProductTag>
196 struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, SparseShape, PermutationShape>
197 :
public evaluator<typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType> {
198 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
199 typedef typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType PlainObject;
200 typedef evaluator<PlainObject> Base;
204 explicit product_evaluator(
const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
205 ::new (static_cast<Base*>(
this)) Base(m_result);
206 generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
210 PlainObject m_result;
217 template <typename SparseDerived, typename PermDerived>
218 inline const
Product<SparseDerived, PermDerived, AliasFreeProduct> operator*(
225 template <
typename SparseDerived,
typename PermDerived>
233 template <
typename SparseDerived,
typename PermutationType>
241 template <
typename SparseDerived,
typename PermutationType>
249 #endif // EIGEN_SPARSE_SELFADJOINTVIEW_H static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
const Product< MatrixDerived, PermutationDerived, DefaultProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:474
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:198
Definition: Constants.h:333
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Base class for permutations.
Definition: PermutationMatrix.h:49
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)
Base class of any sparse matrices or sparse expressions.
Definition: ForwardDeclarations.h:481
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:331
constexpr Derived & derived()
Definition: EigenBase.h:49
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:74