$darkmode
Eigen  5.0.1-dev
GeneralMatrixMatrixTriangular.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
11 #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
12 
13 // IWYU pragma: private
14 #include "../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 template <typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
19 struct selfadjoint_rank1_update;
20 
21 namespace internal {
22 
23 /**********************************************************************
24  * This file implements a general A * B product while
25  * evaluating only one triangular part of the product.
26  * This is a more general version of self adjoint product (C += A A^T)
27  * as the level 3 SYRK Blas routine.
28  **********************************************************************/
29 
30 // forward declarations (defined at the end of this file)
31 template <typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs,
32  int ResInnerStride, int UpLo>
33 struct tribb_kernel;
34 
35 /* Optimized matrix-matrix product evaluating only one triangular half */
36 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
37  int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder, int ResInnerStride, int UpLo,
38  int Version = Specialized>
39 struct general_matrix_matrix_triangular_product;
40 
41 // as usual if the result is row major => we transpose the product
42 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
43  int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride, int UpLo, int Version>
44 struct general_matrix_matrix_triangular_product<Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
45  RhsStorageOrder, ConjugateRhs, RowMajor, ResInnerStride, UpLo,
46  Version> {
47  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
48  static EIGEN_STRONG_INLINE void run(Index size, Index depth, const LhsScalar* lhs, Index lhsStride,
49  const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr,
50  Index resStride, const ResScalar& alpha,
51  level3_blocking<RhsScalar, LhsScalar>& blocking) {
52  general_matrix_matrix_triangular_product<Index, RhsScalar, RhsStorageOrder == RowMajor ? ColMajor : RowMajor,
53  ConjugateRhs, LhsScalar, LhsStorageOrder == RowMajor ? ColMajor : RowMajor,
54  ConjugateLhs, ColMajor, ResInnerStride,
55  UpLo == Lower ? Upper : Lower>::run(size, depth, rhs, rhsStride, lhs,
56  lhsStride, res, resIncr, resStride,
57  alpha, blocking);
58  }
59 };
60 
61 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
62  int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride, int UpLo, int Version>
63 struct general_matrix_matrix_triangular_product<Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
64  RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride, UpLo,
65  Version> {
66  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
67  static EIGEN_STRONG_INLINE void run(Index size, Index depth, const LhsScalar* lhs_, Index lhsStride,
68  const RhsScalar* rhs_, Index rhsStride, ResScalar* res_, Index resIncr,
69  Index resStride, const ResScalar& alpha,
70  level3_blocking<LhsScalar, RhsScalar>& blocking) {
71  if (size == 0) {
72  return;
73  }
74 
75  typedef gebp_traits<LhsScalar, RhsScalar> Traits;
76 
77  typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
78  typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
79  typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
80  LhsMapper lhs(lhs_, lhsStride);
81  RhsMapper rhs(rhs_, rhsStride);
82  ResMapper res(res_, resStride, resIncr);
83 
84  Index kc = blocking.kc();
85  // Ensure that mc >= nr and <= size
86  Index mc = (std::min)(size, (std::max)(static_cast<decltype(blocking.mc())>(Traits::nr), blocking.mc()));
87 
88  // !!! mc must be a multiple of nr
89  if (mc > Traits::nr) {
90  using UnsignedIndex = typename make_unsigned<Index>::type;
91  mc = (UnsignedIndex(mc) / Traits::nr) * Traits::nr;
92  }
93 
94  std::size_t sizeA = kc * mc;
95  std::size_t sizeB = kc * size;
96 
97  ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
98  ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
99 
100  gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing,
101  LhsStorageOrder>
102  pack_lhs;
103  gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
104  gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
105  tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo>
106  sybb;
107 
108  for (Index k2 = 0; k2 < depth; k2 += kc) {
109  const Index actual_kc = (std::min)(k2 + kc, depth) - k2;
110 
111  // note that the actual rhs is the transpose/adjoint of mat
112  pack_rhs(blockB, rhs.getSubMapper(k2, 0), actual_kc, size);
113 
114  for (Index i2 = 0; i2 < size; i2 += mc) {
115  const Index actual_mc = (std::min)(i2 + mc, size) - i2;
116 
117  pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
118 
119  // the selected actual_mc * size panel of res is split into three different part:
120  // 1 - before the diagonal => processed with gebp or skipped
121  // 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
122  // 3 - after the diagonal => processed with gebp or skipped
123  if (UpLo == Lower)
124  gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, (std::min)(size, i2), alpha, -1, -1, 0,
125  0);
126 
127  sybb(res_ + resStride * i2 + resIncr * i2, resIncr, resStride, blockA, blockB + actual_kc * i2, actual_mc,
128  actual_kc, alpha);
129 
130  if (UpLo == Upper) {
131  Index j2 = i2 + actual_mc;
132  gebp(res.getSubMapper(i2, j2), blockA, blockB + actual_kc * j2, actual_mc, actual_kc,
133  (std::max)(Index(0), size - j2), alpha, -1, -1, 0, 0);
134  }
135  }
136  }
137  }
138 };
139 
140 // Optimized packed Block * packed Block product kernel evaluating only one given triangular part
141 // This kernel is built on top of the gebp kernel:
142 // - the current destination block is processed per panel of actual_mc x BlockSize
143 // where BlockSize is set to the minimal value allowing gebp to be as fast as possible
144 // - then, as usual, each panel is split into three parts along the diagonal,
145 // the sub blocks above and below the diagonal are processed as usual,
146 // while the triangular block overlapping the diagonal is evaluated into a
147 // small temporary buffer which is then accumulated into the result using a
148 // triangular traversal.
149 template <typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs,
150  int ResInnerStride, int UpLo>
151 struct tribb_kernel {
152  typedef gebp_traits<LhsScalar, RhsScalar, ConjLhs, ConjRhs> Traits;
153  typedef typename Traits::ResScalar ResScalar;
154 
155  enum { BlockSize = meta_least_common_multiple<plain_enum_max(mr, nr), plain_enum_min(mr, nr)>::ret };
156  void operator()(ResScalar* res_, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB,
157  Index size, Index depth, const ResScalar& alpha) {
158  typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
159  typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
160  ResMapper res(res_, resStride, resIncr);
161  gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
162  gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
163 
164  Matrix<ResScalar, BlockSize, BlockSize, ColMajor> buffer;
165 
166  // let's process the block per panel of actual_mc x BlockSize,
167  // again, each is split into three parts, etc.
168  for (Index j = 0; j < size; j += BlockSize) {
169  Index actualBlockSize = std::min<Index>(BlockSize, size - j);
170  const RhsScalar* actual_b = blockB + j * depth;
171 
172  if (UpLo == Upper)
173  gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, -1, -1, 0, 0);
174 
175  // selfadjoint micro block
176  {
177  Index i = j;
178  buffer.setZero();
179  // 1 - apply the kernel on the temporary buffer
180  gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA + depth * i, actual_b, actualBlockSize, depth,
181  actualBlockSize, alpha, -1, -1, 0, 0);
182 
183  // 2 - triangular accumulation
184  for (Index j1 = 0; j1 < actualBlockSize; ++j1) {
185  typename ResMapper::LinearMapper r = res.getLinearMapper(i, j + j1);
186  for (Index i1 = UpLo == Lower ? j1 : 0; UpLo == Lower ? i1 < actualBlockSize : i1 <= j1; ++i1)
187  r(i1) += buffer(i1, j1);
188  }
189  }
190 
191  if (UpLo == Lower) {
192  Index i = j + actualBlockSize;
193  gebp_kernel1(res.getSubMapper(i, j), blockA + depth * i, actual_b, size - i, depth, actualBlockSize, alpha, -1,
194  -1, 0, 0);
195  }
196  }
197  }
198 };
199 
200 } // end namespace internal
201 
202 // high level API
203 
204 template <typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct>
205 struct general_product_to_triangular_selector;
206 
207 template <typename MatrixType, typename ProductType, int UpLo>
208 struct general_product_to_triangular_selector<MatrixType, ProductType, UpLo, true> {
209  static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) {
210  typedef typename MatrixType::Scalar Scalar;
211 
212  typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
213  typedef internal::blas_traits<Lhs> LhsBlasTraits;
214  typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
215  typedef internal::remove_all_t<ActualLhs> ActualLhs_;
216  internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
217 
218  typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
219  typedef internal::blas_traits<Rhs> RhsBlasTraits;
220  typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
221  typedef internal::remove_all_t<ActualRhs> ActualRhs_;
222  internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
223 
224  Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
225  RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
226 
227  if (!beta) mat.template triangularView<UpLo>().setZero();
228 
229  enum {
230  StorageOrder = (internal::traits<MatrixType>::Flags & RowMajorBit) ? RowMajor : ColMajor,
231  UseLhsDirectly = ActualLhs_::InnerStrideAtCompileTime == 1,
232  UseRhsDirectly = ActualRhs_::InnerStrideAtCompileTime == 1
233  };
234 
235  internal::gemv_static_vector_if<Scalar, Lhs::SizeAtCompileTime, Lhs::MaxSizeAtCompileTime, !UseLhsDirectly>
236  static_lhs;
237  ei_declare_aligned_stack_constructed_variable(
238  Scalar, actualLhsPtr, actualLhs.size(),
239  (UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data()));
240  if (!UseLhsDirectly) Map<typename ActualLhs_::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
241 
242  internal::gemv_static_vector_if<Scalar, Rhs::SizeAtCompileTime, Rhs::MaxSizeAtCompileTime, !UseRhsDirectly>
243  static_rhs;
244  ei_declare_aligned_stack_constructed_variable(
245  Scalar, actualRhsPtr, actualRhs.size(),
246  (UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data()));
247  if (!UseRhsDirectly) Map<typename ActualRhs_::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
248 
249  selfadjoint_rank1_update<
250  Scalar, Index, StorageOrder, UpLo, LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
251  RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>::run(actualLhs.size(), mat.data(),
252  mat.outerStride(), actualLhsPtr,
253  actualRhsPtr, actualAlpha);
254  }
255 };
256 
257 template <typename MatrixType, typename ProductType, int UpLo>
258 struct general_product_to_triangular_selector<MatrixType, ProductType, UpLo, false> {
259  static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) {
260  typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
261  typedef internal::blas_traits<Lhs> LhsBlasTraits;
262  typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
263  typedef internal::remove_all_t<ActualLhs> ActualLhs_;
264  internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
265 
266  typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
267  typedef internal::blas_traits<Rhs> RhsBlasTraits;
268  typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
269  typedef internal::remove_all_t<ActualRhs> ActualRhs_;
270  internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
271 
272  typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
273  RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
274 
275  if (!beta) mat.template triangularView<UpLo>().setZero();
276 
277  enum {
278  IsRowMajor = (internal::traits<MatrixType>::Flags & RowMajorBit) ? 1 : 0,
279  LhsIsRowMajor = ActualLhs_::Flags & RowMajorBit ? 1 : 0,
280  RhsIsRowMajor = ActualRhs_::Flags & RowMajorBit ? 1 : 0,
281  SkipDiag = (UpLo & (UnitDiag | ZeroDiag)) != 0
282  };
283 
284  Index size = mat.cols();
285  if (SkipDiag) size--;
286  Index depth = actualLhs.cols();
287 
288  typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor, typename Lhs::Scalar, typename Rhs::Scalar,
289  MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
290  ActualRhs_::MaxColsAtCompileTime>
291  BlockingType;
292 
293  BlockingType blocking(size, size, depth, 1, false);
294 
295  internal::general_matrix_matrix_triangular_product<
296  Index, typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
297  typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
298  IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime,
299  UpLo&(Lower | Upper)>::run(size, depth, &actualLhs.coeffRef(SkipDiag && (UpLo & Lower) == Lower ? 1 : 0, 0),
300  actualLhs.outerStride(),
301  &actualRhs.coeffRef(0, SkipDiag && (UpLo & Upper) == Upper ? 1 : 0),
302  actualRhs.outerStride(),
303  mat.data() +
304  (SkipDiag ? (bool(IsRowMajor) != ((UpLo & Lower) == Lower) ? mat.innerStride()
305  : mat.outerStride())
306  : 0),
307  mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
308  }
309 };
310 
311 template <typename MatrixType_, unsigned int Mode_>
312 template <typename ProductType>
313 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename TriangularViewImpl<MatrixType_, Mode_, Dense>::TriangularViewType&
314 TriangularViewImpl<MatrixType_, Mode_, Dense>::_assignProduct(
315  const ProductType& prod, const typename TriangularViewImpl<MatrixType_, Mode_, Dense>::Scalar& alpha, bool beta) {
316  EIGEN_STATIC_ASSERT((Mode_ & UnitDiag) == 0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
317  eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
318 
319  general_product_to_triangular_selector<MatrixType_, ProductType, Mode_,
320  internal::traits<ProductType>::InnerSize == 1>::run(derived()
321  .nestedExpression()
322  .const_cast_derived(),
323  prod, alpha, beta);
324 
325  return derived();
326 }
327 
328 } // end namespace Eigen
329 
330 #endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
Definition: Constants.h:318
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
const unsigned int RowMajorBit
Definition: Constants.h:70
Definition: Constants.h:211
Definition: Constants.h:215
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:213
Definition: Constants.h:217
Definition: Constants.h:320