$darkmode
Eigen-unsupported  5.0.1-dev
TensorConcatenation.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 template <typename Axis, typename LhsXprType, typename RhsXprType>
20 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > {
21  // Type promotion to handle the case where the types of the lhs and the rhs are different.
22  typedef typename promote_storage_type<typename LhsXprType::Scalar, typename RhsXprType::Scalar>::ret Scalar;
23  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
24  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
25  typedef
26  typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
27  typedef typename LhsXprType::Nested LhsNested;
28  typedef typename RhsXprType::Nested RhsNested;
29  typedef std::remove_reference_t<LhsNested> LhsNested_;
30  typedef std::remove_reference_t<RhsNested> RhsNested_;
31  static constexpr int NumDimensions = traits<LhsXprType>::NumDimensions;
32  static constexpr int Layout = traits<LhsXprType>::Layout;
33  enum { Flags = 0 };
34  typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
35  typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>
36  PointerType;
37 };
38 
39 template <typename Axis, typename LhsXprType, typename RhsXprType>
40 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> {
41  typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type;
42 };
43 
44 template <typename Axis, typename LhsXprType, typename RhsXprType>
45 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1,
46  typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> {
47  typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type;
48 };
49 
50 } // end namespace internal
51 
57 template <typename Axis, typename LhsXprType, typename RhsXprType>
58 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> {
59  public:
61  typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar;
62  typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind;
63  typedef typename internal::traits<TensorConcatenationOp>::Index Index;
64  typedef typename internal::nested<TensorConcatenationOp>::type Nested;
65  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
66  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
67  typedef typename NumTraits<Scalar>::Real RealScalar;
68 
69  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis)
70  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
71 
72  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression() const {
73  return m_lhs_xpr;
74  }
75 
76  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression() const {
77  return m_rhs_xpr;
78  }
79 
80  EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
81 
82  EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp)
83  protected:
84  typename LhsXprType::Nested m_lhs_xpr;
85  typename RhsXprType::Nested m_rhs_xpr;
86  const Axis m_axis;
87 };
88 
89 // Eval as rvalue
90 template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
91 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
93  typedef typename XprType::Index Index;
94  static constexpr int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
95  static constexpr int RightNumDims =
96  internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value;
97  typedef DSizes<Index, NumDims> Dimensions;
98  typedef typename XprType::Scalar Scalar;
99  typedef typename XprType::CoeffReturnType CoeffReturnType;
100  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
101  typedef StorageMemory<CoeffReturnType, Device> Storage;
102  typedef typename Storage::Type EvaluatorPointerType;
103  static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
104  enum {
105  IsAligned = false,
106  PacketAccess =
108  BlockAccess = false,
111  RawAccess = false
112  };
113 
114  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
115  typedef internal::TensorBlockNotImplemented TensorBlock;
116  //===--------------------------------------------------------------------===//
117 
118  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
119  : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) {
120  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
121  static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) ||
122  NumDims == 1),
123  YOU_MADE_A_PROGRAMMING_MISTAKE);
124  EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
125  EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
126 
127  eigen_assert(0 <= m_axis && m_axis < NumDims);
128  const Dimensions& lhs_dims = m_leftImpl.dimensions();
129  const Dimensions& rhs_dims = m_rightImpl.dimensions();
130  {
131  int i = 0;
132  for (; i < m_axis; ++i) {
133  eigen_assert(lhs_dims[i] > 0);
134  eigen_assert(lhs_dims[i] == rhs_dims[i]);
135  m_dimensions[i] = lhs_dims[i];
136  }
137  eigen_assert(lhs_dims[i] > 0); // Now i == m_axis.
138  eigen_assert(rhs_dims[i] > 0);
139  m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
140  for (++i; i < NumDims; ++i) {
141  eigen_assert(lhs_dims[i] > 0);
142  eigen_assert(lhs_dims[i] == rhs_dims[i]);
143  m_dimensions[i] = lhs_dims[i];
144  }
145  }
146 
147  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
148  m_leftStrides[0] = 1;
149  m_rightStrides[0] = 1;
150  m_outputStrides[0] = 1;
151 
152  for (int j = 1; j < NumDims; ++j) {
153  m_leftStrides[j] = m_leftStrides[j - 1] * lhs_dims[j - 1];
154  m_rightStrides[j] = m_rightStrides[j - 1] * rhs_dims[j - 1];
155  m_outputStrides[j] = m_outputStrides[j - 1] * m_dimensions[j - 1];
156  }
157  } else {
158  m_leftStrides[NumDims - 1] = 1;
159  m_rightStrides[NumDims - 1] = 1;
160  m_outputStrides[NumDims - 1] = 1;
161 
162  for (int j = NumDims - 2; j >= 0; --j) {
163  m_leftStrides[j] = m_leftStrides[j + 1] * lhs_dims[j + 1];
164  m_rightStrides[j] = m_rightStrides[j + 1] * rhs_dims[j + 1];
165  m_outputStrides[j] = m_outputStrides[j + 1] * m_dimensions[j + 1];
166  }
167  }
168  }
169 
170  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
171 
172  // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear?
173  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
174  m_leftImpl.evalSubExprsIfNeeded(NULL);
175  m_rightImpl.evalSubExprsIfNeeded(NULL);
176  return true;
177  }
178 
179  EIGEN_STRONG_INLINE void cleanup() {
180  m_leftImpl.cleanup();
181  m_rightImpl.cleanup();
182  }
183 
184  // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow.
185  // See CL/76180724 comments for more ideas.
186  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
187  // Collect dimension-wise indices (subs).
188  array<Index, NumDims> subs;
189  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
190  for (int i = NumDims - 1; i > 0; --i) {
191  subs[i] = index / m_outputStrides[i];
192  index -= subs[i] * m_outputStrides[i];
193  }
194  subs[0] = index;
195  } else {
196  for (int i = 0; i < NumDims - 1; ++i) {
197  subs[i] = index / m_outputStrides[i];
198  index -= subs[i] * m_outputStrides[i];
199  }
200  subs[NumDims - 1] = index;
201  }
202 
203  const Dimensions& left_dims = m_leftImpl.dimensions();
204  if (subs[m_axis] < left_dims[m_axis]) {
205  Index left_index;
206  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
207  left_index = subs[0];
208  EIGEN_UNROLL_LOOP
209  for (int i = 1; i < NumDims; ++i) {
210  left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
211  }
212  } else {
213  left_index = subs[NumDims - 1];
214  EIGEN_UNROLL_LOOP
215  for (int i = NumDims - 2; i >= 0; --i) {
216  left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
217  }
218  }
219  return m_leftImpl.coeff(left_index);
220  } else {
221  subs[m_axis] -= left_dims[m_axis];
222  const Dimensions& right_dims = m_rightImpl.dimensions();
223  Index right_index;
224  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
225  right_index = subs[0];
226  EIGEN_UNROLL_LOOP
227  for (int i = 1; i < NumDims; ++i) {
228  right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
229  }
230  } else {
231  right_index = subs[NumDims - 1];
232  EIGEN_UNROLL_LOOP
233  for (int i = NumDims - 2; i >= 0; --i) {
234  right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
235  }
236  }
237  return m_rightImpl.coeff(right_index);
238  }
239  }
240 
241  // TODO(phli): Add a real vectorization.
242  template <int LoadMode>
243  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
244  const int packetSize = PacketType<CoeffReturnType, Device>::size;
245  EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
246  eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
247 
248  EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
249  EIGEN_UNROLL_LOOP
250  for (int i = 0; i < packetSize; ++i) {
251  values[i] = coeff(index + i);
252  }
253  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
254  return rslt;
255  }
256 
257  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
258  const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
259  TensorOpCost::DivCost<Index>() + TensorOpCost::ModCost<Index>());
260  const double lhs_size = m_leftImpl.dimensions().TotalSize();
261  const double rhs_size = m_rightImpl.dimensions().TotalSize();
262  return (lhs_size / (lhs_size + rhs_size)) * m_leftImpl.costPerCoeff(vectorized) +
263  (rhs_size / (lhs_size + rhs_size)) * m_rightImpl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
264  }
265 
266  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
267 
268  protected:
269  Dimensions m_dimensions;
270  array<Index, NumDims> m_outputStrides;
271  array<Index, NumDims> m_leftStrides;
272  array<Index, NumDims> m_rightStrides;
273  TensorEvaluator<LeftArgType, Device> m_leftImpl;
274  TensorEvaluator<RightArgType, Device> m_rightImpl;
275  const Axis m_axis;
276 };
277 
278 // Eval as lvalue
279 template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
280 struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
281  : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
282  typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
283  typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
284  typedef typename Base::Dimensions Dimensions;
285  static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
286  enum {
287  IsAligned = false,
288  PacketAccess =
289  TensorEvaluator<LeftArgType, Device>::PacketAccess && TensorEvaluator<RightArgType, Device>::PacketAccess,
290  BlockAccess = false,
291  PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
292  TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
293  RawAccess = false
294  };
295 
296  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
297  typedef internal::TensorBlockNotImplemented TensorBlock;
298  //===--------------------------------------------------------------------===//
299 
300  EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) : Base(op, device) {
301  EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
302  }
303 
304  typedef typename XprType::Index Index;
305  typedef typename XprType::Scalar Scalar;
306  typedef typename XprType::CoeffReturnType CoeffReturnType;
307  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
308 
309  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) const {
310  // Collect dimension-wise indices (subs).
311  array<Index, Base::NumDims> subs;
312  for (int i = Base::NumDims - 1; i > 0; --i) {
313  subs[i] = index / this->m_outputStrides[i];
314  index -= subs[i] * this->m_outputStrides[i];
315  }
316  subs[0] = index;
317 
318  const Dimensions& left_dims = this->m_leftImpl.dimensions();
319  if (subs[this->m_axis] < left_dims[this->m_axis]) {
320  Index left_index = subs[0];
321  for (int i = 1; i < Base::NumDims; ++i) {
322  left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
323  }
324  return this->m_leftImpl.coeffRef(left_index);
325  } else {
326  subs[this->m_axis] -= left_dims[this->m_axis];
327  const Dimensions& right_dims = this->m_rightImpl.dimensions();
328  Index right_index = subs[0];
329  for (int i = 1; i < Base::NumDims; ++i) {
330  right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
331  }
332  return this->m_rightImpl.coeffRef(right_index);
333  }
334  }
335 
336  template <int StoreMode>
337  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index index, const PacketReturnType& x) const {
338  const int packetSize = PacketType<CoeffReturnType, Device>::size;
339  EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
340  eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
341 
342  EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
343  internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
344  for (int i = 0; i < packetSize; ++i) {
345  coeffRef(index + i) = values[i];
346  }
347  }
348 };
349 
350 } // end namespace Eigen
351 
352 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
WriteAccessors
Namespace containing all symbols from the Eigen library.
The tensor evaluator class.
Definition: TensorEvaluator.h:30
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor base class.
Definition: TensorForwardDeclarations.h:68
Tensor concatenation class.
Definition: TensorConcatenation.h:58