$darkmode
Eigen  5.0.1-dev
IndexedView.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_INDEXED_VIEW_H
11 #define EIGEN_INDEXED_VIEW_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <typename XprType, typename RowIndices, typename ColIndices>
21 struct traits<IndexedView<XprType, RowIndices, ColIndices>> : traits<XprType> {
22  enum {
23  RowsAtCompileTime = int(IndexedViewHelper<RowIndices>::SizeAtCompileTime),
24  ColsAtCompileTime = int(IndexedViewHelper<ColIndices>::SizeAtCompileTime),
25  MaxRowsAtCompileTime = RowsAtCompileTime,
26  MaxColsAtCompileTime = ColsAtCompileTime,
27 
28  XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
29  IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1
30  : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0
31  : XprTypeIsRowMajor,
32 
33  RowIncr = int(IndexedViewHelper<RowIndices>::IncrAtCompileTime),
34  ColIncr = int(IndexedViewHelper<ColIndices>::IncrAtCompileTime),
35  InnerIncr = IsRowMajor ? ColIncr : RowIncr,
36  OuterIncr = IsRowMajor ? RowIncr : ColIncr,
37 
38  HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
39  XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret)
40  : int(outer_stride_at_compile_time<XprType>::ret),
41  XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret)
42  : int(inner_stride_at_compile_time<XprType>::ret),
43 
44  InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
45  IsBlockAlike = InnerIncr == 1 && OuterIncr == 1,
46  IsInnerPannel = HasSameStorageOrderAsXprType &&
47  is_same<AllRange<InnerSize>, std::conditional_t<XprTypeIsRowMajor, ColIndices, RowIndices>>::value,
48 
49  InnerStrideAtCompileTime =
50  InnerIncr < 0 || InnerIncr == DynamicIndex || XprInnerStride == Dynamic || InnerIncr == Undefined
51  ? Dynamic
52  : XprInnerStride * InnerIncr,
53  OuterStrideAtCompileTime =
54  OuterIncr < 0 || OuterIncr == DynamicIndex || XprOuterstride == Dynamic || OuterIncr == Undefined
55  ? Dynamic
56  : XprOuterstride * OuterIncr,
57 
58  ReturnAsScalar = is_single_range<RowIndices>::value && is_single_range<ColIndices>::value,
59  ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
60  ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),
61 
62  // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
63  // but this is too strict regarding negative strides...
64  DirectAccessMask = (int(InnerIncr) != Undefined && int(OuterIncr) != Undefined && InnerIncr >= 0 && OuterIncr >= 0)
66  : 0,
67  FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
68  FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
69  FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
70  Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask)) | FlagsLvalueBit | FlagsRowMajorBit |
71  FlagsLinearAccessBit
72  };
73 
74  typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
75 };
76 
77 template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
78 class IndexedViewImpl;
79 
80 } // namespace internal
81 
120 template <typename XprType, typename RowIndices, typename ColIndices>
122  : public internal::IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
123  (internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags &
124  DirectAccessBit) != 0> {
125  public:
126  typedef typename internal::IndexedViewImpl<
127  XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
128  (internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags & DirectAccessBit) != 0>
129  Base;
130  EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
131  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
132 
133  template <typename T0, typename T1>
134  IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
135 };
136 
137 namespace internal {
138 
139 // Generic API dispatcher
140 template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
141 class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
142  public:
143  typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
144  typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
145  typedef internal::remove_all_t<XprType> NestedExpression;
146  typedef typename XprType::Scalar Scalar;
147 
148  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
149 
150  template <typename T0, typename T1>
151  IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices)
152  : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}
153 
155  Index rows() const { return IndexedViewHelper<RowIndices>::size(m_rowIndices); }
156 
158  Index cols() const { return IndexedViewHelper<ColIndices>::size(m_colIndices); }
159 
161  const internal::remove_all_t<XprType>& nestedExpression() const { return m_xpr; }
162 
164  std::remove_reference_t<XprType>& nestedExpression() { return m_xpr; }
165 
167  const RowIndices& rowIndices() const { return m_rowIndices; }
168 
170  const ColIndices& colIndices() const { return m_colIndices; }
171 
172  constexpr Scalar& coeffRef(Index rowId, Index colId) {
173  return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
174  }
175 
176  constexpr const Scalar& coeffRef(Index rowId, Index colId) const {
177  return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
178  }
179 
180  protected:
181  MatrixTypeNested m_xpr;
182  RowIndices m_rowIndices;
183  ColIndices m_colIndices;
184 };
185 
186 template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
187 class IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, true>
188  : public IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, false> {
189  public:
190  using Base = internal::IndexedViewImpl<XprType, RowIndices, ColIndices,
191  typename internal::traits<XprType>::StorageKind, false>;
192  using Derived = IndexedView<XprType, RowIndices, ColIndices>;
193 
194  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
195 
196  template <typename T0, typename T1>
197  IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
198 
199  Index rowIncrement() const {
200  if (traits<Derived>::RowIncr != DynamicIndex && traits<Derived>::RowIncr != Undefined) {
201  return traits<Derived>::RowIncr;
202  }
203  return IndexedViewHelper<RowIndices>::incr(this->rowIndices());
204  }
205  Index colIncrement() const {
206  if (traits<Derived>::ColIncr != DynamicIndex && traits<Derived>::ColIncr != Undefined) {
207  return traits<Derived>::ColIncr;
208  }
209  return IndexedViewHelper<ColIndices>::incr(this->colIndices());
210  }
211 
212  Index innerIncrement() const { return traits<Derived>::IsRowMajor ? colIncrement() : rowIncrement(); }
213 
214  Index outerIncrement() const { return traits<Derived>::IsRowMajor ? rowIncrement() : colIncrement(); }
215 
216  std::decay_t<typename XprType::Scalar>* data() {
217  Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
218  Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
219  return this->nestedExpression().data() + row_offset + col_offset;
220  }
221 
222  const std::decay_t<typename XprType::Scalar>* data() const {
223  Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
224  Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
225  return this->nestedExpression().data() + row_offset + col_offset;
226  }
227 
228  EIGEN_DEVICE_FUNC constexpr Index innerStride() const noexcept {
229  if (traits<Derived>::InnerStrideAtCompileTime != Dynamic) {
230  return traits<Derived>::InnerStrideAtCompileTime;
231  }
232  return innerIncrement() * this->nestedExpression().innerStride();
233  }
234 
235  EIGEN_DEVICE_FUNC constexpr Index outerStride() const noexcept {
236  if (traits<Derived>::OuterStrideAtCompileTime != Dynamic) {
237  return traits<Derived>::OuterStrideAtCompileTime;
238  }
239  return outerIncrement() * this->nestedExpression().outerStride();
240  }
241 };
242 
243 template <typename ArgType, typename RowIndices, typename ColIndices>
244 struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
245  : evaluator_base<IndexedView<ArgType, RowIndices, ColIndices>> {
246  typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;
247 
248  enum {
249  CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,
250 
251  FlagsLinearAccessBit =
252  (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
253 
254  FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,
255 
256  Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) |
257  FlagsLinearAccessBit | FlagsRowMajorBit,
258 
259  Alignment = 0
260  };
261 
262  EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr) {
263  EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
264  }
265 
266  typedef typename XprType::Scalar Scalar;
267  typedef typename XprType::CoeffReturnType CoeffReturnType;
268 
269  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
270  eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
271  m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
272  return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
273  }
274 
275  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
276  eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
277  m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
278  return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
279  }
280 
281  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
282  EIGEN_STATIC_ASSERT_LVALUE(XprType)
283  Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
284  Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
285  eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
286  m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
287  return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
288  }
289 
290  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeffRef(Index index) const {
291  Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
292  Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
293  eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
294  m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
295  return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
296  }
297 
298  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index index) const {
299  Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
300  Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
301  eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
302  m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
303  return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
304  }
305 
306  protected:
307  evaluator<ArgType> m_argImpl;
308  const XprType& m_xpr;
309 };
310 
311 // Catch assignments to an IndexedView.
312 template <typename ArgType, typename RowIndices, typename ColIndices>
313 struct evaluator_assume_aliasing<IndexedView<ArgType, RowIndices, ColIndices>> {
314  static const bool value = true;
315 };
316 
317 } // end namespace internal
318 
319 } // end namespace Eigen
320 
321 #endif // EIGEN_INDEXED_VIEW_H
const unsigned int DirectAccessBit
Definition: Constants.h:159
const unsigned int LvalueBit
Definition: Constants.h:148
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
const int DynamicIndex
Definition: Constants.h:30
Definition: BFloat16.h:231
const unsigned int RowMajorBit
Definition: Constants.h:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const int Undefined
Definition: Constants.h:34
Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices...
Definition: IndexedView.h:121
const int Dynamic
Definition: Constants.h:25
const unsigned int LinearAccessBit
Definition: Constants.h:133