$darkmode
Eigen-unsupported  5.0.1-dev
TensorContractionMapper.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 enum { Rhs = 0, Lhs = 1 };
21 
22 /*
23  * Implementation of the Eigen blas_data_mapper class for tensors.
24  */
27 template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
28 struct CoeffLoader;
29 
30 template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
31  int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
32  template <class> class MakePointer_ = MakePointer>
33 class BaseTensorContractionMapper;
34 
35 template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
36 struct CoeffLoader {
37  enum { DirectOffsets = false };
38 
39  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) {}
40 
41  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
42  eigen_assert(false && "unsupported");
43  }
44 
45  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
46  eigen_assert(false && "unsupported");
47  return NULL;
48  }
49 
50  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
51  return m_tensor.coeff(index);
52  }
53 
54  template <int LoadMode>
55  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
56  return m_tensor.template packet<LoadMode>(index);
57  }
58 
59  private:
60  const Tensor m_tensor;
61 };
62 
63 template <typename Tensor, template <class> class MakePointer_>
64 struct CoeffLoader<Tensor, true, MakePointer_> {
65  enum { DirectOffsets = true };
66 
67  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
68 
69  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { m_data += offset; }
70 
71  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
72  return m_data;
73  }
74 
75  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
76  return loadConstant(m_data + index);
77  }
78 
79  template <int LoadMode>
80  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
81  return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
82  }
83 
84  private:
85  typedef typename Tensor::Scalar Scalar;
86 
87  typename MakePointer_<const Scalar>::Type m_data;
88 };
89 
90 template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
91  int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
92 class SimpleTensorContractionMapper {
93  public:
94  EIGEN_DEVICE_FUNC SimpleTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
95  const nocontract_t& ij_strides, const contract_t& contract_strides,
96  const contract_t& k_strides)
97  : m_tensor(tensor),
98  m_nocontract_strides(nocontract_strides),
99  m_ij_strides(ij_strides),
100  m_contract_strides(contract_strides),
101  m_k_strides(k_strides) {}
102 
103  enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
104 
105  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
106  m_tensor.offsetBuffer(offset);
107  }
108 
109  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void prefetch(Index /*i*/) {}
110 
111  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
112  // column major assumption
113  return operator()(row, 0);
114  }
115 
116  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
117  return m_tensor.coeff(computeIndex(row, col));
118  }
119 
120  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
121  const bool left = (side == Lhs);
122  EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
123  Index nocontract_val = left ? row : col;
124  Index linidx = 0;
125  EIGEN_UNROLL_LOOP
126  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
127  const Index idx = nocontract_val / m_ij_strides[i];
128  linidx += idx * m_nocontract_strides[i];
129  nocontract_val -= idx * m_ij_strides[i];
130  }
131  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
132  if (side == Lhs && inner_dim_contiguous) {
133  eigen_assert(m_nocontract_strides[0] == 1);
134  linidx += nocontract_val;
135  } else {
136  linidx += nocontract_val * m_nocontract_strides[0];
137  }
138  }
139 
140  Index contract_val = left ? col : row;
141  if (array_size<contract_t>::value > 0) {
142  EIGEN_UNROLL_LOOP
143  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
144  const Index idx = contract_val / m_k_strides[i];
145  linidx += idx * m_contract_strides[i];
146  contract_val -= idx * m_k_strides[i];
147  }
148 
149  if (side == Rhs && inner_dim_contiguous) {
150  eigen_assert(m_contract_strides[0] == 1);
151  linidx += contract_val;
152  } else {
153  linidx += contract_val * m_contract_strides[0];
154  }
155  }
156 
157  return linidx;
158  }
159 
160  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
161  const Index distance) const {
162  const bool left = (side == Lhs);
163  EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
164  Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
165  Index linidx[2] = {0, 0};
166  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
167  EIGEN_UNROLL_LOOP
168  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
169  const Index idx0 = nocontract_val[0] / m_ij_strides[i];
170  const Index idx1 = nocontract_val[1] / m_ij_strides[i];
171  linidx[0] += idx0 * m_nocontract_strides[i];
172  linidx[1] += idx1 * m_nocontract_strides[i];
173  nocontract_val[0] -= idx0 * m_ij_strides[i];
174  nocontract_val[1] -= idx1 * m_ij_strides[i];
175  }
176  if (side == Lhs && inner_dim_contiguous) {
177  eigen_assert(m_nocontract_strides[0] == 1);
178  linidx[0] += nocontract_val[0];
179  linidx[1] += nocontract_val[1];
180  } else {
181  linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
182  linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
183  }
184  }
185 
186  Index contract_val[2] = {left ? col : row, left ? col : row + distance};
187  if (array_size<contract_t>::value > 0) {
188  EIGEN_UNROLL_LOOP
189  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
190  const Index idx0 = contract_val[0] / m_k_strides[i];
191  const Index idx1 = contract_val[1] / m_k_strides[i];
192  linidx[0] += idx0 * m_contract_strides[i];
193  linidx[1] += idx1 * m_contract_strides[i];
194  contract_val[0] -= idx0 * m_k_strides[i];
195  contract_val[1] -= idx1 * m_k_strides[i];
196  }
197 
198  if (side == Rhs && inner_dim_contiguous) {
199  eigen_assert(m_contract_strides[0] == 1);
200  linidx[0] += contract_val[0];
201  linidx[1] += contract_val[1];
202  } else {
203  linidx[0] += contract_val[0] * m_contract_strides[0];
204  linidx[1] += contract_val[1] * m_contract_strides[0];
205  }
206  }
207  return IndexPair<Index>(linidx[0], linidx[1]);
208  }
209 
210  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
211  // Only claim alignment when we can compute the actual stride (ie when we're
212  // dealing with the lhs with inner_dim_contiguous. This is because the
213  // matrix-vector product relies on the stride when dealing with aligned inputs.
214  return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
215  }
216  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
217  return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
218  }
219 
220  const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { return m_tensor; }
221 
222  const nocontract_t& nocontract_strides() const { return m_nocontract_strides; }
223  const nocontract_t& ij_strides() const { return m_ij_strides; }
224  const contract_t& contract_strides() const { return m_contract_strides; }
225  const contract_t& k_strides() const { return m_k_strides; }
226 
227  protected:
228  CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
229  const nocontract_t m_nocontract_strides;
230  const nocontract_t m_ij_strides;
231  const contract_t m_contract_strides;
232  const contract_t m_k_strides;
233 };
234 
235 template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
236  int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
237  template <class> class MakePointer_>
238 class BaseTensorContractionMapper
239  : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
240  inner_dim_contiguous, Alignment, MakePointer_> {
241  public:
242  typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
243  inner_dim_contiguous, Alignment, MakePointer_>
244  ParentMapper;
245 
246  EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
247  const nocontract_t& ij_strides, const contract_t& contract_strides,
248  const contract_t& k_strides)
249  : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
250 
251  template <typename PacketT, int AlignmentType>
252  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
253  std::enable_if_t<internal::unpacket_traits<PacketT>::size == packet_size, PacketT>
254  load(Index i, Index j) const {
255  // whole method makes column major assumption
256 
257  // don't need to add offsets for now (because operator handles that)
258  // current code assumes packet size must be a multiple of 2
259  EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
260 
261  if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
262  const Index index = this->computeIndex(i, j);
263  eigen_assert(this->computeIndex(i + packet_size - 1, j) == index + packet_size - 1);
264  return this->m_tensor.template packet<AlignmentType>(index);
265  }
266 
267  const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
268  const Index first = indexPair.first;
269  const Index lastIdx = indexPair.second;
270 
271  // We can always do optimized packet reads from left hand side right now, because
272  // the vertical matrix dimension on the left hand side is never contracting.
273  // On the right hand side we need to check if the contracting dimensions may have
274  // been shuffled first.
275  if (Tensor::PacketAccess && (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
276  (lastIdx - first) == (packet_size - 1)) {
277  return this->m_tensor.template packet<AlignmentType>(first);
278  }
279 
280  EIGEN_ALIGN_MAX Scalar data[packet_size];
281 
282  data[0] = this->m_tensor.coeff(first);
283  EIGEN_UNROLL_LOOP
284  for (Index k = 1; k < packet_size - 1; k += 2) {
285  const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
286  data[k] = this->m_tensor.coeff(internal_pair.first);
287  data[k + 1] = this->m_tensor.coeff(internal_pair.second);
288  }
289  data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
290 
291  return pload<PacketT>(data);
292  }
293 
294  template <typename PacketT, int AlignmentType>
295  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
296  std::enable_if_t<internal::unpacket_traits<PacketT>::size != packet_size, PacketT>
297  load(Index i, Index j) const {
298  const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
299  EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
300 
301  const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
302  const Index first = indexPair.first;
303  const Index lastIdx = indexPair.second;
304 
305  data[0] = this->m_tensor.coeff(first);
306  for (Index k = 1; k < requested_packet_size - 1; k += 2) {
307  const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
308  data[k] = this->m_tensor.coeff(internal_pair.first);
309  data[k + 1] = this->m_tensor.coeff(internal_pair.second);
310  }
311  data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
312 
313  return pload<PacketT>(data);
314  }
315 
316  template <typename PacketT, int AlignmentType>
317  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
318  return this->load<PacketT, AlignmentType>(i, j);
319  }
320 };
321 
322 template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
323  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
324 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
325  inner_dim_reordered, Alignment, MakePointer_>
326  : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
327  inner_dim_contiguous, Alignment, MakePointer_> {
328  public:
329  typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
330  Alignment, MakePointer_>
331  ParentMapper;
332 
333  EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
334  const nocontract_t& ij_strides, const contract_t& contract_strides,
335  const contract_t& k_strides)
336  : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
337 
338  template <typename PacketT, int>
339  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
340  EIGEN_ALIGN_MAX Scalar data[1];
341  data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
342  return pload<PacketT>(data);
343  }
344  template <typename PacketT, int>
345  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
346  EIGEN_ALIGN_MAX Scalar data[1];
347  data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
348  return pload<PacketT>(data);
349  }
350 };
351 
352 template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
353  int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
354  template <class> class MakePointer_ = MakePointer>
355 class TensorContractionSubMapper {
356  public:
357  typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
358  inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
359  ParentMapper;
360  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
361  inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
362  Self;
363  typedef Self LinearMapper;
364  typedef Self SubMapper;
365 
366  enum {
367  // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
368  // TODO: we should also enable direct offsets for the Rhs case.
369  UseDirectOffsets =
370  ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
371  };
372 
373  EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
374  : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
375  // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
376  // this offset every time we attempt to access a coefficient.
377  if (UseDirectOffsets) {
378  Index stride = m_base_mapper.stride();
379  m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
380  }
381  }
382 
383  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
384  if (UseDirectOffsets) {
385  return m_base_mapper(i, 0);
386  }
387  return m_base_mapper(i + m_vert_offset, m_horiz_offset);
388  }
389  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
390  if (UseDirectOffsets) {
391  return m_base_mapper(i, j);
392  }
393  return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
394  }
395 
396  template <typename PacketT>
397  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const {
398  if (UseDirectOffsets) {
399  return m_base_mapper.template loadPacket<PacketT, Alignment>(i, 0);
400  }
401  return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, m_horiz_offset);
402  }
403 
404  template <typename PacketT>
405  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
406  if (UseDirectOffsets) {
407  return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
408  }
409  return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
410  }
411 
412  template <typename PacketT>
413  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const {
414  if (UseDirectOffsets) {
415  return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
416  }
417  return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
418  }
419 
420  template <typename PacketT, int AlignmentType>
421  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
422  if (UseDirectOffsets) {
423  return m_base_mapper.template load<PacketT, AlignmentType>(i, j);
424  }
425  return m_base_mapper.template loadPacket<PacketT, AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
426  }
427 
428  template <typename PacketT>
429  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
430  if (UseDirectOffsets) {
431  m_base_mapper.storePacket(i, 0, p);
432  }
433  m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
434  }
435 
436  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
437  if (UseDirectOffsets) {
438  return LinearMapper(m_base_mapper, i, j);
439  }
440  return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
441  }
442 
443  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
444  if (UseDirectOffsets) {
445  return SubMapper(m_base_mapper, i, j);
446  }
447  return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
448  }
449 
450  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); }
451 
452  template <typename PacketT, int AlignmentType>
453  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
454  EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
455  const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
456  if (UseDirectOffsets) {
457  return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);
458  }
459  return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i + m_vert_offset, m_horiz_offset);
460  }
461 
462  template <typename PacketT>
463  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
464  return false;
465  }
466 
467  const ParentMapper& base_mapper() const { return m_base_mapper; }
468  Index vert_offset() const { return m_vert_offset; }
469  Index horiz_offset() const { return m_horiz_offset; }
470 
471  private:
472  ParentMapper m_base_mapper;
473  const Index m_vert_offset;
474  const Index m_horiz_offset;
475 };
476 
477 template <typename Scalar_, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
478  int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
479  template <class> class MakePointer_ = MakePointer>
480 class TensorContractionInputMapper
481  : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
482  inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
483  public:
484  typedef Scalar_ Scalar;
485  typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
486  inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
487  Base;
488  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
489  inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
490  SubMapper;
491  typedef SubMapper VectorMapper;
492  typedef SubMapper LinearMapper;
493 
494  EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
495  const nocontract_t& ij_strides, const contract_t& contract_strides,
496  const contract_t& k_strides)
497  : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
498 
499  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
500  return SubMapper(*this, i, j);
501  }
502 
503  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
504  return LinearMapper(*this, i, j);
505  }
506 
507  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
508  return VectorMapper(*this, i, j);
509  }
510 
511  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const {
512  return Base::m_tensor;
513  }
514 };
515 
516 template <typename T>
517 struct TensorContractionInputMapperTrait;
518 
519 template <typename Scalar_, typename Index_, int side_, typename Tensor_, typename nocontract_t_, typename contract_t_,
520  int packet_size_, bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_,
521  template <class> class MakePointer_>
522 struct TensorContractionInputMapperTrait<
523  TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_,
524  inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_> > {
525  typedef Tensor_ XprType;
526  static const bool inner_dim_contiguous = inner_dim_contiguous_;
527  static const bool inner_dim_reordered = inner_dim_reordered_;
528 };
529 
530 } // end namespace internal
531 } // end namespace Eigen
532 
533 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
Definition: TensorContractionMapper.h:28
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
AlignmentType
The tensor class.
Definition: Tensor.h:68