$darkmode
Eigen-unsupported  5.0.1-dev
TensorAssign.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 template <typename LhsXprType, typename RhsXprType>
20 struct traits<TensorAssignOp<LhsXprType, RhsXprType> > {
21  typedef typename LhsXprType::Scalar Scalar;
22  typedef typename traits<LhsXprType>::StorageKind StorageKind;
23  typedef
24  typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
25  typedef typename LhsXprType::Nested LhsNested;
26  typedef typename RhsXprType::Nested RhsNested;
27  typedef std::remove_reference_t<LhsNested> LhsNested_;
28  typedef std::remove_reference_t<RhsNested> RhsNested_;
29  static constexpr std::size_t NumDimensions = internal::traits<LhsXprType>::NumDimensions;
30  static constexpr int Layout = internal::traits<LhsXprType>::Layout;
31  typedef typename traits<LhsXprType>::PointerType PointerType;
32 
33  enum { Flags = 0 };
34 };
35 
36 template <typename LhsXprType, typename RhsXprType>
37 struct eval<TensorAssignOp<LhsXprType, RhsXprType>, Eigen::Dense> {
38  typedef const TensorAssignOp<LhsXprType, RhsXprType>& type;
39 };
40 
41 template <typename LhsXprType, typename RhsXprType>
42 struct nested<TensorAssignOp<LhsXprType, RhsXprType>, 1, typename eval<TensorAssignOp<LhsXprType, RhsXprType> >::type> {
43  typedef TensorAssignOp<LhsXprType, RhsXprType> type;
44 };
45 
46 } // end namespace internal
47 
54 template <typename LhsXprType, typename RhsXprType>
55 class TensorAssignOp : public TensorBase<TensorAssignOp<LhsXprType, RhsXprType> > {
56  public:
57  typedef typename Eigen::internal::traits<TensorAssignOp>::Scalar Scalar;
58  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
59  typedef typename LhsXprType::CoeffReturnType CoeffReturnType;
60  typedef typename Eigen::internal::nested<TensorAssignOp>::type Nested;
61  typedef typename Eigen::internal::traits<TensorAssignOp>::StorageKind StorageKind;
62  typedef typename Eigen::internal::traits<TensorAssignOp>::Index Index;
63 
64  static constexpr int NumDims = Eigen::internal::traits<TensorAssignOp>::NumDimensions;
65 
66  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorAssignOp(LhsXprType& lhs, const RhsXprType& rhs)
67  : m_lhs_xpr(lhs), m_rhs_xpr(rhs) {}
68 
70  EIGEN_DEVICE_FUNC internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression() const {
71  return *((internal::remove_all_t<typename LhsXprType::Nested>*)&m_lhs_xpr);
72  }
73 
74  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression() const {
75  return m_rhs_xpr;
76  }
77 
78  protected:
79  internal::remove_all_t<typename LhsXprType::Nested>& m_lhs_xpr;
80  const internal::remove_all_t<typename RhsXprType::Nested>& m_rhs_xpr;
81 };
82 
83 template <typename LeftArgType, typename RightArgType, typename Device>
84 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> {
85  typedef TensorAssignOp<LeftArgType, RightArgType> XprType;
86  typedef typename XprType::Index Index;
87  typedef typename XprType::Scalar Scalar;
88  typedef typename XprType::CoeffReturnType CoeffReturnType;
89  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90  typedef typename TensorEvaluator<RightArgType, Device>::Dimensions Dimensions;
91  typedef StorageMemory<CoeffReturnType, Device> Storage;
92  typedef typename Storage::Type EvaluatorPointerType;
93 
94  static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
95  static constexpr int NumDims = XprType::NumDims;
96  static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
97 
98  enum {
99  IsAligned =
100  int(TensorEvaluator<LeftArgType, Device>::IsAligned) & int(TensorEvaluator<RightArgType, Device>::IsAligned),
101  PacketAccess = int(TensorEvaluator<LeftArgType, Device>::PacketAccess) &
102  int(TensorEvaluator<RightArgType, Device>::PacketAccess),
103  BlockAccess = int(TensorEvaluator<LeftArgType, Device>::BlockAccess) &
104  int(TensorEvaluator<RightArgType, Device>::BlockAccess),
105  PreferBlockAccess = int(TensorEvaluator<LeftArgType, Device>::PreferBlockAccess) |
106  int(TensorEvaluator<RightArgType, Device>::PreferBlockAccess),
107  RawAccess = TensorEvaluator<LeftArgType, Device>::RawAccess
108  };
109 
110  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
111  typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
112  typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
113 
114  typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlock RightTensorBlock;
115  //===--------------------------------------------------------------------===//
116 
117  TensorEvaluator(const XprType& op, const Device& device)
118  : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) {
119  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
120  static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
121  YOU_MADE_A_PROGRAMMING_MISTAKE);
122  }
123 
124  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const {
125  // The dimensions of the lhs and the rhs tensors should be equal to prevent
126  // overflows and ensure the result is fully initialized.
127  // TODO: use left impl instead if right impl dimensions are known at compile time.
128  return m_rightImpl.dimensions();
129  }
130 
131  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
132  eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
133  m_leftImpl.evalSubExprsIfNeeded(NULL);
134  // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non
135  // null value), attempt to evaluate the rhs expression in place. Returns true iff in place
136  // evaluation isn't supported and the caller still needs to manually assign the values generated
137  // by the rhs to the lhs.
138  return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data());
139  }
140 
141 #ifdef EIGEN_USE_THREADS
142  template <typename EvalSubExprsCallback>
143  EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
144  m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done](bool) {
145  m_rightImpl.evalSubExprsIfNeededAsync(m_leftImpl.data(), [done](bool need_assign) { done(need_assign); });
146  });
147  }
148 #endif // EIGEN_USE_THREADS
149 
150  EIGEN_STRONG_INLINE void cleanup() {
151  m_leftImpl.cleanup();
152  m_rightImpl.cleanup();
153  }
154 
155  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalScalar(Index i) const {
156  m_leftImpl.coeffRef(i) = m_rightImpl.coeff(i);
157  }
158  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalPacket(Index i) const {
159  const int LhsStoreMode = TensorEvaluator<LeftArgType, Device>::IsAligned ? Aligned : Unaligned;
160  const int RhsLoadMode = TensorEvaluator<RightArgType, Device>::IsAligned ? Aligned : Unaligned;
161  m_leftImpl.template writePacket<LhsStoreMode>(i, m_rightImpl.template packet<RhsLoadMode>(i));
162  }
163  EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const { return m_leftImpl.coeff(index); }
164  template <int LoadMode>
165  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
166  return m_leftImpl.template packet<LoadMode>(index);
167  }
168 
169  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
170  // We assume that evalPacket or evalScalar is called to perform the
171  // assignment and account for the cost of the write here, but reduce left
172  // cost by one load because we are using m_leftImpl.coeffRef.
173  TensorOpCost left = m_leftImpl.costPerCoeff(vectorized);
174  return m_rightImpl.costPerCoeff(vectorized) +
175  TensorOpCost(numext::maxi(0.0, left.bytes_loaded() - sizeof(CoeffReturnType)), left.bytes_stored(),
176  left.compute_cycles()) +
177  TensorOpCost(0, sizeof(CoeffReturnType), 0, vectorized, PacketSize);
178  }
179 
180  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements() const {
181  return internal::TensorBlockResourceRequirements::merge(m_leftImpl.getResourceRequirements(),
182  m_rightImpl.getResourceRequirements());
183  }
184 
185  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlock(TensorBlockDesc& desc, TensorBlockScratch& scratch) {
186  if (TensorEvaluator<LeftArgType, Device>::RawAccess && m_leftImpl.data() != NULL) {
187  // If destination has raw data access, we pass it as a potential
188  // destination for a block descriptor evaluation.
189  desc.template AddDestinationBuffer<Layout>(
190  /*dst_base=*/m_leftImpl.data() + desc.offset(),
191  /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions()));
192  }
193 
194  RightTensorBlock block = m_rightImpl.block(desc, scratch, /*root_of_expr_ast=*/true);
195  // If block was evaluated into a destination, there is no need to do assignment.
196  if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) {
197  m_leftImpl.writeBlock(desc, block);
198  }
199  block.cleanup();
200  }
201 
202  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_leftImpl.data(); }
203 
204  private:
205  TensorEvaluator<LeftArgType, Device> m_leftImpl;
206  TensorEvaluator<RightArgType, Device> m_rightImpl;
207 };
208 
209 } // namespace Eigen
210 
211 #endif // EIGEN_CXX11_TENSOR_TENSOR_ASSIGN_H
Namespace containing all symbols from the Eigen library.
Definition: TensorAssign.h:55
internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
Definition: TensorAssign.h:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor base class.
Definition: TensorForwardDeclarations.h:68