11 #ifndef EIGEN_GENERAL_PRODUCT_H 12 #define EIGEN_GENERAL_PRODUCT_H 15 #include "./InternalHeaderCheck.h" 19 enum { Large = 2, Small = 3 };
26 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 28 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20 33 template <
int Rows,
int Cols,
int Depth>
34 struct product_type_selector;
36 template <
int Size,
int MaxSize>
37 struct product_size_category {
39 #ifndef EIGEN_GPU_COMPILE_PHASE 40 is_large = MaxSize ==
Dynamic || Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
41 (Size ==
Dynamic && MaxSize >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
45 value = is_large ? Large
51 template <
typename Lhs,
typename Rhs>
53 typedef remove_all_t<Lhs> Lhs_;
54 typedef remove_all_t<Rhs> Rhs_;
56 MaxRows = traits<Lhs_>::MaxRowsAtCompileTime,
57 Rows = traits<Lhs_>::RowsAtCompileTime,
58 MaxCols = traits<Rhs_>::MaxColsAtCompileTime,
59 Cols = traits<Rhs_>::ColsAtCompileTime,
60 MaxDepth = min_size_prefer_fixed(traits<Lhs_>::MaxColsAtCompileTime, traits<Rhs_>::MaxRowsAtCompileTime),
61 Depth = min_size_prefer_fixed(traits<Lhs_>::ColsAtCompileTime, traits<Rhs_>::RowsAtCompileTime)
68 rows_select = product_size_category<Rows, MaxRows>::value,
69 cols_select = product_size_category<Cols, MaxCols>::value,
70 depth_select = product_size_category<Depth, MaxDepth>::value
72 typedef product_type_selector<rows_select, cols_select, depth_select> selector;
75 enum { value = selector::ret, ret = selector::ret };
76 #ifdef EIGEN_DEBUG_PRODUCT 78 EIGEN_DEBUG_VAR(Rows);
79 EIGEN_DEBUG_VAR(Cols);
80 EIGEN_DEBUG_VAR(Depth);
81 EIGEN_DEBUG_VAR(rows_select);
82 EIGEN_DEBUG_VAR(cols_select);
83 EIGEN_DEBUG_VAR(depth_select);
84 EIGEN_DEBUG_VAR(value);
93 template <
int M,
int N>
94 struct product_type_selector<M, N, 1> {
95 enum { ret = OuterProduct };
98 struct product_type_selector<M, 1, 1> {
99 enum { ret = LazyCoeffBasedProductMode };
102 struct product_type_selector<1, N, 1> {
103 enum { ret = LazyCoeffBasedProductMode };
106 struct product_type_selector<1, 1, Depth> {
107 enum { ret = InnerProduct };
110 struct product_type_selector<1, 1, 1> {
111 enum { ret = InnerProduct };
114 struct product_type_selector<Small, 1, Small> {
115 enum { ret = CoeffBasedProductMode };
118 struct product_type_selector<1, Small, Small> {
119 enum { ret = CoeffBasedProductMode };
122 struct product_type_selector<Small, Small, Small> {
123 enum { ret = CoeffBasedProductMode };
126 struct product_type_selector<Small, Small, 1> {
127 enum { ret = LazyCoeffBasedProductMode };
130 struct product_type_selector<Small, Large, 1> {
131 enum { ret = LazyCoeffBasedProductMode };
134 struct product_type_selector<Large, Small, 1> {
135 enum { ret = LazyCoeffBasedProductMode };
138 struct product_type_selector<1, Large, Small> {
139 enum { ret = CoeffBasedProductMode };
142 struct product_type_selector<1, Large, Large> {
143 enum { ret = GemvProduct };
146 struct product_type_selector<1, Small, Large> {
147 enum { ret = CoeffBasedProductMode };
150 struct product_type_selector<Large, 1, Small> {
151 enum { ret = CoeffBasedProductMode };
154 struct product_type_selector<Large, 1, Large> {
155 enum { ret = GemvProduct };
158 struct product_type_selector<Small, 1, Large> {
159 enum { ret = CoeffBasedProductMode };
162 struct product_type_selector<Small, Small, Large> {
163 enum { ret = GemmProduct };
166 struct product_type_selector<Large, Small, Large> {
167 enum { ret = GemmProduct };
170 struct product_type_selector<Small, Large, Large> {
171 enum { ret = GemmProduct };
174 struct product_type_selector<Large, Large, Large> {
175 enum { ret = GemmProduct };
178 struct product_type_selector<Large, Small, Small> {
179 enum { ret = CoeffBasedProductMode };
182 struct product_type_selector<Small, Large, Small> {
183 enum { ret = CoeffBasedProductMode };
186 struct product_type_selector<Large, Large, Small> {
187 enum { ret = GemmProduct };
220 template <
int S
ide,
int StorageOrder,
bool BlasCompatible>
221 struct gemv_dense_selector;
227 template <
typename Scalar,
int Size,
int MaxSize,
bool Cond>
228 struct gemv_static_vector_if;
230 template <
typename Scalar,
int Size,
int MaxSize>
231 struct gemv_static_vector_if<Scalar, Size, MaxSize, false> {
232 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr Scalar* data() {
233 eigen_internal_assert(
false &&
"should never be called");
238 template <
typename Scalar,
int Size>
239 struct gemv_static_vector_if<Scalar, Size,
Dynamic, true> {
240 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr Scalar* data() {
return 0; }
243 template <
typename Scalar,
int Size,
int MaxSize>
244 struct gemv_static_vector_if<Scalar, Size, MaxSize, true> {
245 #if EIGEN_MAX_STATIC_ALIGN_BYTES != 0 246 internal::plain_array<Scalar, internal::min_size_prefer_fixed(Size, MaxSize), 0, AlignedMax> m_data;
247 EIGEN_STRONG_INLINE constexpr Scalar* data() {
return m_data.array; }
251 internal::plain_array<Scalar, internal::min_size_prefer_fixed(Size, MaxSize) + EIGEN_MAX_ALIGN_BYTES, 0> m_data;
252 EIGEN_STRONG_INLINE constexpr Scalar* data() {
253 return reinterpret_cast<Scalar*
>((std::uintptr_t(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES - 1))) +
254 EIGEN_MAX_ALIGN_BYTES);
260 template <
int StorageOrder,
bool BlasCompatible>
261 struct gemv_dense_selector<
OnTheLeft, StorageOrder, BlasCompatible> {
262 template <
typename Lhs,
typename Rhs,
typename Dest>
263 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
264 Transpose<Dest> destT(dest);
266 gemv_dense_selector<OnTheRight, OtherStorageOrder, BlasCompatible>::run(rhs.transpose(), lhs.transpose(), destT,
273 template <
typename Lhs,
typename Rhs,
typename Dest>
274 static inline void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
275 typedef typename Lhs::Scalar LhsScalar;
276 typedef typename Rhs::Scalar RhsScalar;
277 typedef typename Dest::Scalar ResScalar;
279 typedef internal::blas_traits<Lhs> LhsBlasTraits;
280 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
281 typedef internal::blas_traits<Rhs> RhsBlasTraits;
282 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
284 typedef Map<Matrix<ResScalar, Dynamic, 1>, plain_enum_min(AlignedMax, internal::packet_traits<ResScalar>::size)>
287 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
288 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
290 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
293 typedef std::conditional_t<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr> ActualDest;
298 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime == 1),
299 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
300 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime != 0)
303 typedef const_blas_data_mapper<LhsScalar, Index, ColMajor> LhsMapper;
304 typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;
305 RhsScalar compatibleAlpha = get_factor<ResScalar, RhsScalar>::run(actualAlpha);
307 if (!MightCannotUseDest) {
310 general_matrix_vector_product<
Index, LhsScalar, LhsMapper,
ColMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
311 RhsMapper, RhsBlasTraits::NeedToConjugate>::run(actualLhs.rows(), actualLhs.cols(),
312 LhsMapper(actualLhs.data(),
313 actualLhs.outerStride()),
314 RhsMapper(actualRhs.data(),
315 actualRhs.innerStride()),
316 dest.data(), 1, compatibleAlpha);
318 gemv_static_vector_if<ResScalar, ActualDest::SizeAtCompileTime, ActualDest::MaxSizeAtCompileTime,
322 const bool alphaIsCompatible = (!ComplexByReal) || (numext::is_exactly_zero(numext::imag(actualAlpha)));
323 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
325 ei_declare_aligned_stack_constructed_variable(ResScalar, actualDestPtr, dest.size(),
326 evalToDest ? dest.data() : static_dest.data());
329 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 330 constexpr
int Size = Dest::SizeAtCompileTime;
331 Index size = dest.size();
332 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
334 if (!alphaIsCompatible) {
335 MappedDest(actualDestPtr, dest.size()).setZero();
336 compatibleAlpha = RhsScalar(1);
338 MappedDest(actualDestPtr, dest.size()) = dest;
341 general_matrix_vector_product<
Index, LhsScalar, LhsMapper,
ColMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
342 RhsMapper, RhsBlasTraits::NeedToConjugate>::run(actualLhs.rows(), actualLhs.cols(),
343 LhsMapper(actualLhs.data(),
344 actualLhs.outerStride()),
345 RhsMapper(actualRhs.data(),
346 actualRhs.innerStride()),
347 actualDestPtr, 1, compatibleAlpha);
350 if (!alphaIsCompatible)
351 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
353 dest = MappedDest(actualDestPtr, dest.size());
361 template <
typename Lhs,
typename Rhs,
typename Dest>
362 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
363 typedef typename Lhs::Scalar LhsScalar;
364 typedef typename Rhs::Scalar RhsScalar;
365 typedef typename Dest::Scalar ResScalar;
367 typedef internal::blas_traits<Lhs> LhsBlasTraits;
368 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
369 typedef internal::blas_traits<Rhs> RhsBlasTraits;
370 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
371 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
373 std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
374 std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
376 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
382 ActualRhsTypeCleaned::InnerStrideAtCompileTime == 1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime == 0
385 gemv_static_vector_if<RhsScalar, ActualRhsTypeCleaned::SizeAtCompileTime,
386 ActualRhsTypeCleaned::MaxSizeAtCompileTime, !DirectlyUseRhs>
389 ei_declare_aligned_stack_constructed_variable(
390 RhsScalar, actualRhsPtr, actualRhs.size(),
391 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
393 if (!DirectlyUseRhs) {
394 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 395 constexpr
int Size = ActualRhsTypeCleaned::SizeAtCompileTime;
396 Index size = actualRhs.size();
397 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
399 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
402 typedef const_blas_data_mapper<LhsScalar, Index, RowMajor> LhsMapper;
403 typedef const_blas_data_mapper<RhsScalar, Index, ColMajor> RhsMapper;
404 general_matrix_vector_product<
Index, LhsScalar, LhsMapper,
RowMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
405 RhsMapper, RhsBlasTraits::NeedToConjugate>::
406 run(actualLhs.rows(), actualLhs.cols(), LhsMapper(actualLhs.data(), actualLhs.outerStride()),
407 RhsMapper(actualRhsPtr, 1), dest.data(),
408 dest.col(0).innerStride(),
416 template <
typename Lhs,
typename Rhs,
typename Dest>
417 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
418 EIGEN_STATIC_ASSERT((!nested_eval<Lhs, 1>::Evaluate),
419 EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
422 typename nested_eval<Rhs, 1>::type actual_rhs(rhs);
423 const Index size = rhs.rows();
424 for (
Index k = 0; k < size; ++k) dest += (alpha * actual_rhs.coeff(k)) * lhs.col(k);
430 template <
typename Lhs,
typename Rhs,
typename Dest>
431 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
432 EIGEN_STATIC_ASSERT((!nested_eval<Lhs, 1>::Evaluate),
433 EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
434 typename nested_eval<Rhs, Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
435 const Index rows = dest.rows();
436 for (
Index i = 0; i < rows; ++i)
437 dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
453 template <
typename Derived>
454 template <
typename OtherDerived>
462 ProductIsValid = Derived::ColsAtCompileTime ==
Dynamic || OtherDerived::RowsAtCompileTime ==
Dynamic ||
463 int(Derived::ColsAtCompileTime) == int(OtherDerived::RowsAtCompileTime),
464 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
465 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived, OtherDerived)
471 ProductIsValid || !(AreVectors && SameSizes),
472 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
473 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
474 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
475 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
476 #ifdef EIGEN_DEBUG_PRODUCT 477 internal::product_type<Derived, OtherDerived>::debug();
494 template <
typename Derived>
495 template <
typename OtherDerived>
499 ProductIsValid = Derived::ColsAtCompileTime ==
Dynamic || OtherDerived::RowsAtCompileTime ==
Dynamic ||
500 int(Derived::ColsAtCompileTime) == int(OtherDerived::RowsAtCompileTime),
501 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
502 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived, OtherDerived)
508 ProductIsValid || !(AreVectors && SameSizes),
509 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
510 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
511 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
512 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
519 #endif // EIGEN_PRODUCT_H constexpr Derived & derived()
Definition: EigenBase.h:49
Definition: Constants.h:318
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:198
Definition: Constants.h:333
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
const Product< Derived, OtherDerived > operator*(const MatrixBase< OtherDerived > &other) const
Definition: GeneralProduct.h:455
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:331
Definition: Constants.h:320
const int Dynamic
Definition: Constants.h:25
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
const Product< Derived, OtherDerived, LazyProduct > lazyProduct(const MatrixBase< OtherDerived > &other) const
Definition: GeneralProduct.h:497