$darkmode
Eigen  5.0.1-dev
BlasUtil.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_BLASUTIL_H
11 #define EIGEN_BLASUTIL_H
12 
13 // This file contains many lightweight helper classes used to
14 // implement and control fast level 2 and level 3 BLAS-like routines.
15 
16 // IWYU pragma: private
17 #include "../InternalHeaderCheck.h"
18 
19 namespace Eigen {
20 
21 namespace internal {
22 
23 // forward declarations
24 template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
25  bool ConjugateLhs = false, bool ConjugateRhs = false>
26 struct gebp_kernel;
27 
28 template <typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false,
29  bool PanelMode = false>
30 struct gemm_pack_rhs;
31 
32 template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, int StorageOrder,
33  bool Conjugate = false, bool PanelMode = false>
34 struct gemm_pack_lhs;
35 
36 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
37  int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder, int ResInnerStride>
38 struct general_matrix_matrix_product;
39 
40 template <typename Index, typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
41  typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version = Specialized>
42 struct general_matrix_vector_product;
43 
44 template <typename From, typename To>
45 struct get_factor {
46  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
47 };
48 
49 template <typename Scalar>
50 struct get_factor<Scalar, typename NumTraits<Scalar>::Real> {
51  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) {
52  return numext::real(x);
53  }
54 };
55 
56 template <typename Scalar, typename Index>
57 class BlasVectorMapper {
58  public:
59  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar* data) : m_data(data) {}
60 
61  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { return m_data[i]; }
62  template <typename Packet, int AlignmentType>
63  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
64  return ploadt<Packet, AlignmentType>(m_data + i);
65  }
66 
67  template <typename Packet>
68  EIGEN_DEVICE_FUNC bool aligned(Index i) const {
69  return (std::uintptr_t(m_data + i) % sizeof(Packet)) == 0;
70  }
71 
72  protected:
73  Scalar* m_data;
74 };
75 
76 template <typename Scalar, typename Index, int AlignmentType, int Incr = 1>
77 class BlasLinearMapper;
78 
79 template <typename Scalar, typename Index, int AlignmentType>
80 class BlasLinearMapper<Scalar, Index, AlignmentType> {
81  public:
82  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar* data, Index incr = 1) : m_data(data) {
83  EIGEN_ONLY_USED_FOR_DEBUG(incr);
84  eigen_assert(incr == 1);
85  }
86 
87  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i) const { internal::prefetch(&operator()(i)); }
88 
89  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { return m_data[i]; }
90 
91  template <typename PacketType>
92  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
93  return ploadt<PacketType, AlignmentType>(m_data + i);
94  }
95 
96  template <typename PacketType>
97  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index n, Index offset = 0) const {
98  return ploadt_partial<PacketType, AlignmentType>(m_data + i, n, offset);
99  }
100 
101  template <typename PacketType, int AlignmentT>
102  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType load(Index i) const {
103  return ploadt<PacketType, AlignmentT>(m_data + i);
104  }
105 
106  template <typename PacketType>
107  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType& p) const {
108  pstoret<Scalar, PacketType, AlignmentType>(m_data + i, p);
109  }
110 
111  template <typename PacketType>
112  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, const PacketType& p, Index n,
113  Index offset = 0) const {
114  pstoret_partial<Scalar, PacketType, AlignmentType>(m_data + i, p, n, offset);
115  }
116 
117  protected:
118  Scalar* m_data;
119 };
120 
121 // Lightweight helper class to access matrix coefficients.
122 template <typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
123 class blas_data_mapper;
124 
125 // TMP to help PacketBlock store implementation.
126 // There's currently no known use case for PacketBlock load.
127 // The default implementation assumes ColMajor order.
128 // It always store each packet sequentially one `stride` apart.
129 template <typename Index, typename Scalar, typename Packet, int n, int idx, int StorageOrder>
130 struct PacketBlockManagement {
131  PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, StorageOrder> pbm;
132  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
133  const PacketBlock<Packet, n>& block) const {
134  pbm.store(to, stride, i, j, block);
135  pstoreu<Scalar>(to + i + (j + idx) * stride, block.packet[idx]);
136  }
137 };
138 
139 // PacketBlockManagement specialization to take care of RowMajor order without ifs.
140 template <typename Index, typename Scalar, typename Packet, int n, int idx>
141 struct PacketBlockManagement<Index, Scalar, Packet, n, idx, RowMajor> {
142  PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, RowMajor> pbm;
143  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
144  const PacketBlock<Packet, n>& block) const {
145  pbm.store(to, stride, i, j, block);
146  pstoreu<Scalar>(to + j + (i + idx) * stride, block.packet[idx]);
147  }
148 };
149 
150 template <typename Index, typename Scalar, typename Packet, int n, int StorageOrder>
151 struct PacketBlockManagement<Index, Scalar, Packet, n, -1, StorageOrder> {
152  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
153  const PacketBlock<Packet, n>& block) const {
154  EIGEN_UNUSED_VARIABLE(to);
155  EIGEN_UNUSED_VARIABLE(stride);
156  EIGEN_UNUSED_VARIABLE(i);
157  EIGEN_UNUSED_VARIABLE(j);
158  EIGEN_UNUSED_VARIABLE(block);
159  }
160 };
161 
162 template <typename Index, typename Scalar, typename Packet, int n>
163 struct PacketBlockManagement<Index, Scalar, Packet, n, -1, RowMajor> {
164  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
165  const PacketBlock<Packet, n>& block) const {
166  EIGEN_UNUSED_VARIABLE(to);
167  EIGEN_UNUSED_VARIABLE(stride);
168  EIGEN_UNUSED_VARIABLE(i);
169  EIGEN_UNUSED_VARIABLE(j);
170  EIGEN_UNUSED_VARIABLE(block);
171  }
172 };
173 
174 template <typename Scalar, typename Index, int StorageOrder, int AlignmentType>
175 class blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, 1> {
176  public:
177  typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
178  typedef blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> SubMapper;
179  typedef BlasVectorMapper<Scalar, Index> VectorMapper;
180 
181  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr = 1)
182  : m_data(data), m_stride(stride) {
183  EIGEN_ONLY_USED_FOR_DEBUG(incr);
184  eigen_assert(incr == 1);
185  }
186 
187  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
188  return SubMapper(&operator()(i, j), m_stride);
189  }
190 
191  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
192  return LinearMapper(&operator()(i, j));
193  }
194 
195  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
196  return VectorMapper(&operator()(i, j));
197  }
198 
199  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i, Index j) const { internal::prefetch(&operator()(i, j)); }
200 
201  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
202  return m_data[StorageOrder == RowMajor ? j + i * m_stride : i + j * m_stride];
203  }
204 
205  template <typename PacketType>
206  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
207  return ploadt<PacketType, AlignmentType>(&operator()(i, j));
208  }
209 
210  template <typename PacketType>
211  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index j, Index n,
212  Index offset = 0) const {
213  return ploadt_partial<PacketType, AlignmentType>(&operator()(i, j), n, offset);
214  }
215 
216  template <typename PacketT, int AlignmentT>
217  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
218  return ploadt<PacketT, AlignmentT>(&operator()(i, j));
219  }
220 
221  template <typename PacketType>
222  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType& p) const {
223  pstoret<Scalar, PacketType, AlignmentType>(&operator()(i, j), p);
224  }
225 
226  template <typename PacketType>
227  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, Index j, const PacketType& p, Index n,
228  Index offset = 0) const {
229  pstoret_partial<Scalar, PacketType, AlignmentType>(&operator()(i, j), p, n, offset);
230  }
231 
232  template <typename SubPacket>
233  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket& p) const {
234  pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
235  }
236 
237  template <typename SubPacket>
238  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
239  return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
240  }
241 
242  EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
243  EIGEN_DEVICE_FUNC const Index incr() const { return 1; }
244  EIGEN_DEVICE_FUNC constexpr const Scalar* data() const { return m_data; }
245 
246  EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
247  if (std::uintptr_t(m_data) % sizeof(Scalar)) {
248  return -1;
249  }
250  return internal::first_default_aligned(m_data, size);
251  }
252 
253  template <typename SubPacket, int n>
254  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j,
255  const PacketBlock<SubPacket, n>& block) const {
256  PacketBlockManagement<Index, Scalar, SubPacket, n, n - 1, StorageOrder> pbm;
257  pbm.store(m_data, m_stride, i, j, block);
258  }
259 
260  protected:
261  Scalar* EIGEN_RESTRICT m_data;
262  const Index m_stride;
263 };
264 
265 // Implementation of non-natural increment (i.e. inner-stride != 1)
266 // The exposed API is not complete yet compared to the Incr==1 case
267 // because some features makes less sense in this case.
268 template <typename Scalar, typename Index, int AlignmentType, int Incr>
269 class BlasLinearMapper {
270  public:
271  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar* data, Index incr) : m_data(data), m_incr(incr) {}
272 
273  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { internal::prefetch(&operator()(i)); }
274 
275  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { return m_data[i * m_incr.value()]; }
276 
277  template <typename PacketType>
278  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
279  return pgather<Scalar, PacketType>(m_data + i * m_incr.value(), m_incr.value());
280  }
281 
282  template <typename PacketType>
283  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index n, Index /*offset*/ = 0) const {
284  return pgather_partial<Scalar, PacketType>(m_data + i * m_incr.value(), m_incr.value(), n);
285  }
286 
287  template <typename PacketType>
288  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType& p) const {
289  pscatter<Scalar, PacketType>(m_data + i * m_incr.value(), p, m_incr.value());
290  }
291 
292  template <typename PacketType>
293  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, const PacketType& p, Index n,
294  Index /*offset*/ = 0) const {
295  pscatter_partial<Scalar, PacketType>(m_data + i * m_incr.value(), p, m_incr.value(), n);
296  }
297 
298  protected:
299  Scalar* m_data;
300  const internal::variable_if_dynamic<Index, Incr> m_incr;
301 };
302 
303 template <typename Scalar, typename Index, int StorageOrder, int AlignmentType, int Incr>
304 class blas_data_mapper {
305  public:
306  typedef BlasLinearMapper<Scalar, Index, AlignmentType, Incr> LinearMapper;
307  typedef blas_data_mapper SubMapper;
308 
309  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr)
310  : m_data(data), m_stride(stride), m_incr(incr) {}
311 
312  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
313  return SubMapper(&operator()(i, j), m_stride, m_incr.value());
314  }
315 
316  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
317  return LinearMapper(&operator()(i, j), m_incr.value());
318  }
319 
320  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i, Index j) const { internal::prefetch(&operator()(i, j)); }
321 
322  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
323  return m_data[StorageOrder == RowMajor ? j * m_incr.value() + i * m_stride : i * m_incr.value() + j * m_stride];
324  }
325 
326  template <typename PacketType>
327  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
328  return pgather<Scalar, PacketType>(&operator()(i, j), m_incr.value());
329  }
330 
331  template <typename PacketType>
332  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index j, Index n,
333  Index /*offset*/ = 0) const {
334  return pgather_partial<Scalar, PacketType>(&operator()(i, j), m_incr.value(), n);
335  }
336 
337  template <typename PacketT, int AlignmentT>
338  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
339  return pgather<Scalar, PacketT>(&operator()(i, j), m_incr.value());
340  }
341 
342  template <typename PacketType>
343  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType& p) const {
344  pscatter<Scalar, PacketType>(&operator()(i, j), p, m_incr.value());
345  }
346 
347  template <typename PacketType>
348  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, Index j, const PacketType& p, Index n,
349  Index /*offset*/ = 0) const {
350  pscatter_partial<Scalar, PacketType>(&operator()(i, j), p, m_incr.value(), n);
351  }
352 
353  template <typename SubPacket>
354  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket& p) const {
355  pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
356  }
357 
358  template <typename SubPacket>
359  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
360  return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
361  }
362 
363  // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the
364  // Complex types.
365  template <typename SubPacket, typename Scalar_, int n, int idx>
366  struct storePacketBlock_helper {
367  storePacketBlock_helper<SubPacket, Scalar_, n, idx - 1> spbh;
368  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
369  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
370  const PacketBlock<SubPacket, n>& block) const {
371  spbh.store(sup, i, j, block);
372  sup->template storePacket<SubPacket>(i, j + idx, block.packet[idx]);
373  }
374  };
375 
376  template <typename SubPacket, int n, int idx>
377  struct storePacketBlock_helper<SubPacket, std::complex<float>, n, idx> {
378  storePacketBlock_helper<SubPacket, std::complex<float>, n, idx - 1> spbh;
379  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
380  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
381  const PacketBlock<SubPacket, n>& block) const {
382  spbh.store(sup, i, j, block);
383  sup->template storePacket<SubPacket>(i, j + idx, block.packet[idx]);
384  }
385  };
386 
387  template <typename SubPacket, int n, int idx>
388  struct storePacketBlock_helper<SubPacket, std::complex<double>, n, idx> {
389  storePacketBlock_helper<SubPacket, std::complex<double>, n, idx - 1> spbh;
390  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
391  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
392  const PacketBlock<SubPacket, n>& block) const {
393  spbh.store(sup, i, j, block);
394  for (int l = 0; l < unpacket_traits<SubPacket>::size; l++) {
395  std::complex<double>* v = &sup->operator()(i + l, j + idx);
396  v->real(block.packet[idx].v[2 * l + 0]);
397  v->imag(block.packet[idx].v[2 * l + 1]);
398  }
399  }
400  };
401 
402  template <typename SubPacket, typename Scalar_, int n>
403  struct storePacketBlock_helper<SubPacket, Scalar_, n, -1> {
404  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
405  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
406  const PacketBlock<SubPacket, n>&) const {}
407  };
408 
409  template <typename SubPacket, int n>
410  struct storePacketBlock_helper<SubPacket, std::complex<float>, n, -1> {
411  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
412  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
413  const PacketBlock<SubPacket, n>&) const {}
414  };
415 
416  template <typename SubPacket, int n>
417  struct storePacketBlock_helper<SubPacket, std::complex<double>, n, -1> {
418  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
419  const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
420  const PacketBlock<SubPacket, n>&) const {}
421  };
422  // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be
423  // avoided when possible.
424  template <typename SubPacket, int n>
425  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j,
426  const PacketBlock<SubPacket, n>& block) const {
427  storePacketBlock_helper<SubPacket, Scalar, n, n - 1> spb;
428  spb.store(this, i, j, block);
429  }
430 
431  EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
432  EIGEN_DEVICE_FUNC const Index incr() const { return m_incr.value(); }
433  EIGEN_DEVICE_FUNC constexpr Scalar* data() const { return m_data; }
434 
435  protected:
436  Scalar* EIGEN_RESTRICT m_data;
437  const Index m_stride;
438  const internal::variable_if_dynamic<Index, Incr> m_incr;
439 };
440 
441 // lightweight helper class to access matrix coefficients (const version)
442 template <typename Scalar, typename Index, int StorageOrder>
443 class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
444  public:
445  typedef const_blas_data_mapper<Scalar, Index, StorageOrder> SubMapper;
446 
447  EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar* data, Index stride)
448  : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
449 
450  EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
451  return SubMapper(&(this->operator()(i, j)), this->m_stride);
452  }
453 };
454 
455 /* Helper class to analyze the factors of a Product expression.
456  * In particular it allows to pop out operator-, scalar multiples,
457  * and conjugate */
458 template <typename XprType>
459 struct blas_traits {
460  typedef typename traits<XprType>::Scalar Scalar;
461  typedef const XprType& ExtractType;
462  typedef XprType ExtractType_;
463  enum {
464  IsComplex = NumTraits<Scalar>::IsComplex,
465  IsTransposed = false,
466  NeedToConjugate = false,
467  HasUsableDirectAccess =
468  ((int(XprType::Flags) & DirectAccessBit) &&
469  (bool(XprType::IsVectorAtCompileTime) || int(inner_stride_at_compile_time<XprType>::ret) == 1))
470  ? 1
471  : 0,
472  HasScalarFactor = false
473  };
474  typedef std::conditional_t<bool(HasUsableDirectAccess), ExtractType, typename ExtractType_::PlainObject>
475  DirectLinearAccessType;
476  EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return x; }
477  EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC const Scalar extractScalarFactor(const XprType&) {
478  return Scalar(1);
479  }
480 };
481 
482 // pop conjugate
483 template <typename Scalar, typename NestedXpr>
484 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > : blas_traits<NestedXpr> {
485  typedef blas_traits<NestedXpr> Base;
486  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
487  typedef typename Base::ExtractType ExtractType;
488 
489  enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex };
490  EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
491  EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
492  return conj(Base::extractScalarFactor(x.nestedExpression()));
493  }
494 };
495 
496 // pop scalar multiple
497 template <typename Scalar, typename NestedXpr, typename Plain>
498 struct blas_traits<
499  CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain>, NestedXpr> >
500  : blas_traits<NestedXpr> {
501  enum { HasScalarFactor = true };
502  typedef blas_traits<NestedXpr> Base;
503  typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain>, NestedXpr>
504  XprType;
505  typedef typename Base::ExtractType ExtractType;
506  EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) {
507  return Base::extract(x.rhs());
508  }
509  EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC Scalar extractScalarFactor(const XprType& x) {
510  return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs());
511  }
512 };
513 template <typename Scalar, typename NestedXpr, typename Plain>
514 struct blas_traits<
515  CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain> > >
516  : blas_traits<NestedXpr> {
517  enum { HasScalarFactor = true };
518  typedef blas_traits<NestedXpr> Base;
519  typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain> >
520  XprType;
521  typedef typename Base::ExtractType ExtractType;
522  EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
523  EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
524  return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other;
525  }
526 };
527 template <typename Scalar, typename Plain1, typename Plain2>
528 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain1>,
529  const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain2> > >
530  : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>, Plain1> > {};
531 
532 // pop opposite
533 template <typename Scalar, typename NestedXpr>
534 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > : blas_traits<NestedXpr> {
535  enum { HasScalarFactor = true };
536  typedef blas_traits<NestedXpr> Base;
537  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
538  typedef typename Base::ExtractType ExtractType;
539  EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
540  EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
541  return -Base::extractScalarFactor(x.nestedExpression());
542  }
543 };
544 
545 // pop/push transpose
546 template <typename NestedXpr>
547 struct blas_traits<Transpose<NestedXpr> > : blas_traits<NestedXpr> {
548  typedef typename NestedXpr::Scalar Scalar;
549  typedef blas_traits<NestedXpr> Base;
550  typedef Transpose<NestedXpr> XprType;
551  typedef Transpose<const typename Base::ExtractType_>
552  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
553  typedef Transpose<const typename Base::ExtractType_> ExtractType_;
554  typedef std::conditional_t<bool(Base::HasUsableDirectAccess), ExtractType, typename ExtractType::PlainObject>
555  DirectLinearAccessType;
556  enum { IsTransposed = Base::IsTransposed ? 0 : 1 };
557  EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) {
558  return ExtractType(Base::extract(x.nestedExpression()));
559  }
560  EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
561  return Base::extractScalarFactor(x.nestedExpression());
562  }
563 };
564 
565 template <typename T>
566 struct blas_traits<const T> : blas_traits<T> {};
567 
568 template <typename T, bool HasUsableDirectAccess = blas_traits<T>::HasUsableDirectAccess>
569 struct extract_data_selector {
570  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename T::Scalar* run(const T& m) {
571  return blas_traits<T>::extract(m).data();
572  }
573 };
574 
575 template <typename T>
576 struct extract_data_selector<T, false> {
577  EIGEN_DEVICE_FUNC static typename T::Scalar* run(const T&) { return 0; }
578 };
579 
580 template <typename T>
581 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename T::Scalar* extract_data(const T& m) {
582  return extract_data_selector<T>::run(m);
583 }
584 
589 template <typename ResScalar, typename Lhs, typename Rhs>
591  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs) {
592  return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
593  }
594  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) {
595  return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
596  }
597 };
598 template <typename Lhs, typename Rhs>
599 struct combine_scalar_factors_impl<bool, Lhs, Rhs> {
600  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs) {
601  return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
602  }
603  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs) {
604  return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
605  }
606 };
607 
608 template <typename ResScalar, typename Lhs, typename Rhs>
609 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs,
610  const Rhs& rhs) {
611  return combine_scalar_factors_impl<ResScalar, Lhs, Rhs>::run(alpha, lhs, rhs);
612 }
613 template <typename ResScalar, typename Lhs, typename Rhs>
614 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs) {
615  return combine_scalar_factors_impl<ResScalar, Lhs, Rhs>::run(lhs, rhs);
616 }
617 
618 } // end namespace internal
619 
620 } // end namespace Eigen
621 
622 #endif // EIGEN_BLASUTIL_H
const unsigned int DirectAccessBit
Definition: Constants.h:159
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Definition: BFloat16.h:231
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
AlignmentType
Definition: Constants.h:234
Definition: Constants.h:320