$darkmode
Eigen-unsupported  5.0.1-dev
TensorArgMax.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5 // Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
13 
14 // IWYU pragma: private
15 #include "./InternalHeaderCheck.h"
16 
17 namespace Eigen {
18 namespace internal {
19 
20 template <typename XprType>
21 struct traits<TensorIndexPairOp<XprType>> : public traits<XprType> {
22  typedef traits<XprType> XprTraits;
23  typedef typename XprTraits::StorageKind StorageKind;
24  typedef typename XprTraits::Index Index;
25  typedef Pair<Index, typename XprTraits::Scalar> Scalar;
26  typedef typename XprType::Nested Nested;
27  typedef std::remove_reference_t<Nested> Nested_;
28  static constexpr int NumDimensions = XprTraits::NumDimensions;
29  static constexpr int Layout = XprTraits::Layout;
30 };
31 
32 template <typename XprType>
33 struct eval<TensorIndexPairOp<XprType>, Eigen::Dense> {
34  typedef const TensorIndexPairOp<XprType> EIGEN_DEVICE_REF type;
35 };
36 
37 template <typename XprType>
38 struct nested<TensorIndexPairOp<XprType>, 1, typename eval<TensorIndexPairOp<XprType>>::type> {
39  typedef TensorIndexPairOp<XprType> type;
40 };
41 
42 } // end namespace internal
43 
49 template <typename XprType>
50 class TensorIndexPairOp : public TensorBase<TensorIndexPairOp<XprType>, ReadOnlyAccessors> {
51  public:
52  typedef typename Eigen::internal::traits<TensorIndexPairOp>::Scalar Scalar;
53  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
54  typedef typename Eigen::internal::nested<TensorIndexPairOp>::type Nested;
55  typedef typename Eigen::internal::traits<TensorIndexPairOp>::StorageKind StorageKind;
56  typedef typename Eigen::internal::traits<TensorIndexPairOp>::Index Index;
57  typedef Pair<Index, typename XprType::CoeffReturnType> CoeffReturnType;
58 
59  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexPairOp(const XprType& expr) : m_xpr(expr) {}
60 
61  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_xpr; }
62 
63  protected:
64  typename XprType::Nested m_xpr;
65 };
66 
67 // Eval as rvalue
68 template <typename ArgType, typename Device>
69 struct TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> {
70  typedef TensorIndexPairOp<ArgType> XprType;
71  typedef typename XprType::Index Index;
72  typedef typename XprType::Scalar Scalar;
73  typedef typename XprType::CoeffReturnType CoeffReturnType;
74 
75  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
76  static constexpr int NumDims = internal::array_size<Dimensions>::value;
77  typedef StorageMemory<CoeffReturnType, Device> Storage;
78  typedef typename Storage::Type EvaluatorPointerType;
79 
80  enum {
81  IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
82  PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
83  BlockAccess = false,
85  CoordAccess = false, // to be implemented
86  RawAccess = false
87  };
88  static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
89 
90  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
91  typedef internal::TensorBlockNotImplemented TensorBlock;
92  //===--------------------------------------------------------------------===//
93 
94  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_impl(op.expression(), device) {}
95 
96  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); }
97 
98  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
99  m_impl.evalSubExprsIfNeeded(NULL);
100  return true;
101  }
102  EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
103 
104  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
105  return CoeffReturnType(index, m_impl.coeff(index));
106  }
107 
108  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
109  return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
110  }
111 
112  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
113 
114  protected:
115  TensorEvaluator<ArgType, Device> m_impl;
116 };
117 
118 namespace internal {
119 
126 template <typename ReduceOp, typename Dims, typename XprType>
127 struct traits<TensorPairReducerOp<ReduceOp, Dims, XprType>> : public traits<XprType> {
128  typedef traits<XprType> XprTraits;
129  typedef typename XprTraits::StorageKind StorageKind;
130  typedef typename XprTraits::Index Index;
131  typedef Index Scalar;
132  typedef typename XprType::Nested Nested;
133  typedef std::remove_reference_t<Nested> Nested_;
134  static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
135  static constexpr int Layout = XprTraits::Layout;
136 };
137 
138 template <typename ReduceOp, typename Dims, typename XprType>
139 struct eval<TensorPairReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense> {
140  typedef const TensorPairReducerOp<ReduceOp, Dims, XprType> EIGEN_DEVICE_REF type;
141 };
142 
143 template <typename ReduceOp, typename Dims, typename XprType>
144 struct nested<TensorPairReducerOp<ReduceOp, Dims, XprType>, 1,
145  typename eval<TensorPairReducerOp<ReduceOp, Dims, XprType>>::type> {
146  typedef TensorPairReducerOp<ReduceOp, Dims, XprType> type;
147 };
148 
149 } // end namespace internal
150 
151 template <typename ReduceOp, typename Dims, typename XprType>
152 class TensorPairReducerOp : public TensorBase<TensorPairReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> {
153  public:
154  typedef typename Eigen::internal::traits<TensorPairReducerOp>::Scalar Scalar;
155  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
156  typedef typename Eigen::internal::nested<TensorPairReducerOp>::type Nested;
157  typedef typename Eigen::internal::traits<TensorPairReducerOp>::StorageKind StorageKind;
158  typedef typename Eigen::internal::traits<TensorPairReducerOp>::Index Index;
159  typedef Index CoeffReturnType;
160 
161  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPairReducerOp(const XprType& expr, const ReduceOp& reduce_op,
162  const Index return_dim, const Dims& reduce_dims)
163  : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
164 
165  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_xpr; }
166 
167  EIGEN_DEVICE_FUNC const ReduceOp& reduce_op() const { return m_reduce_op; }
168 
169  EIGEN_DEVICE_FUNC const Dims& reduce_dims() const { return m_reduce_dims; }
170 
171  EIGEN_DEVICE_FUNC Index return_dim() const { return m_return_dim; }
172 
173  protected:
174  typename XprType::Nested m_xpr;
175  const ReduceOp m_reduce_op;
176  const Index m_return_dim;
177  const Dims m_reduce_dims;
178 };
179 
180 // Eval as rvalue
181 template <typename ReduceOp, typename Dims, typename ArgType, typename Device>
182 struct TensorEvaluator<const TensorPairReducerOp<ReduceOp, Dims, ArgType>, Device> {
183  typedef TensorPairReducerOp<ReduceOp, Dims, ArgType> XprType;
184  typedef typename XprType::Index Index;
185  typedef typename XprType::Scalar Scalar;
186  typedef typename XprType::CoeffReturnType CoeffReturnType;
187  typedef typename TensorIndexPairOp<ArgType>::CoeffReturnType PairType;
188  typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>,
189  Device>::Dimensions Dimensions;
190  typedef typename TensorEvaluator<const TensorIndexPairOp<ArgType>, Device>::Dimensions InputDimensions;
191  static constexpr int NumDims = internal::array_size<InputDimensions>::value;
192  typedef array<Index, NumDims> StrideDims;
193  typedef StorageMemory<CoeffReturnType, Device> Storage;
194  typedef typename Storage::Type EvaluatorPointerType;
195  typedef StorageMemory<PairType, Device> PairStorageMem;
196 
197  enum {
198  IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
199  PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
200  BlockAccess = false,
201  PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
202  CoordAccess = false, // to be implemented
203  RawAccess = false
204  };
205  static constexpr int Layout =
206  TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device>::Layout;
207  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
208  typedef internal::TensorBlockNotImplemented TensorBlock;
209  //===--------------------------------------------------------------------===//
210 
211  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
212  : m_orig_impl(op.expression(), device),
213  m_impl(op.expression().index_pairs().reduce(op.reduce_dims(), op.reduce_op()), device),
214  m_return_dim(op.return_dim()) {
215  gen_strides(m_orig_impl.dimensions(), m_strides);
216  if (Layout == static_cast<int>(ColMajor)) {
217  const Index total_size = internal::array_prod(m_orig_impl.dimensions());
218  m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
219  } else {
220  const Index total_size = internal::array_prod(m_orig_impl.dimensions());
221  m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
222  }
223  // If m_return_dim is not a valid index, returns 1 or this can crash on Windows.
224  m_stride_div =
225  ((m_return_dim >= 0) && (m_return_dim < static_cast<Index>(m_strides.size()))) ? m_strides[m_return_dim] : 1;
226  }
227 
228  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); }
229 
230  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
231  m_impl.evalSubExprsIfNeeded(NULL);
232  return true;
233  }
234  EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
235 
236  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
237  const PairType v = m_impl.coeff(index);
238  return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
239  }
240 
241  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
242 
243  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
244  const double compute_cost =
245  1.0 + (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
246  return m_orig_impl.costPerCoeff(vectorized) + m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
247  }
248 
249  private:
250  EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
251  if (m_return_dim < 0) {
252  return; // Won't be using the strides.
253  }
254  eigen_assert(m_return_dim < NumDims && "Asking to convert index to a dimension outside of the rank");
255 
256  // Calculate m_stride_div and m_stride_mod, which are used to
257  // calculate the value of an index w.r.t. the m_return_dim.
258  if (Layout == static_cast<int>(ColMajor)) {
259  strides[0] = 1;
260  for (int i = 1; i < NumDims; ++i) {
261  strides[i] = strides[i - 1] * dims[i - 1];
262  }
263  } else {
264  strides[NumDims - 1] = 1;
265  for (int i = NumDims - 2; i >= 0; --i) {
266  strides[i] = strides[i + 1] * dims[i + 1];
267  }
268  }
269  }
270 
271  protected:
272  TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> m_orig_impl;
273  TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device> m_impl;
274  const Index m_return_dim;
275  StrideDims m_strides;
276  Index m_stride_mod;
277  Index m_stride_div;
278 };
279 
280 } // end namespace Eigen
281 
282 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
Tensor + Index Pair class.
Definition: TensorArgMax.h:50
Namespace containing all symbols from the Eigen library.
The tensor evaluator class.
Definition: TensorEvaluator.h:30
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor base class.
Definition: TensorForwardDeclarations.h:68