$darkmode
Eigen-unsupported  5.0.1-dev
TensorTrace.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
5 // Copyright (C) 2017 Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
13 
14 // IWYU pragma: private
15 #include "./InternalHeaderCheck.h"
16 
17 namespace Eigen {
18 
19 namespace internal {
20 template <typename Dims, typename XprType>
21 struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType> {
22  typedef typename XprType::Scalar Scalar;
23  typedef traits<XprType> XprTraits;
24  typedef typename XprTraits::StorageKind StorageKind;
25  typedef typename XprTraits::Index Index;
26  typedef typename XprType::Nested Nested;
27  typedef std::remove_reference_t<Nested> Nested_;
28  static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
29  static constexpr int Layout = XprTraits::Layout;
30  enum {
31  // Trace is read-only.
32  Flags = traits<XprType>::Flags & ~LvalueBit
33  };
34 };
35 
36 template <typename Dims, typename XprType>
37 struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense> {
38  typedef const TensorTraceOp<Dims, XprType>& type;
39 };
40 
41 template <typename Dims, typename XprType>
42 struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type> {
43  typedef TensorTraceOp<Dims, XprType> type;
44 };
45 
46 } // end namespace internal
47 
53 template <typename Dims, typename XprType>
54 class TensorTraceOp : public TensorBase<TensorTraceOp<Dims, XprType> > {
55  public:
56  typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
57  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58  typedef typename XprType::CoeffReturnType CoeffReturnType;
59  typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
60  typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
61  typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;
62 
63  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType& expr, const Dims& dims)
64  : m_xpr(expr), m_dims(dims) {}
65 
66  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dims& dims() const { return m_dims; }
67 
68  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const internal::remove_all_t<typename XprType::Nested>& expression() const {
69  return m_xpr;
70  }
71 
72  protected:
73  typename XprType::Nested m_xpr;
74  const Dims m_dims;
75 };
76 
77 // Eval as rvalue
78 template <typename Dims, typename ArgType, typename Device>
79 struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device> {
80  typedef TensorTraceOp<Dims, ArgType> XprType;
81  static constexpr int NumInputDims =
82  internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
83  static constexpr int NumReducedDims = internal::array_size<Dims>::value;
84  static constexpr int NumOutputDims = NumInputDims - NumReducedDims;
85  typedef typename XprType::Index Index;
86  typedef DSizes<Index, NumOutputDims> Dimensions;
87  typedef typename XprType::Scalar Scalar;
88  typedef typename XprType::CoeffReturnType CoeffReturnType;
89  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90  static constexpr int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
91  typedef StorageMemory<CoeffReturnType, Device> Storage;
92  typedef typename Storage::Type EvaluatorPointerType;
93 
94  static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
95  enum {
96  IsAligned = false,
97  PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
98  BlockAccess = false,
99  PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
100  CoordAccess = false,
101  RawAccess = false
102  };
103 
104  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
105  typedef internal::TensorBlockNotImplemented TensorBlock;
106  //===--------------------------------------------------------------------===//
107 
108  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
109  : m_impl(op.expression(), device), m_traceDim(1), m_device(device) {
110  EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
111  EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)),
112  YOU_MADE_A_PROGRAMMING_MISTAKE);
113 
114  for (int i = 0; i < NumInputDims; ++i) {
115  m_reduced[i] = false;
116  }
117 
118  const Dims& op_dims = op.dims();
119  for (int i = 0; i < NumReducedDims; ++i) {
120  eigen_assert(op_dims[i] >= 0);
121  eigen_assert(op_dims[i] < NumInputDims);
122  m_reduced[op_dims[i]] = true;
123  }
124 
125  // All the dimensions should be distinct to compute the trace
126  int num_distinct_reduce_dims = 0;
127  for (int i = 0; i < NumInputDims; ++i) {
128  if (m_reduced[i]) {
129  ++num_distinct_reduce_dims;
130  }
131  }
132 
133  EIGEN_ONLY_USED_FOR_DEBUG(num_distinct_reduce_dims);
134  eigen_assert(num_distinct_reduce_dims == NumReducedDims);
135 
136  // Compute the dimensions of the result.
137  const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
138 
139  int output_index = 0;
140  int reduced_index = 0;
141  for (int i = 0; i < NumInputDims; ++i) {
142  if (m_reduced[i]) {
143  m_reducedDims[reduced_index] = input_dims[i];
144  if (reduced_index > 0) {
145  // All the trace dimensions must have the same size
146  eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
147  }
148  ++reduced_index;
149  } else {
150  m_dimensions[output_index] = input_dims[i];
151  ++output_index;
152  }
153  }
154 
155  if (NumReducedDims != 0) {
156  m_traceDim = m_reducedDims[0];
157  }
158 
159  // Compute the output strides
160  if (NumOutputDims > 0) {
161  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
162  m_outputStrides[0] = 1;
163  for (int i = 1; i < NumOutputDims; ++i) {
164  m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
165  }
166  } else {
167  m_outputStrides.back() = 1;
168  for (int i = NumOutputDims - 2; i >= 0; --i) {
169  m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
170  }
171  }
172  }
173 
174  // Compute the input strides
175  if (NumInputDims > 0) {
176  array<Index, NumInputDims> input_strides;
177  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
178  input_strides[0] = 1;
179  for (int i = 1; i < NumInputDims; ++i) {
180  input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
181  }
182  } else {
183  input_strides.back() = 1;
184  for (int i = NumInputDims - 2; i >= 0; --i) {
185  input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
186  }
187  }
188 
189  output_index = 0;
190  reduced_index = 0;
191  for (int i = 0; i < NumInputDims; ++i) {
192  if (m_reduced[i]) {
193  m_reducedStrides[reduced_index] = input_strides[i];
194  ++reduced_index;
195  } else {
196  m_preservedStrides[output_index] = input_strides[i];
197  ++output_index;
198  }
199  }
200  }
201  }
202 
203  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
204 
205  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
206  m_impl.evalSubExprsIfNeeded(NULL);
207  return true;
208  }
209 
210  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; }
211 
212  EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
213 
214  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
215  // Initialize the result
216  CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
217  Index index_stride = 0;
218  for (int i = 0; i < NumReducedDims; ++i) {
219  index_stride += m_reducedStrides[i];
220  }
221 
222  // If trace is requested along all dimensions, starting index would be 0
223  Index cur_index = 0;
224  if (NumOutputDims != 0) cur_index = firstInput(index);
225  for (Index i = 0; i < m_traceDim; ++i) {
226  result += m_impl.coeff(cur_index);
227  cur_index += index_stride;
228  }
229 
230  return result;
231  }
232 
233  template <int LoadMode>
234  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
235  eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
236 
237  EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
238  for (int i = 0; i < PacketSize; ++i) {
239  values[i] = coeff(index + i);
240  }
241  PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
242  return result;
243  }
244 
245  protected:
246  // Given the output index, finds the first index in the input tensor used to compute the trace
247  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
248  Index startInput = 0;
249  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
250  for (int i = NumOutputDims - 1; i > 0; --i) {
251  const Index idx = index / m_outputStrides[i];
252  startInput += idx * m_preservedStrides[i];
253  index -= idx * m_outputStrides[i];
254  }
255  startInput += index * m_preservedStrides[0];
256  } else {
257  for (int i = 0; i < NumOutputDims - 1; ++i) {
258  const Index idx = index / m_outputStrides[i];
259  startInput += idx * m_preservedStrides[i];
260  index -= idx * m_outputStrides[i];
261  }
262  startInput += index * m_preservedStrides[NumOutputDims - 1];
263  }
264  return startInput;
265  }
266 
267  Dimensions m_dimensions;
268  TensorEvaluator<ArgType, Device> m_impl;
269  // Initialize the size of the trace dimension
270  Index m_traceDim;
271  const Device EIGEN_DEVICE_REF m_device;
272  array<bool, NumInputDims> m_reduced;
273  array<Index, NumReducedDims> m_reducedDims;
274  array<Index, NumOutputDims> m_outputStrides;
275  array<Index, NumReducedDims> m_reducedStrides;
276  array<Index, NumOutputDims> m_preservedStrides;
277 };
278 
279 } // End namespace Eigen
280 
281 #endif // EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
const unsigned int LvalueBit
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index