$darkmode
Eigen-unsupported  5.0.1-dev
KroneckerTensorProduct.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
5 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
6 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef KRONECKER_TENSOR_PRODUCT_H
13 #define KRONECKER_TENSOR_PRODUCT_H
14 
15 // IWYU pragma: private
16 #include "./InternalHeaderCheck.h"
17 
18 namespace Eigen {
19 
27 template <typename Derived>
28 class KroneckerProductBase : public ReturnByValue<Derived> {
29  private:
30  typedef typename internal::traits<Derived> Traits;
31  typedef typename Traits::Scalar Scalar;
32 
33  protected:
34  typedef typename Traits::Lhs Lhs;
35  typedef typename Traits::Rhs Rhs;
36 
37  public:
39  KroneckerProductBase(const Lhs& A, const Rhs& B) : m_A(A), m_B(B) {}
40 
41  inline Index rows() const { return m_A.rows() * m_B.rows(); }
42  inline Index cols() const { return m_A.cols() * m_B.cols(); }
43 
48  Scalar coeff(Index row, Index col) const {
49  return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * m_B.coeff(row % m_B.rows(), col % m_B.cols());
50  }
51 
56  Scalar coeff(Index i) const {
57  EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
58  return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
59  }
60 
61  protected:
62  typename Lhs::Nested m_A;
63  typename Rhs::Nested m_B;
64 };
65 
78 template <typename Lhs, typename Rhs>
79 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs, Rhs> > {
80  private:
82  using Base::m_A;
83  using Base::m_B;
84 
85  public:
87  KroneckerProduct(const Lhs& A, const Rhs& B) : Base(A, B) {}
88 
90  template <typename Dest>
91  void evalTo(Dest& dst) const;
92 };
93 
109 template <typename Lhs, typename Rhs>
110 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs, Rhs> > {
111  private:
113  using Base::m_A;
114  using Base::m_B;
115 
116  public:
118  KroneckerProductSparse(const Lhs& A, const Rhs& B) : Base(A, B) {}
119 
121  template <typename Dest>
122  void evalTo(Dest& dst) const;
123 };
124 
125 template <typename Lhs, typename Rhs>
126 template <typename Dest>
127 void KroneckerProduct<Lhs, Rhs>::evalTo(Dest& dst) const {
128  const int BlockRows = Rhs::RowsAtCompileTime, BlockCols = Rhs::ColsAtCompileTime;
129  const Index Br = m_B.rows(), Bc = m_B.cols();
130  for (Index i = 0; i < m_A.rows(); ++i)
131  for (Index j = 0; j < m_A.cols(); ++j)
132  Block<Dest, BlockRows, BlockCols>(dst, i * Br, j * Bc, Br, Bc) = m_A.coeff(i, j) * m_B;
133 }
134 
135 template <typename Lhs, typename Rhs>
136 template <typename Dest>
138  Index Br = m_B.rows(), Bc = m_B.cols();
139  dst.resize(this->rows(), this->cols());
140  dst.resizeNonZeros(0);
141 
142  // 1 - evaluate the operands if needed:
143  typedef typename internal::nested_eval<Lhs, Dynamic>::type Lhs1;
144  typedef internal::remove_all_t<Lhs1> Lhs1Cleaned;
145  const Lhs1 lhs1(m_A);
146  typedef typename internal::nested_eval<Rhs, Dynamic>::type Rhs1;
147  typedef internal::remove_all_t<Rhs1> Rhs1Cleaned;
148  const Rhs1 rhs1(m_B);
149 
150  // 2 - construct respective iterators
151  typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
152  typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
153 
154  // compute number of non-zeros per innervectors of dst
155  {
156  // TODO VectorXi is not necessarily big enough!
157  VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
158  for (Index kA = 0; kA < m_A.outerSize(); ++kA)
159  for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
160 
161  VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
162  for (Index kB = 0; kB < m_B.outerSize(); ++kB)
163  for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
164 
165  Matrix<int, Dynamic, Dynamic, ColMajor> nnzAB = nnzB * nnzA.transpose();
166  dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
167  }
168 
169  for (Index kA = 0; kA < m_A.outerSize(); ++kA) {
170  for (Index kB = 0; kB < m_B.outerSize(); ++kB) {
171  for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) {
172  for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) {
173  Index i = itA.row() * Br + itB.row(), j = itA.col() * Bc + itB.col();
174  dst.insert(i, j) = itA.value() * itB.value();
175  }
176  }
177  }
178  }
179 }
180 
181 namespace internal {
182 
183 template <typename Lhs_, typename Rhs_>
184 struct traits<KroneckerProduct<Lhs_, Rhs_> > {
185  typedef remove_all_t<Lhs_> Lhs;
186  typedef remove_all_t<Rhs_> Rhs;
188  typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
189 
190  enum {
191  Rows = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
192  Cols = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
193  MaxRows = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
194  MaxCols = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime)
195  };
196 
197  typedef Matrix<Scalar, Rows, Cols> ReturnType;
198 };
199 
200 template <typename Lhs_, typename Rhs_>
201 struct traits<KroneckerProductSparse<Lhs_, Rhs_> > {
202  typedef MatrixXpr XprKind;
203  typedef remove_all_t<Lhs_> Lhs;
204  typedef remove_all_t<Rhs_> Rhs;
205  typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
206  typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind,
207  scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret
208  StorageKind;
209  typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
210 
211  enum {
212  LhsFlags = Lhs::Flags,
213  RhsFlags = Rhs::Flags,
214 
215  RowsAtCompileTime = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
216  ColsAtCompileTime = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
217  MaxRowsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
218  MaxColsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
219 
220  EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) & RowMajorBit),
221  RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
222 
223  Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits) | EvalBeforeNestingBit,
224  CoeffReadCost = HugeCost
225  };
226 
227  typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
228 };
229 
230 } // end namespace internal
231 
251 template <typename A, typename B>
253  return KroneckerProduct<A, B>(a.derived(), b.derived());
254 }
255 
277 template <typename A, typename B>
279  return KroneckerProductSparse<A, B>(a.derived(), b.derived());
280 }
281 
282 } // end namespace Eigen
283 
284 #endif // KRONECKER_TENSOR_PRODUCT_H
const int HugeCost
Scalar coeff(Index row, Index col) const
Definition: KroneckerTensorProduct.h:48
Namespace containing all symbols from the Eigen library.
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition: KroneckerTensorProduct.h:127
const unsigned int RowMajorBit
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:118
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:39
constexpr Derived & derived()
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:87
const unsigned int EvalBeforeNestingBit
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition: KroneckerTensorProduct.h:137
Kronecker tensor product helper class for dense matrices.
Definition: KroneckerTensorProduct.h:79
The base class of dense and sparse Kronecker product.
Definition: KroneckerTensorProduct.h:28
Kronecker tensor product helper class for sparse matrices.
Definition: KroneckerTensorProduct.h:110
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
Definition: KroneckerTensorProduct.h:252
Scalar coeff(Index i) const
Definition: KroneckerTensorProduct.h:56