16 #ifndef EIGEN_BFLOAT16_H 17 #define EIGEN_BFLOAT16_H 20 #include "../../InternalHeaderCheck.h" 22 #if defined(EIGEN_HAS_HIP_BF16) 29 #pragma push_macro("EIGEN_CONSTEXPR") 30 #undef EIGEN_CONSTEXPR 31 #define EIGEN_CONSTEXPR 34 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \ 36 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \ 37 const PACKET_BF16& _x) { \ 38 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \ 42 #if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE) 43 #define EIGEN_USE_HIP_BF16 52 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src);
55 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src);
57 namespace bfloat16_impl {
59 #if defined(EIGEN_USE_HIP_BF16) 61 struct __bfloat16_raw :
public hip_bfloat16 {
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : hip_bfloat16(raw) {}
70 struct __bfloat16_raw {
71 #if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE) 72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
74 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
76 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : value(raw) {}
80 #endif // defined(EIGEN_USE_HIP_BF16) 82 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(
unsigned short value);
83 template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
84 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(
float ff);
88 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff);
90 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff);
91 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h);
93 struct bfloat16_base :
public __bfloat16_raw {
94 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(
const __bfloat16_raw& h) : __bfloat16_raw(h) {}
101 struct bfloat16 :
public bfloat16_impl::bfloat16_base {
102 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
104 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
106 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
108 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
bool b)
109 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
112 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113 : bfloat16_impl::bfloat16_base(
114 bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
116 explicit EIGEN_DEVICE_FUNC bfloat16(
float f)
117 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
121 template <
typename RealScalar>
122 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const std::complex<RealScalar>& val)
123 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.
real()))) {}
125 EIGEN_DEVICE_FUNC
operator float()
const {
126 return bfloat16_impl::bfloat16_to_float(*
this);
132 namespace bfloat16_impl {
133 template <
typename =
void>
134 struct numeric_limits_bfloat16_impl {
135 static EIGEN_CONSTEXPR
const bool is_specialized =
true;
136 static EIGEN_CONSTEXPR
const bool is_signed =
true;
137 static EIGEN_CONSTEXPR
const bool is_integer =
false;
138 static EIGEN_CONSTEXPR
const bool is_exact =
false;
139 static EIGEN_CONSTEXPR
const bool has_infinity =
true;
140 static EIGEN_CONSTEXPR
const bool has_quiet_NaN =
true;
141 static EIGEN_CONSTEXPR
const bool has_signaling_NaN =
true;
142 EIGEN_DIAGNOSTICS(push)
143 EIGEN_DISABLE_DEPRECATED_WARNING
144 static EIGEN_CONSTEXPR
const std::float_denorm_style has_denorm = std::denorm_present;
145 static EIGEN_CONSTEXPR
const bool has_denorm_loss =
false;
146 EIGEN_DIAGNOSTICS(pop)
147 static EIGEN_CONSTEXPR
const std::float_round_style round_style = std::numeric_limits<float>::round_style;
148 static EIGEN_CONSTEXPR
const bool is_iec559 =
true;
151 static EIGEN_CONSTEXPR
const bool is_bounded =
true;
152 static EIGEN_CONSTEXPR
const bool is_modulo =
false;
153 static EIGEN_CONSTEXPR
const int digits = 8;
154 static EIGEN_CONSTEXPR
const int digits10 = 2;
155 static EIGEN_CONSTEXPR
const int max_digits10 = 4;
156 static EIGEN_CONSTEXPR
const int radix = std::numeric_limits<float>::radix;
157 static EIGEN_CONSTEXPR
const int min_exponent = std::numeric_limits<float>::min_exponent;
158 static EIGEN_CONSTEXPR
const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
159 static EIGEN_CONSTEXPR
const int max_exponent = std::numeric_limits<float>::max_exponent;
160 static EIGEN_CONSTEXPR
const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
161 static EIGEN_CONSTEXPR
const bool traps = std::numeric_limits<float>::traps;
164 static EIGEN_CONSTEXPR
const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
166 static EIGEN_CONSTEXPR Eigen::bfloat16(min)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
167 static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
168 static EIGEN_CONSTEXPR Eigen::bfloat16(max)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
169 static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
170 static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
171 static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
172 static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
173 static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
174 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
176 static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
179 template <
typename T>
180 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_specialized;
181 template <
typename T>
182 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_signed;
183 template <
typename T>
184 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_integer;
185 template <
typename T>
186 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_exact;
187 template <
typename T>
188 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_infinity;
189 template <
typename T>
190 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
191 template <
typename T>
192 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
193 EIGEN_DIAGNOSTICS(push)
194 EIGEN_DISABLE_DEPRECATED_WARNING
195 template <
typename T>
196 EIGEN_CONSTEXPR
const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
197 template <
typename T>
198 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
199 EIGEN_DIAGNOSTICS(pop)
200 template <
typename T>
201 EIGEN_CONSTEXPR
const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
202 template <
typename T>
203 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_iec559;
204 template <
typename T>
205 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_bounded;
206 template <
typename T>
207 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_modulo;
208 template <
typename T>
209 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits;
210 template <
typename T>
211 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits10;
212 template <
typename T>
213 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_digits10;
214 template <
typename T>
215 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::radix;
216 template <
typename T>
217 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent;
218 template <
typename T>
219 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent10;
220 template <
typename T>
221 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent;
222 template <
typename T>
223 EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent10;
224 template <
typename T>
225 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::traps;
226 template <
typename T>
227 EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
237 class numeric_limits<
Eigen::bfloat16> :
public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
239 class numeric_limits<const
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
241 class numeric_limits<volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
243 class numeric_limits<const volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
248 namespace bfloat16_impl {
253 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats 255 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) 257 #pragma push_macro("EIGEN_DEVICE_FUNC") 258 #undef EIGEN_DEVICE_FUNC 259 #if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16)) 260 #define EIGEN_DEVICE_FUNC __host__ 261 #else // both host and device need emulated ops. 262 #define EIGEN_DEVICE_FUNC __host__ __device__ 269 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const bfloat16& b) {
270 return bfloat16(
float(a) +
float(b));
272 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const int& b) {
273 return bfloat16(
float(a) + static_cast<float>(b));
275 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const int& a,
const bfloat16& b) {
276 return bfloat16(static_cast<float>(a) +
float(b));
278 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
operator*(
const bfloat16& a,
const bfloat16& b) {
279 return bfloat16(
float(a) *
float(b));
281 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a,
const bfloat16& b) {
282 return bfloat16(
float(a) -
float(b));
284 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a,
const bfloat16& b) {
285 return bfloat16(
float(a) /
float(b));
287 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a) {
288 numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
289 return numext::bit_cast<bfloat16>(x);
291 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a,
const bfloat16& b) {
292 a = bfloat16(
float(a) +
float(b));
295 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a,
const bfloat16& b) {
296 a = bfloat16(
float(a) *
float(b));
299 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a,
const bfloat16& b) {
300 a = bfloat16(
float(a) -
float(b));
303 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a,
const bfloat16& b) {
304 a = bfloat16(
float(a) /
float(b));
307 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
311 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
315 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a,
int) {
316 bfloat16 original_value = a;
318 return original_value;
320 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a,
int) {
321 bfloat16 original_value = a;
323 return original_value;
325 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator==(
const bfloat16& a,
const bfloat16& b) {
326 return numext::equal_strict(
float(a),
float(b));
328 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator!=(
const bfloat16& a,
const bfloat16& b) {
329 return numext::not_equal_strict(
float(a),
float(b));
331 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<(
const bfloat16& a,
const bfloat16& b) {
332 return float(a) < float(b);
334 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<=(
const bfloat16& a,
const bfloat16& b) {
335 return float(a) <= float(b);
337 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>(
const bfloat16& a,
const bfloat16& b) {
338 return float(a) > float(b);
340 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>=(
const bfloat16& a,
const bfloat16& b) {
341 return float(a) >= float(b);
344 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) 345 #pragma pop_macro("EIGEN_DEVICE_FUNC") 347 #endif // Emulate support for bfloat16 floats 351 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a,
Index b) {
352 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
355 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(
const float v) {
356 #if defined(EIGEN_USE_HIP_BF16) 357 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
359 __bfloat16_raw output;
360 if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
361 output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
364 output.value =
static_cast<numext::uint16_t
>(numext::bit_cast<numext::uint32_t>(v) >> 16);
369 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
370 #if defined(EIGEN_USE_HIP_BF16) 375 return __bfloat16_raw(value);
379 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
380 const __bfloat16_raw& bf) {
381 #if defined(EIGEN_USE_HIP_BF16) 391 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff) {
392 #if defined(EIGEN_USE_HIP_BF16) 393 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
395 __bfloat16_raw output;
397 if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
403 output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
554 output = float_to_bfloat16_rtne<true>(ff);
565 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff) {
566 #if defined(EIGEN_USE_HIP_BF16) 567 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
569 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
570 __bfloat16_raw output;
573 numext::uint32_t lsb = (input >> 16) & 1;
574 numext::uint32_t rounding_bias = 0x7fff + lsb;
575 input += rounding_bias;
576 output.value =
static_cast<numext::uint16_t
>(input >> 16);
581 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h) {
582 #if defined(EIGEN_USE_HIP_BF16) 583 return static_cast<float>(h);
585 return numext::bit_cast<
float>(
static_cast<numext::uint32_t
>(h.value) << 16);
591 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(
isinf)(
const bfloat16& a) {
592 EIGEN_USING_STD(
isinf);
593 #if defined(EIGEN_USE_HIP_BF16) 596 return (
isinf)(float(a));
599 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(
isnan)(
const bfloat16& a) {
600 EIGEN_USING_STD(
isnan);
601 #if defined(EIGEN_USE_HIP_BF16) 604 return (
isnan)(float(a));
607 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(
isfinite)(
const bfloat16& a) {
608 return !(
isinf EIGEN_NOT_A_MACRO(a)) && !(
isnan EIGEN_NOT_A_MACRO(a));
611 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
abs(
const bfloat16& a) {
612 numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
613 return numext::bit_cast<bfloat16>(x);
615 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
exp(
const bfloat16& a) {
return bfloat16(::expf(
float(a))); }
616 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
exp2(
const bfloat16& a) {
return bfloat16(::exp2f(
float(a))); }
617 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
expm1(
const bfloat16& a) {
return bfloat16(numext::expm1(
float(a))); }
618 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
log(
const bfloat16& a) {
return bfloat16(::logf(
float(a))); }
619 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
log1p(
const bfloat16& a) {
return bfloat16(numext::log1p(
float(a))); }
620 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
log10(
const bfloat16& a) {
return bfloat16(::log10f(
float(a))); }
621 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
log2(
const bfloat16& a) {
622 return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(
float(a)));
624 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
sqrt(
const bfloat16& a) {
return bfloat16(::sqrtf(
float(a))); }
625 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(
const bfloat16& a,
const bfloat16& b) {
626 return bfloat16(::powf(
float(a),
float(b)));
628 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(
const bfloat16& a,
const bfloat16& b) {
629 return bfloat16(::atan2f(
float(a),
float(b)));
631 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
sin(
const bfloat16& a) {
return bfloat16(::sinf(
float(a))); }
632 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
cos(
const bfloat16& a) {
return bfloat16(::cosf(
float(a))); }
633 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
tan(
const bfloat16& a) {
return bfloat16(::tanf(
float(a))); }
634 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
asin(
const bfloat16& a) {
return bfloat16(::asinf(
float(a))); }
635 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
acos(
const bfloat16& a) {
return bfloat16(::acosf(
float(a))); }
636 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
atan(
const bfloat16& a) {
return bfloat16(::atanf(
float(a))); }
637 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
sinh(
const bfloat16& a) {
return bfloat16(::sinhf(
float(a))); }
638 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
cosh(
const bfloat16& a) {
return bfloat16(::coshf(
float(a))); }
639 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
tanh(
const bfloat16& a) {
return bfloat16(::tanhf(
float(a))); }
640 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
asinh(
const bfloat16& a) {
return bfloat16(::asinhf(
float(a))); }
641 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
acosh(
const bfloat16& a) {
return bfloat16(::acoshf(
float(a))); }
642 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
atanh(
const bfloat16& a) {
return bfloat16(::atanhf(
float(a))); }
643 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
floor(
const bfloat16& a) {
return bfloat16(::floorf(
float(a))); }
644 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
ceil(
const bfloat16& a) {
return bfloat16(::ceilf(
float(a))); }
645 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
rint(
const bfloat16& a) {
return bfloat16(::rintf(
float(a))); }
646 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
round(
const bfloat16& a) {
return bfloat16(::roundf(
float(a))); }
647 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16
trunc(
const bfloat16& a) {
return bfloat16(::truncf(
float(a))); }
648 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(
const bfloat16& a,
const bfloat16& b) {
649 return bfloat16(::fmodf(
float(a),
float(b)));
652 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(
const bfloat16& a,
const bfloat16& b) {
653 const float f1 =
static_cast<float>(a);
654 const float f2 =
static_cast<float>(b);
655 return f2 < f1 ? b : a;
658 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(
const bfloat16& a,
const bfloat16& b) {
659 const float f1 =
static_cast<float>(a);
660 const float f2 =
static_cast<float>(b);
661 return f1 < f2 ? b : a;
664 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(
const bfloat16& a,
const bfloat16& b) {
665 const float f1 =
static_cast<float>(a);
666 const float f2 =
static_cast<float>(b);
667 return bfloat16(::fminf(f1, f2));
670 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(
const bfloat16& a,
const bfloat16& b) {
671 const float f1 =
static_cast<float>(a);
672 const float f2 =
static_cast<float>(b);
673 return bfloat16(::fmaxf(f1, f2));
676 EIGEN_DEVICE_FUNC
inline bfloat16 fma(
const bfloat16& a,
const bfloat16& b,
const bfloat16& c) {
678 return bfloat16(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
682 EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os,
const bfloat16& v) {
683 os << static_cast<float>(v);
693 struct is_arithmetic<bfloat16> {
694 enum { value =
true };
698 struct random_impl<bfloat16> {
699 enum :
int { MantissaBits = 7 };
700 using Impl = random_impl<float>;
701 static EIGEN_DEVICE_FUNC
inline bfloat16 run(
const bfloat16& x,
const bfloat16& y) {
702 float result = Impl::run(x, y, MantissaBits);
703 return bfloat16(result);
705 static EIGEN_DEVICE_FUNC
inline bfloat16 run() {
706 float result = Impl::run(MantissaBits);
707 return bfloat16(result);
714 struct NumTraits<
Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
715 enum { IsSigned =
true, IsInteger =
false, IsComplex =
false, RequireInitialization =
false };
717 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
718 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
720 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
721 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
723 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
724 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
726 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
727 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
729 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
730 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
732 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
733 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
739 #if defined(EIGEN_HAS_HIP_BF16) 740 #pragma pop_macro("EIGEN_CONSTEXPR") 747 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(
isnan)(
const Eigen::bfloat16& h) {
748 return (bfloat16_impl::isnan)(h);
752 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(
isinf)(
const Eigen::bfloat16& h) {
753 return (bfloat16_impl::isinf)(h);
757 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(
isfinite)(
const Eigen::bfloat16& h) {
758 return (bfloat16_impl::isfinite)(h);
762 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src) {
763 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
767 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src) {
768 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
771 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(
const bfloat16& from,
const bfloat16& to) {
772 if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
775 if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
781 uint16_t from_bits = numext::bit_cast<uint16_t>(from);
782 bool from_sign = from_bits >> 15;
784 bool toward_inf = (to > from) == !from_sign;
787 }
else if ((from_bits & 0x7fff) == 0) {
793 return numext::bit_cast<bfloat16>(from_bits);
798 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(
const Eigen::bfloat16& x,
const Eigen::bfloat16& y,
const Eigen::bfloat16& z) {
799 return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
805 #if EIGEN_HAS_STD_HASH 808 struct hash<
Eigen::bfloat16> {
809 EIGEN_STRONG_INLINE std::size_t operator()(
const Eigen::bfloat16& a)
const {
810 return static_cast<std::size_t
>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
827 #if defined(EIGEN_HIPCC) 829 #if defined(EIGEN_HAS_HIP_BF16) 831 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var,
int srcLane,
int width = warpSize) {
832 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
833 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl(ivar, srcLane, width)));
836 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var,
unsigned int delta,
837 int width = warpSize) {
838 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
839 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl_up(ivar, delta, width)));
842 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var,
unsigned int delta,
843 int width = warpSize) {
844 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
845 return Eigen::numext::bit_cast<Eigen::bfloat16>(
846 static_cast<Eigen::numext::uint16_t
>(__shfl_down(ivar, delta, width)));
849 __device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var,
int laneMask,
int width = warpSize) {
850 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
851 return Eigen::numext::bit_cast<Eigen::bfloat16>(
852 static_cast<Eigen::numext::uint16_t
>(__shfl_xor(ivar, laneMask, width)));
859 #if defined(EIGEN_HIPCC) 860 EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(
const Eigen::bfloat16* ptr) {
861 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
862 __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
866 #endif // EIGEN_BFLOAT16_H const Eigen::CwiseUnaryOp< Eigen::internal::scalar_tanh_op< typename Derived::Scalar >, const Derived > tanh(const Eigen::ArrayBase< Derived > &x)
const Product< MatrixDerived, PermutationDerived, DefaultProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:474
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sinh_op< typename Derived::Scalar >, const Derived > sinh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isfinite_op< typename Derived::Scalar >, const Derived > isfinite(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sqrt_op< typename Derived::Scalar >, const Derived > sqrt(const Eigen::ArrayBase< Derived > &x)
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Definition: BFloat16.h:231
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ceil_op< typename Derived::Scalar >, const Derived > ceil(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_asin_op< typename Derived::Scalar >, const Derived > asin(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_acos_op< typename Derived::Scalar >, const Derived > acos(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_exp2_op< typename Derived::Scalar >, const Derived > exp2(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isnan_op< typename Derived::Scalar >, const Derived > isnan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_cos_op< typename Derived::Scalar >, const Derived > cos(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_round_op< typename Derived::Scalar >, const Derived > round(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_rint_op< typename Derived::Scalar >, const Derived > rint(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_floor_op< typename Derived::Scalar >, const Derived > floor(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log1p_op< typename Derived::Scalar >, const Derived > log1p(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isinf_op< typename Derived::Scalar >, const Derived > isinf(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_abs_op< typename Derived::Scalar >, const Derived > abs(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_cosh_op< typename Derived::Scalar >, const Derived > cosh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log_op< typename Derived::Scalar >, const Derived > log(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_tan_op< typename Derived::Scalar >, const Derived > tan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_expm1_op< typename Derived::Scalar >, const Derived > expm1(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_atanh_op< typename Derived::Scalar >, const Derived > atanh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log2_op< typename Derived::Scalar >, const Derived > log2(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_atan_op< typename Derived::Scalar >, const Derived > atan(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sin_op< typename Derived::Scalar >, const Derived > sin(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_exp_op< typename Derived::Scalar >, const Derived > exp(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_log10_op< typename Derived::Scalar >, const Derived > log10(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_acosh_op< typename Derived::Scalar >, const Derived > acosh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_asinh_op< typename Derived::Scalar >, const Derived > asinh(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_trunc_op< typename Derived::Scalar >, const Derived > trunc(const Eigen::ArrayBase< Derived > &x)