$darkmode
Eigen  5.0.1-dev
SparseSparseProductWithPruning.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
21 template <typename Lhs, typename Rhs, typename ResultType>
22 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res,
23  const typename ResultType::RealScalar& tolerance) {
24  // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
25 
26  typedef typename remove_all_t<Rhs>::Scalar RhsScalar;
27  typedef typename remove_all_t<ResultType>::Scalar ResScalar;
28  typedef typename remove_all_t<Lhs>::StorageIndex StorageIndex;
29 
30  // make sure to call innerSize/outerSize since we fake the storage order.
31  Index rows = lhs.innerSize();
32  Index cols = rhs.outerSize();
33  // Index size = lhs.outerSize();
34  eigen_assert(lhs.outerSize() == rhs.innerSize());
35 
36  // allocate a temporary buffer
37  AmbiVector<ResScalar, StorageIndex> tempVector(rows);
38 
39  // mimics a resizeByInnerOuter:
40  if (ResultType::IsRowMajor)
41  res.resize(cols, rows);
42  else
43  res.resize(rows, cols);
44 
45  evaluator<Lhs> lhsEval(lhs);
46  evaluator<Rhs> rhsEval(rhs);
47 
48  // estimate the number of non zero entries
49  // given a rhs column containing Y non zeros, we assume that the respective Y columns
50  // of the lhs differs in average of one non zeros, thus the number of non zeros for
51  // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
52  // per column of the lhs.
53  // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
54  Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
55 
56  res.reserve(estimated_nnz_prod);
57  double ratioColRes = double(estimated_nnz_prod) / (double(lhs.rows()) * double(rhs.cols()));
58  for (Index j = 0; j < cols; ++j) {
59  // FIXME:
60  // double ratioColRes = (double(rhs.innerVector(j).nonZeros()) +
61  // double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
62  // let's do a more accurate determination of the nnz ratio for the current column j of res
63  tempVector.init(ratioColRes);
64  tempVector.setZero();
65  for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) {
66  // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
67  tempVector.restart();
68  RhsScalar x = rhsIt.value();
69  for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) {
70  tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
71  }
72  }
73  res.startVec(j);
74  for (typename AmbiVector<ResScalar, StorageIndex>::Iterator it(tempVector, tolerance); it; ++it)
75  res.insertBackByOuterInner(j, it.index()) = it.value();
76  }
77  res.finalize();
78 }
79 
80 template <typename Lhs, typename Rhs, typename ResultType, int LhsStorageOrder = traits<Lhs>::Flags & RowMajorBit,
81  int RhsStorageOrder = traits<Rhs>::Flags & RowMajorBit,
82  int ResStorageOrder = traits<ResultType>::Flags & RowMajorBit>
83 struct sparse_sparse_product_with_pruning_selector;
84 
85 template <typename Lhs, typename Rhs, typename ResultType>
86 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, ColMajor> {
87  typedef typename ResultType::RealScalar RealScalar;
88 
89  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
90  remove_all_t<ResultType> res_(res.rows(), res.cols());
91  internal::sparse_sparse_product_with_pruning_impl<Lhs, Rhs, ResultType>(lhs, rhs, res_, tolerance);
92  res.swap(res_);
93  }
94 };
95 
96 template <typename Lhs, typename Rhs, typename ResultType>
97 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, RowMajor> {
98  typedef typename ResultType::RealScalar RealScalar;
99  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
100  // we need a col-major matrix to hold the result
101  typedef SparseMatrix<typename ResultType::Scalar, ColMajor, typename ResultType::StorageIndex> SparseTemporaryType;
102  SparseTemporaryType res_(res.rows(), res.cols());
103  internal::sparse_sparse_product_with_pruning_impl<Lhs, Rhs, SparseTemporaryType>(lhs, rhs, res_, tolerance);
104  res = res_;
105  }
106 };
107 
108 template <typename Lhs, typename Rhs, typename ResultType>
109 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, RowMajor> {
110  typedef typename ResultType::RealScalar RealScalar;
111  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
112  // let's transpose the product to get a column x column product
113  remove_all_t<ResultType> res_(res.rows(), res.cols());
114  internal::sparse_sparse_product_with_pruning_impl<Rhs, Lhs, ResultType>(rhs, lhs, res_, tolerance);
115  res.swap(res_);
116  }
117 };
118 
119 template <typename Lhs, typename Rhs, typename ResultType>
120 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, ColMajor> {
121  typedef typename ResultType::RealScalar RealScalar;
122  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
123  typedef SparseMatrix<typename Lhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixLhs;
124  typedef SparseMatrix<typename Rhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixRhs;
125  ColMajorMatrixLhs colLhs(lhs);
126  ColMajorMatrixRhs colRhs(rhs);
127  internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs, ColMajorMatrixRhs, ResultType>(colLhs, colRhs,
128  res, tolerance);
129 
130  // let's transpose the product to get a column x column product
131  // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
132  // SparseTemporaryType res_(res.cols(), res.rows());
133  // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, res_);
134  // res = res_.transpose();
135  }
136 };
137 
138 template <typename Lhs, typename Rhs, typename ResultType>
139 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, RowMajor> {
140  typedef typename ResultType::RealScalar RealScalar;
141  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
142  typedef SparseMatrix<typename Lhs::Scalar, RowMajor, typename Lhs::StorageIndex> RowMajorMatrixLhs;
143  RowMajorMatrixLhs rowLhs(lhs);
144  sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs, Rhs, ResultType, RowMajor, RowMajor>(rowLhs, rhs,
145  res, tolerance);
146  }
147 };
148 
149 template <typename Lhs, typename Rhs, typename ResultType>
150 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, RowMajor> {
151  typedef typename ResultType::RealScalar RealScalar;
152  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
153  typedef SparseMatrix<typename Rhs::Scalar, RowMajor, typename Lhs::StorageIndex> RowMajorMatrixRhs;
154  RowMajorMatrixRhs rowRhs(rhs);
155  sparse_sparse_product_with_pruning_selector<Lhs, RowMajorMatrixRhs, ResultType, RowMajor, RowMajor, RowMajor>(
156  lhs, rowRhs, res, tolerance);
157  }
158 };
159 
160 template <typename Lhs, typename Rhs, typename ResultType>
161 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, ColMajor> {
162  typedef typename ResultType::RealScalar RealScalar;
163  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
164  typedef SparseMatrix<typename Rhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixRhs;
165  ColMajorMatrixRhs colRhs(rhs);
166  internal::sparse_sparse_product_with_pruning_impl<Lhs, ColMajorMatrixRhs, ResultType>(lhs, colRhs, res, tolerance);
167  }
168 };
169 
170 template <typename Lhs, typename Rhs, typename ResultType>
171 struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, ColMajor> {
172  typedef typename ResultType::RealScalar RealScalar;
173  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
174  typedef SparseMatrix<typename Lhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixLhs;
175  ColMajorMatrixLhs colLhs(lhs);
176  internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs, Rhs, ResultType>(colLhs, rhs, res, tolerance);
177  }
178 };
179 
180 } // end namespace internal
181 
182 } // end namespace Eigen
183 
184 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
Definition: Constants.h:318
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
const unsigned int RowMajorBit
Definition: Constants.h:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:320