$darkmode
Eigen-unsupported  5.0.1-dev
TensorExpr.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 template <typename NullaryOp, typename XprType>
20 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > : traits<XprType> {
21  typedef traits<XprType> XprTraits;
22  typedef typename XprType::Scalar Scalar;
23  typedef typename XprType::Nested XprTypeNested;
24  typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
25  static constexpr int NumDimensions = XprTraits::NumDimensions;
26  static constexpr int Layout = XprTraits::Layout;
27  typedef typename XprTraits::PointerType PointerType;
28  enum { Flags = 0 };
29 };
30 
31 } // end namespace internal
32 
41 template <typename NullaryOp, typename XprType>
42 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> {
43  public:
44  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
45  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
46  typedef typename XprType::CoeffReturnType CoeffReturnType;
48  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
49  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
50 
51  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
52  : m_xpr(xpr), m_functor(func) {}
53 
54  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& nestedExpression() const { return m_xpr; }
55 
56  EIGEN_DEVICE_FUNC const NullaryOp& functor() const { return m_functor; }
57 
58  protected:
59  typename XprType::Nested m_xpr;
60  const NullaryOp m_functor;
61 };
62 
63 namespace internal {
64 template <typename UnaryOp, typename XprType>
65 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > : traits<XprType> {
66  // TODO(phli): Add InputScalar, InputPacket. Check references to
67  // current Scalar/Packet to see if the intent is Input or Output.
68  typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
69  typedef traits<XprType> XprTraits;
70  typedef typename XprType::Nested XprTypeNested;
71  typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
72  static constexpr int NumDimensions = XprTraits::NumDimensions;
73  static constexpr int Layout = XprTraits::Layout;
74  typedef typename TypeConversion<Scalar, typename XprTraits::PointerType>::type PointerType;
75 };
76 
77 template <typename UnaryOp, typename XprType>
78 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense> {
79  typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
80 };
81 
82 template <typename UnaryOp, typename XprType>
83 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type> {
84  typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
85 };
86 
87 } // end namespace internal
88 
97 template <typename UnaryOp, typename XprType>
98 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> {
99  public:
100  // TODO(phli): Add InputScalar, InputPacket. Check references to
101  // current Scalar/Packet to see if the intent is Input or Output.
102  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
103  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
104  typedef Scalar CoeffReturnType;
105  typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
106  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
107  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
108 
109  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
110  : m_xpr(xpr), m_functor(func) {}
111 
112  EIGEN_DEVICE_FUNC const UnaryOp& functor() const { return m_functor; }
113 
115  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& nestedExpression() const { return m_xpr; }
116 
117  protected:
118  typename XprType::Nested m_xpr;
119  const UnaryOp m_functor;
120 };
121 
122 namespace internal {
123 template <typename BinaryOp, typename LhsXprType, typename RhsXprType>
124 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > {
125  // Type promotion to handle the case where the types of the lhs and the rhs
126  // are different.
127  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
128  // current Scalar/Packet to see if the intent is Inputs or Output.
129  typedef typename result_of<BinaryOp(typename LhsXprType::Scalar, typename RhsXprType::Scalar)>::type Scalar;
130  typedef traits<LhsXprType> XprTraits;
131  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
132  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
133  typedef
134  typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
135  typedef typename LhsXprType::Nested LhsNested;
136  typedef typename RhsXprType::Nested RhsNested;
137  typedef std::remove_reference_t<LhsNested> LhsNested_;
138  typedef std::remove_reference_t<RhsNested> RhsNested_;
139  static constexpr int NumDimensions = XprTraits::NumDimensions;
140  static constexpr int Layout = XprTraits::Layout;
141  typedef typename TypeConversion<Scalar,
142  std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
143  typename traits<LhsXprType>::PointerType,
144  typename traits<RhsXprType>::PointerType> >::type PointerType;
145  enum { Flags = 0 };
146 };
147 
148 template <typename BinaryOp, typename LhsXprType, typename RhsXprType>
149 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense> {
150  typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
151 };
152 
153 template <typename BinaryOp, typename LhsXprType, typename RhsXprType>
154 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1,
155  typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type> {
156  typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
157 };
158 
159 } // end namespace internal
160 
169 template <typename BinaryOp, typename LhsXprType, typename RhsXprType>
171  : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> {
172  public:
173  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
174  // current Scalar/Packet to see if the intent is Inputs or Output.
175  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
176  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
177  typedef Scalar CoeffReturnType;
178  typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
179  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
180  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
181 
182  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs,
183  const BinaryOp& func = BinaryOp())
184  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
185 
186  EIGEN_DEVICE_FUNC const BinaryOp& functor() const { return m_functor; }
187 
189  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression() const {
190  return m_lhs_xpr;
191  }
192 
193  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression() const {
194  return m_rhs_xpr;
195  }
196 
197  protected:
198  typename LhsXprType::Nested m_lhs_xpr;
199  typename RhsXprType::Nested m_rhs_xpr;
200  const BinaryOp m_functor;
201 };
202 
203 namespace internal {
204 template <typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
205 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > {
206  // Type promotion to handle the case where the types of the args are different.
207  typedef typename result_of<TernaryOp(typename Arg1XprType::Scalar, typename Arg2XprType::Scalar,
208  typename Arg3XprType::Scalar)>::type Scalar;
209  typedef traits<Arg1XprType> XprTraits;
210  typedef typename traits<Arg1XprType>::StorageKind StorageKind;
211  typedef typename traits<Arg1XprType>::Index Index;
212  typedef typename Arg1XprType::Nested Arg1Nested;
213  typedef typename Arg2XprType::Nested Arg2Nested;
214  typedef typename Arg3XprType::Nested Arg3Nested;
215  typedef std::remove_reference_t<Arg1Nested> Arg1Nested_;
216  typedef std::remove_reference_t<Arg2Nested> Arg2Nested_;
217  typedef std::remove_reference_t<Arg3Nested> Arg3Nested_;
218  static constexpr int NumDimensions = XprTraits::NumDimensions;
219  static constexpr int Layout = XprTraits::Layout;
220  typedef typename TypeConversion<Scalar,
221  std::conditional_t<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
222  typename traits<Arg2XprType>::PointerType,
223  typename traits<Arg3XprType>::PointerType> >::type PointerType;
224  enum { Flags = 0 };
225 };
226 
227 template <typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
228 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> {
229  typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
230 };
231 
232 template <typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
233 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1,
234  typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> {
235  typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
236 };
237 
238 } // end namespace internal
239 
240 template <typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
241 class TensorCwiseTernaryOp
242  : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> {
243  public:
244  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
245  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
246  typedef Scalar CoeffReturnType;
247  typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
248  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
249  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
250 
251  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2,
252  const Arg3XprType& arg3,
253  const TernaryOp& func = TernaryOp())
254  : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
255 
256  EIGEN_DEVICE_FUNC const TernaryOp& functor() const { return m_functor; }
257 
259  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename Arg1XprType::Nested>& arg1Expression() const {
260  return m_arg1_xpr;
261  }
262 
263  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename Arg2XprType::Nested>& arg2Expression() const {
264  return m_arg2_xpr;
265  }
266 
267  EIGEN_DEVICE_FUNC const internal::remove_all_t<typename Arg3XprType::Nested>& arg3Expression() const {
268  return m_arg3_xpr;
269  }
270 
271  protected:
272  typename Arg1XprType::Nested m_arg1_xpr;
273  typename Arg2XprType::Nested m_arg2_xpr;
274  typename Arg3XprType::Nested m_arg3_xpr;
275  const TernaryOp m_functor;
276 };
277 
278 namespace internal {
279 template <typename IfXprType, typename ThenXprType, typename ElseXprType>
280 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > : traits<ThenXprType> {
281  typedef typename traits<ThenXprType>::Scalar Scalar;
282  typedef traits<ThenXprType> XprTraits;
283  typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
284  typename traits<ElseXprType>::StorageKind>::ret StorageKind;
285  typedef
286  typename promote_index_type<typename traits<ElseXprType>::Index, typename traits<ThenXprType>::Index>::type Index;
287  typedef typename IfXprType::Nested IfNested;
288  typedef typename ThenXprType::Nested ThenNested;
289  typedef typename ElseXprType::Nested ElseNested;
290  static constexpr int NumDimensions = XprTraits::NumDimensions;
291  static constexpr int Layout = XprTraits::Layout;
292  typedef std::conditional_t<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
293  typename traits<ThenXprType>::PointerType, typename traits<ElseXprType>::PointerType>
294  PointerType;
295 };
296 
297 template <typename IfXprType, typename ThenXprType, typename ElseXprType>
298 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> {
299  typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
300 };
301 
302 template <typename IfXprType, typename ThenXprType, typename ElseXprType>
303 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1,
304  typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> {
305  typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
306 };
307 
308 } // end namespace internal
309 
310 template <typename IfXprType, typename ThenXprType, typename ElseXprType>
311 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> {
312  public:
313  typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
314  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
315  typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
316  typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
317  typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
318  typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
319  typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
320 
321  EIGEN_DEVICE_FUNC TensorSelectOp(const IfXprType& a_condition, const ThenXprType& a_then, const ElseXprType& a_else)
322  : m_condition(a_condition), m_then(a_then), m_else(a_else) {}
323 
324  EIGEN_DEVICE_FUNC const IfXprType& ifExpression() const { return m_condition; }
325 
326  EIGEN_DEVICE_FUNC const ThenXprType& thenExpression() const { return m_then; }
327 
328  EIGEN_DEVICE_FUNC const ElseXprType& elseExpression() const { return m_else; }
329 
330  protected:
331  typename IfXprType::Nested m_condition;
332  typename ThenXprType::Nested m_then;
333  typename ElseXprType::Nested m_else;
334 };
335 
336 } // end namespace Eigen
337 
338 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
Tensor binary expression.
Definition: TensorExpr.h:170
Tensor unary expression.
Definition: TensorExpr.h:98
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor base class.
Definition: TensorForwardDeclarations.h:68
const internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
Definition: TensorExpr.h:189
Tensor nullary expression.
Definition: TensorExpr.h:42
const internal::remove_all_t< typename XprType::Nested > & nestedExpression() const
Definition: TensorExpr.h:115