$darkmode
Eigen-unsupported  5.0.1-dev
TensorDevice.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
29 template <typename ExpressionType, typename DeviceType>
30 class TensorDevice {
31  public:
32  TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
33 
34  EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorDevice)
35 
36  template <typename OtherDerived>
37  EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
39  Assign assign(m_expression, other);
41  return *this;
42  }
43 
44  template <typename OtherDerived>
45  EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
46  typedef typename OtherDerived::Scalar Scalar;
47  typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
48  Sum sum(m_expression, other);
50  Assign assign(m_expression, sum);
52  return *this;
53  }
54 
55  template <typename OtherDerived>
56  EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
57  typedef typename OtherDerived::Scalar Scalar;
58  typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived>
59  Difference;
60  Difference difference(m_expression, other);
62  Assign assign(m_expression, difference);
64  return *this;
65  }
66 
67  protected:
68  const DeviceType& m_device;
69  ExpressionType& m_expression;
70 };
71 
86 template <typename ExpressionType, typename DeviceType, typename DoneCallback>
88  public:
89  TensorAsyncDevice(const DeviceType& device, ExpressionType& expression, DoneCallback done)
90  : m_device(device), m_expression(expression), m_done(std::move(done)) {}
91 
92  template <typename OtherDerived>
93  EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
96 
97  Assign assign(m_expression, other);
98  Executor::run(assign, m_device);
99  m_done();
100 
101  return *this;
102  }
103 
104  protected:
105  const DeviceType& m_device;
106  ExpressionType& m_expression;
107  DoneCallback m_done;
108 };
109 
110 #ifdef EIGEN_USE_THREADS
111 template <typename ExpressionType, typename DoneCallback>
112 class TensorAsyncDevice<ExpressionType, ThreadPoolDevice, DoneCallback> {
113  public:
114  TensorAsyncDevice(const ThreadPoolDevice& device, ExpressionType& expression, DoneCallback done)
115  : m_device(device), m_expression(expression), m_done(std::move(done)) {}
116 
117  template <typename OtherDerived>
118  EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
119  typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
120  typedef internal::TensorAsyncExecutor<const Assign, ThreadPoolDevice, DoneCallback> Executor;
121 
122  // WARNING: After assignment 'm_done' callback will be in undefined state.
123  Assign assign(m_expression, other);
124  Executor::runAsync(assign, m_device, std::move(m_done));
125 
126  return *this;
127  }
128 
129  protected:
130  const ThreadPoolDevice& m_device;
131  ExpressionType& m_expression;
132  DoneCallback m_done;
133 };
134 #endif
135 
136 } // end namespace Eigen
137 
138 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
The tensor executor class.
Definition: TensorExecutor.h:76
Tensor binary expression.
Definition: TensorExpr.h:170
Namespace containing all symbols from the Eigen library.
Pseudo expression providing an operator = that will evaluate its argument on the specified computing ...
Definition: TensorDevice.h:30
Definition: AutoDiffScalar.h:629
Definition: TensorAssign.h:55
Pseudo expression providing an operator = that will evaluate its argument asynchronously on the speci...
Definition: TensorDevice.h:87