$darkmode
Eigen  5.0.1-dev
InnerProduct.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2024 Charlie Schlosser <cs.schlosser@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_INNER_PRODUCT_EVAL_H
11 #define EIGEN_INNER_PRODUCT_EVAL_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 // recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists
21 template <typename Scalar, int Size, typename Packet = typename packet_traits<Scalar>::type,
22  bool Stop =
23  (unpacket_traits<Packet>::size <= Size) || is_same<Packet, typename unpacket_traits<Packet>::half>::value>
24 struct find_inner_product_packet_helper;
25 
26 template <typename Scalar, int Size, typename Packet>
27 struct find_inner_product_packet_helper<Scalar, Size, Packet, false> {
28  using type = typename find_inner_product_packet_helper<Scalar, Size, typename unpacket_traits<Packet>::half>::type;
29 };
30 
31 template <typename Scalar, int Size, typename Packet>
32 struct find_inner_product_packet_helper<Scalar, Size, Packet, true> {
33  using type = Packet;
34 };
35 
36 template <typename Scalar, int Size>
37 struct find_inner_product_packet : find_inner_product_packet_helper<Scalar, Size> {};
38 
39 template <typename Scalar>
40 struct find_inner_product_packet<Scalar, Dynamic> {
41  using type = typename packet_traits<Scalar>::type;
42 };
43 
44 template <typename Lhs, typename Rhs>
45 struct inner_product_assert {
46  EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs)
47  EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs)
48  EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs)
49 #ifndef EIGEN_NO_DEBUG
50  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) {
51  eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size");
52  }
53 #else
54  static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {}
55 #endif
56 };
57 
58 template <typename Func, typename Lhs, typename Rhs>
59 struct inner_product_evaluator {
60  static constexpr int LhsFlags = evaluator<Lhs>::Flags;
61  static constexpr int RhsFlags = evaluator<Rhs>::Flags;
62  static constexpr int SizeAtCompileTime = size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime);
63  static constexpr int MaxSizeAtCompileTime =
64  min_size_prefer_fixed(Lhs::MaxSizeAtCompileTime, Rhs::MaxSizeAtCompileTime);
65  static constexpr int LhsAlignment = evaluator<Lhs>::Alignment;
66  static constexpr int RhsAlignment = evaluator<Rhs>::Alignment;
67 
68  using Scalar = typename Func::result_type;
69  using Packet = typename find_inner_product_packet<Scalar, SizeAtCompileTime>::type;
70 
71  static constexpr bool Vectorize =
72  bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess &&
73  ((MaxSizeAtCompileTime == Dynamic) || (unpacket_traits<Packet>::size <= MaxSizeAtCompileTime));
74 
75  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit inner_product_evaluator(const Lhs& lhs, const Rhs& rhs,
76  Func func = Func())
77  : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) {
78  inner_product_assert<Lhs, Rhs>::run(lhs, rhs);
79  }
80 
81  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_size.value(); }
82 
83  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const {
84  return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index));
85  }
86 
87  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& value, Index index) const {
88  return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index));
89  }
90 
91  template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
92  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const {
93  return m_func.packet(m_lhs.template packet<LhsMode, PacketType>(index),
94  m_rhs.template packet<RhsMode, PacketType>(index));
95  }
96 
97  template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
98  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(const PacketType& value, Index index) const {
99  return m_func.packet(value, m_lhs.template packet<LhsMode, PacketType>(index),
100  m_rhs.template packet<RhsMode, PacketType>(index));
101  }
102 
103  const Func m_func;
104  const evaluator<Lhs> m_lhs;
105  const evaluator<Rhs> m_rhs;
106  const variable_if_dynamic<Index, SizeAtCompileTime> m_size;
107 };
108 
109 template <typename Evaluator, bool Vectorize = Evaluator::Vectorize>
110 struct inner_product_impl;
111 
112 // scalar loop
113 template <typename Evaluator>
114 struct inner_product_impl<Evaluator, false> {
115  using Scalar = typename Evaluator::Scalar;
116  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
117  const Index size = eval.size();
118  if (size == 0) return Scalar(0);
119 
120  Scalar result = eval.coeff(0);
121  for (Index k = 1; k < size; k++) {
122  result = eval.coeff(result, k);
123  }
124 
125  return result;
126  }
127 };
128 
129 // vector loop
130 template <typename Evaluator>
131 struct inner_product_impl<Evaluator, true> {
132  using UnsignedIndex = std::make_unsigned_t<Index>;
133  using Scalar = typename Evaluator::Scalar;
134  using Packet = typename Evaluator::Packet;
135  static constexpr int PacketSize = unpacket_traits<Packet>::size;
136  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
137  const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
138  if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
139 
140  const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
141  const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
142  const UnsignedIndex numPackets = size / PacketSize;
143  const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
144 
145  Packet presult0, presult1, presult2, presult3;
146 
147  presult0 = eval.template packet<Packet>(0 * PacketSize);
148  if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
149  if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
150  if (numPackets >= 4) {
151  presult3 = eval.template packet<Packet>(3 * PacketSize);
152 
153  for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
154  presult0 = eval.packet(presult0, k + 0 * PacketSize);
155  presult1 = eval.packet(presult1, k + 1 * PacketSize);
156  presult2 = eval.packet(presult2, k + 2 * PacketSize);
157  presult3 = eval.packet(presult3, k + 3 * PacketSize);
158  }
159 
160  if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
161  if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
162  if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
163 
164  presult2 = padd(presult2, presult3);
165  }
166 
167  if (numPackets >= 3) presult1 = padd(presult1, presult2);
168  if (numPackets >= 2) presult0 = padd(presult0, presult1);
169 
170  Scalar result = predux(presult0);
171  for (UnsignedIndex k = packetEnd; k < size; k++) {
172  result = eval.coeff(result, k);
173  }
174 
175  return result;
176  }
177 };
178 
179 template <typename Scalar, bool Conj>
180 struct conditional_conj;
181 
182 template <typename Scalar>
183 struct conditional_conj<Scalar, true> {
184  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return numext::conj(a); }
185  template <typename Packet>
186  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
187  return pconj(a);
188  }
189 };
190 
191 template <typename Scalar>
192 struct conditional_conj<Scalar, false> {
193  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return a; }
194  template <typename Packet>
195  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
196  return a;
197  }
198 };
199 
200 template <typename LhsScalar, typename RhsScalar, bool Conj>
201 struct scalar_inner_product_op {
202  using result_type = typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType;
203  using conj_helper = conditional_conj<LhsScalar, Conj>;
204  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const {
205  return (conj_helper::coeff(a) * b);
206  }
207  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type& accum, const LhsScalar& a,
208  const RhsScalar& b) const {
209  return (conj_helper::coeff(a) * b) + accum;
210  }
211  static constexpr bool PacketAccess = false;
212 };
213 
214 // Partial specialization for packet access if and only if
215 // LhsScalar == RhsScalar == ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType.
216 template <typename Scalar, bool Conj>
217 struct scalar_inner_product_op<
218  Scalar,
219  typename std::enable_if<internal::is_same<typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType, Scalar>::value,
220  Scalar>::type,
221  Conj> {
222  using result_type = Scalar;
223  using conj_helper = conditional_conj<Scalar, Conj>;
224  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a, const Scalar& b) const {
225  return pmul(conj_helper::coeff(a), b);
226  }
227  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const {
228  return pmadd(conj_helper::coeff(a), b, accum);
229  }
230  template <typename Packet>
231  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a, const Packet& b) const {
232  return pmul(conj_helper::packet(a), b);
233  }
234  template <typename Packet>
235  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const {
236  return pmadd(conj_helper::packet(a), b, accum);
237  }
238  static constexpr bool PacketAccess = packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasAdd;
239 };
240 
241 template <typename Lhs, typename Rhs, bool Conj>
242 struct default_inner_product_impl {
243  using LhsScalar = typename traits<Lhs>::Scalar;
244  using RhsScalar = typename traits<Rhs>::Scalar;
245  using Op = scalar_inner_product_op<LhsScalar, RhsScalar, Conj>;
246  using Evaluator = inner_product_evaluator<Op, Lhs, Rhs>;
247  using result_type = typename Evaluator::Scalar;
248  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase<Lhs>& a, const MatrixBase<Rhs>& b) {
249  Evaluator eval(a.derived(), b.derived(), Op());
250  return inner_product_impl<Evaluator>::run(eval);
251  }
252 };
253 
254 template <typename Lhs, typename Rhs>
255 struct dot_impl : default_inner_product_impl<Lhs, Rhs, true> {};
256 
257 } // namespace internal
258 } // namespace Eigen
259 
260 #endif // EIGEN_INNER_PRODUCT_EVAL_H
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Definition: BFloat16.h:231
const unsigned int PacketAccessBit
Definition: Constants.h:97
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const int Dynamic
Definition: Constants.h:25