$darkmode
Eigen  5.0.1-dev
BFloat16.h
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef EIGEN_BFLOAT16_H
17 #define EIGEN_BFLOAT16_H
18 
19 // IWYU pragma: private
20 #include "../../InternalHeaderCheck.h"
21 
22 #if defined(EIGEN_HAS_HIP_BF16)
23 // When compiling with GPU support, the "hip_bfloat16" base class as well as
24 // some other routines are defined in the GPU compiler header files
25 // (hip_bfloat16.h), and they are not tagged constexpr
26 // As a consequence, we get compile failures when compiling Eigen with
27 // GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
28 // Eigen with GPU support
29 #pragma push_macro("EIGEN_CONSTEXPR")
30 #undef EIGEN_CONSTEXPR
31 #define EIGEN_CONSTEXPR
32 #endif
33 
34 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
35  template <> \
36  EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \
37  const PACKET_BF16& _x) { \
38  return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
39  }
40 
41 // Only use HIP GPU bf16 in kernels
42 #if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
43 #define EIGEN_USE_HIP_BF16
44 #endif
45 
46 namespace Eigen {
47 
48 struct bfloat16;
49 
50 namespace numext {
51 template <>
52 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src);
53 
54 template <>
55 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src);
56 } // namespace numext
57 namespace bfloat16_impl {
58 
59 #if defined(EIGEN_USE_HIP_BF16)
60 
61 struct __bfloat16_raw : public hip_bfloat16 {
62  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : hip_bfloat16(raw) {}
65 };
66 
67 #else
68 
69 // Make our own __bfloat16_raw definition.
70 struct __bfloat16_raw {
71 #if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
72  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
73 #else
74  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
75 #endif
76  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
77  unsigned short value;
78 };
79 
80 #endif // defined(EIGEN_USE_HIP_BF16)
81 
82 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
83 template <bool AssumeArgumentIsNormalOrInfinityOrZero>
84 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
85 // Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
86 // > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
87 template <>
88 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
89 template <>
90 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
91 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
92 
93 struct bfloat16_base : public __bfloat16_raw {
94  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
96 };
97 
98 } // namespace bfloat16_impl
99 
100 // Class definition.
101 struct bfloat16 : public bfloat16_impl::bfloat16_base {
102  typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
103 
104  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
105 
106  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
107 
108  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
109  : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
110 
111  template <class T>
112  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113  : bfloat16_impl::bfloat16_base(
114  bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
115 
116  explicit EIGEN_DEVICE_FUNC bfloat16(float f)
117  : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
118 
119  // Following the convention of numpy, converting between complex and
120  // float will lead to loss of imag value.
121  template <typename RealScalar>
122  explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
123  : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
124 
125  EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
126  return bfloat16_impl::bfloat16_to_float(*this);
127  }
128 };
129 
130 // TODO(majnemer): Get rid of this once we can rely on C++17 inline variables do
131 // solve the ODR issue.
132 namespace bfloat16_impl {
133 template <typename = void>
134 struct numeric_limits_bfloat16_impl {
135  static EIGEN_CONSTEXPR const bool is_specialized = true;
136  static EIGEN_CONSTEXPR const bool is_signed = true;
137  static EIGEN_CONSTEXPR const bool is_integer = false;
138  static EIGEN_CONSTEXPR const bool is_exact = false;
139  static EIGEN_CONSTEXPR const bool has_infinity = true;
140  static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
141  static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
142  EIGEN_DIAGNOSTICS(push)
143  EIGEN_DISABLE_DEPRECATED_WARNING
144  static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
145  static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
146  EIGEN_DIAGNOSTICS(pop)
147  static EIGEN_CONSTEXPR const std::float_round_style round_style = std::numeric_limits<float>::round_style;
148  static EIGEN_CONSTEXPR const bool is_iec559 = true;
149  // The C++ standard defines this as "true if the set of values representable
150  // by the type is finite." BFloat16 has finite precision.
151  static EIGEN_CONSTEXPR const bool is_bounded = true;
152  static EIGEN_CONSTEXPR const bool is_modulo = false;
153  static EIGEN_CONSTEXPR const int digits = 8;
154  static EIGEN_CONSTEXPR const int digits10 = 2;
155  static EIGEN_CONSTEXPR const int max_digits10 = 4;
156  static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
157  static EIGEN_CONSTEXPR const int min_exponent = std::numeric_limits<float>::min_exponent;
158  static EIGEN_CONSTEXPR const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
159  static EIGEN_CONSTEXPR const int max_exponent = std::numeric_limits<float>::max_exponent;
160  static EIGEN_CONSTEXPR const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
161  static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
162  // IEEE754: "The implementer shall choose how tininess is detected, but shall
163  // detect tininess in the same way for all operations in radix two"
164  static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
165 
166  static EIGEN_CONSTEXPR Eigen::bfloat16(min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
167  static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
168  static EIGEN_CONSTEXPR Eigen::bfloat16(max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
169  static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
170  static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
171  static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
172  static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
173  static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
174  return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
175  }
176  static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
177 };
178 
179 template <typename T>
180 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_specialized;
181 template <typename T>
182 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_signed;
183 template <typename T>
184 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_integer;
185 template <typename T>
186 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_exact;
187 template <typename T>
188 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_infinity;
189 template <typename T>
190 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
191 template <typename T>
192 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
193 EIGEN_DIAGNOSTICS(push)
194 EIGEN_DISABLE_DEPRECATED_WARNING
195 template <typename T>
196 EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
197 template <typename T>
198 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
199 EIGEN_DIAGNOSTICS(pop)
200 template <typename T>
201 EIGEN_CONSTEXPR const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
202 template <typename T>
203 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_iec559;
204 template <typename T>
205 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_bounded;
206 template <typename T>
207 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_modulo;
208 template <typename T>
209 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits;
210 template <typename T>
211 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits10;
212 template <typename T>
213 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_digits10;
214 template <typename T>
215 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::radix;
216 template <typename T>
217 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent;
218 template <typename T>
219 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent10;
220 template <typename T>
221 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent;
222 template <typename T>
223 EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent10;
224 template <typename T>
225 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::traps;
226 template <typename T>
227 EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
228 } // end namespace bfloat16_impl
229 } // end namespace Eigen
230 
231 namespace std {
232 // If std::numeric_limits<T> is specialized, should also specialize
233 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
234 // std::numeric_limits<const volatile T>
235 // https://stackoverflow.com/a/16519653/
236 template <>
237 class numeric_limits<Eigen::bfloat16> : public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
238 template <>
239 class numeric_limits<const Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
240 template <>
241 class numeric_limits<volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
242 template <>
243 class numeric_limits<const volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
244 } // end namespace std
245 
246 namespace Eigen {
247 
248 namespace bfloat16_impl {
249 
250 // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
251 // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
252 // of the functions, while the latter can only deal with one of them.
253 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
254 
255 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
256 // We need to provide emulated *host-side* BF16 operators for clang.
257 #pragma push_macro("EIGEN_DEVICE_FUNC")
258 #undef EIGEN_DEVICE_FUNC
259 #if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
260 #define EIGEN_DEVICE_FUNC __host__
261 #else // both host and device need emulated ops.
262 #define EIGEN_DEVICE_FUNC __host__ __device__
263 #endif
264 #endif
265 
266 // Definitions for CPUs, mostly working through conversion
267 // to/from fp32.
268 
269 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
270  return bfloat16(float(a) + float(b));
271 }
272 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const int& b) {
273  return bfloat16(float(a) + static_cast<float>(b));
274 }
275 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const int& a, const bfloat16& b) {
276  return bfloat16(static_cast<float>(a) + float(b));
277 }
278 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
279  return bfloat16(float(a) * float(b));
280 }
281 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
282  return bfloat16(float(a) - float(b));
283 }
284 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
285  return bfloat16(float(a) / float(b));
286 }
287 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a) {
288  numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
289  return numext::bit_cast<bfloat16>(x);
290 }
291 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a, const bfloat16& b) {
292  a = bfloat16(float(a) + float(b));
293  return a;
294 }
295 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a, const bfloat16& b) {
296  a = bfloat16(float(a) * float(b));
297  return a;
298 }
299 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a, const bfloat16& b) {
300  a = bfloat16(float(a) - float(b));
301  return a;
302 }
303 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a, const bfloat16& b) {
304  a = bfloat16(float(a) / float(b));
305  return a;
306 }
307 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
308  a += bfloat16(1);
309  return a;
310 }
311 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
312  a -= bfloat16(1);
313  return a;
314 }
315 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
316  bfloat16 original_value = a;
317  ++a;
318  return original_value;
319 }
320 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
321  bfloat16 original_value = a;
322  --a;
323  return original_value;
324 }
325 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16& a, const bfloat16& b) {
326  return numext::equal_strict(float(a), float(b));
327 }
328 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const bfloat16& a, const bfloat16& b) {
329  return numext::not_equal_strict(float(a), float(b));
330 }
331 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const bfloat16& a, const bfloat16& b) {
332  return float(a) < float(b);
333 }
334 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16& a, const bfloat16& b) {
335  return float(a) <= float(b);
336 }
337 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16& a, const bfloat16& b) {
338  return float(a) > float(b);
339 }
340 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const bfloat16& a, const bfloat16& b) {
341  return float(a) >= float(b);
342 }
343 
344 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
345 #pragma pop_macro("EIGEN_DEVICE_FUNC")
346 #endif
347 #endif // Emulate support for bfloat16 floats
348 
349 // Division by an index. Do it in full float precision to avoid accuracy
350 // issues in converting the denominator to bfloat16.
351 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, Index b) {
352  return bfloat16(static_cast<float>(a) / static_cast<float>(b));
353 }
354 
355 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
356 #if defined(EIGEN_USE_HIP_BF16)
357  return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
358 #else
359  __bfloat16_raw output;
360  if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
361  output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
362  return output;
363  }
364  output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
365  return output;
366 #endif
367 }
368 
369 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
370 #if defined(EIGEN_USE_HIP_BF16)
371  __bfloat16_raw bf;
372  bf.data = value;
373  return bf;
374 #else
375  return __bfloat16_raw(value);
376 #endif
377 }
378 
379 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
380  const __bfloat16_raw& bf) {
381 #if defined(EIGEN_USE_HIP_BF16)
382  return bf.data;
383 #else
384  return bf.value;
385 #endif
386 }
387 
388 // float_to_bfloat16_rtne template specialization that does not make any
389 // assumption about the value of its function argument (ff).
390 template <>
391 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
392 #if defined(EIGEN_USE_HIP_BF16)
393  return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
394 #else
395  __bfloat16_raw output;
396 
397  if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
398  // If the value is a NaN, squash it to a qNaN with msb of fraction set,
399  // this makes sure after truncation we don't end up with an inf.
400  //
401  // qNaN magic: All exponent bits set + most significant bit of fraction
402  // set.
403  output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
404  } else {
405  // Fast rounding algorithm that rounds a half value to nearest even. This
406  // reduces expected error when we convert a large number of floats. Here
407  // is how it works:
408  //
409  // Definitions:
410  // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
411  // with the following tags:
412  //
413  // Sign | Exp (8 bits) | Frac (23 bits)
414  // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
415  //
416  // S: Sign bit.
417  // E: Exponent bits.
418  // F: First 6 bits of fraction.
419  // L: Least significant bit of resulting bfloat16 if we truncate away the
420  // rest of the float32. This is also the 7th bit of fraction
421  // R: Rounding bit, 8th bit of fraction.
422  // T: Sticky bits, rest of fraction, 15 bits.
423  //
424  // To round half to nearest even, there are 3 cases where we want to round
425  // down (simply truncate the result of the bits away, which consists of
426  // rounding bit and sticky bits) and two cases where we want to round up
427  // (truncate then add one to the result).
428  //
429  // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
430  // 1s) as the rounding bias, adds the rounding bias to the input, then
431  // truncates the last 16 bits away.
432  //
433  // To understand how it works, we can analyze this algorithm case by case:
434  //
435  // 1. L = 0, R = 0:
436  // Expect: round down, this is less than half value.
437  //
438  // Algorithm:
439  // - Rounding bias: 0x7fff + 0 = 0x7fff
440  // - Adding rounding bias to input may create any carry, depending on
441  // whether there is any value set to 1 in T bits.
442  // - R may be set to 1 if there is a carry.
443  // - L remains 0.
444  // - Note that this case also handles Inf and -Inf, where all fraction
445  // bits, including L, R and Ts are all 0. The output remains Inf after
446  // this algorithm.
447  //
448  // 2. L = 1, R = 0:
449  // Expect: round down, this is less than half value.
450  //
451  // Algorithm:
452  // - Rounding bias: 0x7fff + 1 = 0x8000
453  // - Adding rounding bias to input doesn't change sticky bits but
454  // adds 1 to rounding bit.
455  // - L remains 1.
456  //
457  // 3. L = 0, R = 1, all of T are 0:
458  // Expect: round down, this is exactly at half, the result is already
459  // even (L=0).
460  //
461  // Algorithm:
462  // - Rounding bias: 0x7fff + 0 = 0x7fff
463  // - Adding rounding bias to input sets all sticky bits to 1, but
464  // doesn't create a carry.
465  // - R remains 1.
466  // - L remains 0.
467  //
468  // 4. L = 1, R = 1:
469  // Expect: round up, this is exactly at half, the result needs to be
470  // round to the next even number.
471  //
472  // Algorithm:
473  // - Rounding bias: 0x7fff + 1 = 0x8000
474  // - Adding rounding bias to input doesn't change sticky bits, but
475  // creates a carry from rounding bit.
476  // - The carry sets L to 0, creates another carry bit and propagate
477  // forward to F bits.
478  // - If all the F bits are 1, a carry then propagates to the exponent
479  // bits, which then creates the minimum value with the next exponent
480  // value. Note that we won't have the case where exponents are all 1,
481  // since that's either a NaN (handled in the other if condition) or inf
482  // (handled in case 1).
483  //
484  // 5. L = 0, R = 1, any of T is 1:
485  // Expect: round up, this is greater than half.
486  //
487  // Algorithm:
488  // - Rounding bias: 0x7fff + 0 = 0x7fff
489  // - Adding rounding bias to input creates a carry from sticky bits,
490  // sets rounding bit to 0, then create another carry.
491  // - The second carry sets L to 1.
492  //
493  // Examples:
494  //
495  // Exact half value that is already even:
496  // Input:
497  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
498  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
499  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
500  //
501  // This falls into case 3. We truncate the rest of 16 bits and no
502  // carry is created into F and L:
503  //
504  // Output:
505  // Sign | Exp (8 bit) | Frac (first 7 bit)
506  // S E E E E E E E E F F F F F F L
507  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
508  //
509  // Exact half value, round to next even number:
510  // Input:
511  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
512  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
513  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
514  //
515  // This falls into case 4. We create a carry from R and T,
516  // which then propagates into L and F:
517  //
518  // Output:
519  // Sign | Exp (8 bit) | Frac (first 7 bit)
520  // S E E E E E E E E F F F F F F L
521  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
522  //
523  //
524  // Max denormal value round to min normal value:
525  // Input:
526  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
527  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
528  // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
529  //
530  // This falls into case 4. We create a carry from R and T,
531  // propagate into L and F, which then propagates into exponent
532  // bits:
533  //
534  // Output:
535  // Sign | Exp (8 bit) | Frac (first 7 bit)
536  // S E E E E E E E E F F F F F F L
537  // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
538  //
539  // Max normal value round to Inf:
540  // Input:
541  // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
542  // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
543  // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
544  //
545  // This falls into case 4. We create a carry from R and T,
546  // propagate into L and F, which then propagates into exponent
547  // bits:
548  //
549  // Sign | Exp (8 bit) | Frac (first 7 bit)
550  // S E E E E E E E E F F F F F F L
551  // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
552 
553  // At this point, ff must be either a normal float, or +/-infinity.
554  output = float_to_bfloat16_rtne<true>(ff);
555  }
556  return output;
557 #endif
558 }
559 
560 // float_to_bfloat16_rtne template specialization that assumes that its function
561 // argument (ff) is either a normal floating point number, or +/-infinity, or
562 // zero. Used to improve the runtime performance of conversion from an integer
563 // type to bfloat16.
564 template <>
565 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
566 #if defined(EIGEN_USE_HIP_BF16)
567  return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
568 #else
569  numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
570  __bfloat16_raw output;
571 
572  // Least significant bit of resulting bfloat.
573  numext::uint32_t lsb = (input >> 16) & 1;
574  numext::uint32_t rounding_bias = 0x7fff + lsb;
575  input += rounding_bias;
576  output.value = static_cast<numext::uint16_t>(input >> 16);
577  return output;
578 #endif
579 }
580 
581 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
582 #if defined(EIGEN_USE_HIP_BF16)
583  return static_cast<float>(h);
584 #else
585  return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
586 #endif
587 }
588 
589 // --- standard functions ---
590 
591 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const bfloat16& a) {
592  EIGEN_USING_STD(isinf);
593 #if defined(EIGEN_USE_HIP_BF16)
594  return (isinf)(a); // Uses HIP hip_bfloat16 isinf operator
595 #else
596  return (isinf)(float(a));
597 #endif
598 }
599 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const bfloat16& a) {
600  EIGEN_USING_STD(isnan);
601 #if defined(EIGEN_USE_HIP_BF16)
602  return (isnan)(a); // Uses HIP hip_bfloat16 isnan operator
603 #else
604  return (isnan)(float(a));
605 #endif
606 }
607 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const bfloat16& a) {
608  return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
609 }
610 
611 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
612  numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
613  return numext::bit_cast<bfloat16>(x);
614 }
615 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { return bfloat16(::expf(float(a))); }
616 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp2(const bfloat16& a) { return bfloat16(::exp2f(float(a))); }
617 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { return bfloat16(numext::expm1(float(a))); }
618 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { return bfloat16(::logf(float(a))); }
619 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { return bfloat16(numext::log1p(float(a))); }
620 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { return bfloat16(::log10f(float(a))); }
621 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
622  return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
623 }
624 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { return bfloat16(::sqrtf(float(a))); }
625 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
626  return bfloat16(::powf(float(a), float(b)));
627 }
628 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
629  return bfloat16(::atan2f(float(a), float(b)));
630 }
631 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { return bfloat16(::sinf(float(a))); }
632 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { return bfloat16(::cosf(float(a))); }
633 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { return bfloat16(::tanf(float(a))); }
634 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { return bfloat16(::asinf(float(a))); }
635 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { return bfloat16(::acosf(float(a))); }
636 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { return bfloat16(::atanf(float(a))); }
637 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { return bfloat16(::sinhf(float(a))); }
638 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { return bfloat16(::coshf(float(a))); }
639 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { return bfloat16(::tanhf(float(a))); }
640 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { return bfloat16(::asinhf(float(a))); }
641 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { return bfloat16(::acoshf(float(a))); }
642 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { return bfloat16(::atanhf(float(a))); }
643 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { return bfloat16(::floorf(float(a))); }
644 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { return bfloat16(::ceilf(float(a))); }
645 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { return bfloat16(::rintf(float(a))); }
646 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { return bfloat16(::roundf(float(a))); }
647 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(const bfloat16& a) { return bfloat16(::truncf(float(a))); }
648 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
649  return bfloat16(::fmodf(float(a), float(b)));
650 }
651 
652 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(const bfloat16& a, const bfloat16& b) {
653  const float f1 = static_cast<float>(a);
654  const float f2 = static_cast<float>(b);
655  return f2 < f1 ? b : a;
656 }
657 
658 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(const bfloat16& a, const bfloat16& b) {
659  const float f1 = static_cast<float>(a);
660  const float f2 = static_cast<float>(b);
661  return f1 < f2 ? b : a;
662 }
663 
664 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
665  const float f1 = static_cast<float>(a);
666  const float f2 = static_cast<float>(b);
667  return bfloat16(::fminf(f1, f2));
668 }
669 
670 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
671  const float f1 = static_cast<float>(a);
672  const float f2 = static_cast<float>(b);
673  return bfloat16(::fmaxf(f1, f2));
674 }
675 
676 EIGEN_DEVICE_FUNC inline bfloat16 fma(const bfloat16& a, const bfloat16& b, const bfloat16& c) {
677  // Emulate FMA via float.
678  return bfloat16(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
679 }
680 
681 #ifndef EIGEN_NO_IO
682 EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) {
683  os << static_cast<float>(v);
684  return os;
685 }
686 #endif
687 
688 } // namespace bfloat16_impl
689 
690 namespace internal {
691 
692 template <>
693 struct is_arithmetic<bfloat16> {
694  enum { value = true };
695 };
696 
697 template <>
698 struct random_impl<bfloat16> {
699  enum : int { MantissaBits = 7 };
700  using Impl = random_impl<float>;
701  static EIGEN_DEVICE_FUNC inline bfloat16 run(const bfloat16& x, const bfloat16& y) {
702  float result = Impl::run(x, y, MantissaBits);
703  return bfloat16(result);
704  }
705  static EIGEN_DEVICE_FUNC inline bfloat16 run() {
706  float result = Impl::run(MantissaBits);
707  return bfloat16(result);
708  }
709 };
710 
711 } // namespace internal
712 
713 template <>
714 struct NumTraits<Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
715  enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
716 
717  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
718  return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
719  }
720  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
721  return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
722  }
723  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
724  return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
725  }
726  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
727  return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
728  }
729  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
730  return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
731  }
732  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
733  return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
734  }
735 };
736 
737 } // namespace Eigen
738 
739 #if defined(EIGEN_HAS_HIP_BF16)
740 #pragma pop_macro("EIGEN_CONSTEXPR")
741 #endif
742 
743 namespace Eigen {
744 namespace numext {
745 
746 template <>
747 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::bfloat16& h) {
748  return (bfloat16_impl::isnan)(h);
749 }
750 
751 template <>
752 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::bfloat16& h) {
753  return (bfloat16_impl::isinf)(h);
754 }
755 
756 template <>
757 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::bfloat16& h) {
758  return (bfloat16_impl::isfinite)(h);
759 }
760 
761 template <>
762 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
763  return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
764 }
765 
766 template <>
767 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
768  return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
769 }
770 
771 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, const bfloat16& to) {
772  if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
773  return from;
774  }
775  if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
776  return to;
777  }
778  if (from == to) {
779  return to;
780  }
781  uint16_t from_bits = numext::bit_cast<uint16_t>(from);
782  bool from_sign = from_bits >> 15;
783  // Whether we are adjusting toward the infinity with the same sign as from.
784  bool toward_inf = (to > from) == !from_sign;
785  if (toward_inf) {
786  ++from_bits;
787  } else if ((from_bits & 0x7fff) == 0) {
788  // Adjusting away from inf, but from is zero, so just toggle the sign.
789  from_bits ^= 0x8000;
790  } else {
791  --from_bits;
792  }
793  return numext::bit_cast<bfloat16>(from_bits);
794 }
795 
796 // Specialize multiply-add to match packet operations and reduce conversions to/from float.
797 template<>
798 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
799  return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
800 }
801 
802 } // namespace numext
803 } // namespace Eigen
804 
805 #if EIGEN_HAS_STD_HASH
806 namespace std {
807 template <>
808 struct hash<Eigen::bfloat16> {
809  EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
810  return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
811  }
812 };
813 } // namespace std
814 #endif
815 
816 // Add the missing shfl* intrinsics.
817 // The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
818 // CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
819 //
820 // HIP and CUDA prior to SDK 9.0 define
821 // __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
822 // CUDA since 9.0 deprecates those and instead defines
823 // __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
824 // with native support for __half and __nv_bfloat16
825 //
826 // Note that the following are __device__ - only functions.
827 #if defined(EIGEN_HIPCC)
828 
829 #if defined(EIGEN_HAS_HIP_BF16)
830 
831 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var, int srcLane, int width = warpSize) {
832  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
833  return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
834 }
835 
836 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var, unsigned int delta,
837  int width = warpSize) {
838  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
839  return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
840 }
841 
842 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var, unsigned int delta,
843  int width = warpSize) {
844  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
845  return Eigen::numext::bit_cast<Eigen::bfloat16>(
846  static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
847 }
848 
849 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var, int laneMask, int width = warpSize) {
850  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
851  return Eigen::numext::bit_cast<Eigen::bfloat16>(
852  static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
853 }
854 
855 #endif // HIP
856 
857 #endif // __shfl*
858 
859 #if defined(EIGEN_HIPCC)
860 EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(const Eigen::bfloat16* ptr) {
861  return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
862  __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
863 }
864 #endif // __ldg
865 
866 #endif // EIGEN_BFLOAT16_H
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_tanh_op< typename Derived::Scalar >, const Derived > tanh(const Eigen::ArrayBase< Derived > &x)
const Product< MatrixDerived, PermutationDerived, DefaultProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:474
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sinh_op< typename Derived::Scalar >, const Derived > sinh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isfinite_op< typename Derived::Scalar >, const Derived > isfinite(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sqrt_op< typename Derived::Scalar >, const Derived > sqrt(const Eigen::ArrayBase< Derived > &x)
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Definition: BFloat16.h:231
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ceil_op< typename Derived::Scalar >, const Derived > ceil(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_asin_op< typename Derived::Scalar >, const Derived > asin(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_acos_op< typename Derived::Scalar >, const Derived > acos(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_exp2_op< typename Derived::Scalar >, const Derived > exp2(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isnan_op< typename Derived::Scalar >, const Derived > isnan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_cos_op< typename Derived::Scalar >, const Derived > cos(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_round_op< typename Derived::Scalar >, const Derived > round(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_rint_op< typename Derived::Scalar >, const Derived > rint(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_floor_op< typename Derived::Scalar >, const Derived > floor(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log1p_op< typename Derived::Scalar >, const Derived > log1p(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isinf_op< typename Derived::Scalar >, const Derived > isinf(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_abs_op< typename Derived::Scalar >, const Derived > abs(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_cosh_op< typename Derived::Scalar >, const Derived > cosh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log_op< typename Derived::Scalar >, const Derived > log(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_tan_op< typename Derived::Scalar >, const Derived > tan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_expm1_op< typename Derived::Scalar >, const Derived > expm1(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_atanh_op< typename Derived::Scalar >, const Derived > atanh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log2_op< typename Derived::Scalar >, const Derived > log2(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_atan_op< typename Derived::Scalar >, const Derived > atan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sin_op< typename Derived::Scalar >, const Derived > sin(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_exp_op< typename Derived::Scalar >, const Derived > exp(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log10_op< typename Derived::Scalar >, const Derived > log10(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_acosh_op< typename Derived::Scalar >, const Derived > acosh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_asinh_op< typename Derived::Scalar >, const Derived > asinh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_trunc_op< typename Derived::Scalar >, const Derived > trunc(const Eigen::ArrayBase< Derived > &x)