16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
20#include "../../InternalHeaderCheck.h"
22#if defined(EIGEN_HAS_HIP_BF16)
29#pragma push_macro("EIGEN_CONSTEXPR")
31#define EIGEN_CONSTEXPR
34#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
36 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \
37 const PACKET_BF16& _x) { \
38 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
42#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
43#define EIGEN_USE_HIP_BF16
52EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src);
55EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src);
57namespace bfloat16_impl {
59#if defined(EIGEN_USE_HIP_BF16)
61struct __bfloat16_raw :
public hip_bfloat16 {
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : hip_bfloat16(raw) {}
70struct __bfloat16_raw {
71#if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
74 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
76 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : value(raw) {}
82EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(
unsigned short value);
83template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
84EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(
float ff);
88EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff);
90EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff);
91EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h);
93struct bfloat16_base :
public __bfloat16_raw {
94 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(
const __bfloat16_raw& h) : __bfloat16_raw(h) {}
101struct bfloat16 :
public bfloat16_impl::bfloat16_base {
102 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
104 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
106 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
108 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
bool b)
109 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
112 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113 : bfloat16_impl::bfloat16_base(
114 bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
116 explicit EIGEN_DEVICE_FUNC bfloat16(
float f)
117 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
121 template <
typename RealScalar>
122 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const std::complex<RealScalar>& val)
123 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.
real()))) {}
125 EIGEN_DEVICE_FUNC
operator float()
const {
126 return bfloat16_impl::bfloat16_to_float(*
this);
132namespace bfloat16_impl {
133template <
typename =
void>
134struct numeric_limits_bfloat16_impl {
135 static EIGEN_CONSTEXPR
const bool is_specialized =
true;
136 static EIGEN_CONSTEXPR
const bool is_signed =
true;
137 static EIGEN_CONSTEXPR
const bool is_integer =
false;
138 static EIGEN_CONSTEXPR
const bool is_exact =
false;
139 static EIGEN_CONSTEXPR
const bool has_infinity =
true;
140 static EIGEN_CONSTEXPR
const bool has_quiet_NaN =
true;
141 static EIGEN_CONSTEXPR
const bool has_signaling_NaN =
true;
142 EIGEN_DIAGNOSTICS(push)
143 EIGEN_DISABLE_DEPRECATED_WARNING
144 static EIGEN_CONSTEXPR
const std::float_denorm_style has_denorm = std::denorm_present;
145 static EIGEN_CONSTEXPR
const bool has_denorm_loss =
false;
146 EIGEN_DIAGNOSTICS(pop)
147 static EIGEN_CONSTEXPR
const std::float_round_style round_style = std::numeric_limits<float>::round_style;
148 static EIGEN_CONSTEXPR
const bool is_iec559 =
true;
151 static EIGEN_CONSTEXPR
const bool is_bounded =
true;
152 static EIGEN_CONSTEXPR
const bool is_modulo =
false;
153 static EIGEN_CONSTEXPR
const int digits = 8;
154 static EIGEN_CONSTEXPR
const int digits10 = 2;
155 static EIGEN_CONSTEXPR
const int max_digits10 = 4;
156 static EIGEN_CONSTEXPR
const int radix = std::numeric_limits<float>::radix;
157 static EIGEN_CONSTEXPR
const int min_exponent = std::numeric_limits<float>::min_exponent;
158 static EIGEN_CONSTEXPR
const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
159 static EIGEN_CONSTEXPR
const int max_exponent = std::numeric_limits<float>::max_exponent;
160 static EIGEN_CONSTEXPR
const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
161 static EIGEN_CONSTEXPR
const bool traps = std::numeric_limits<float>::traps;
164 static EIGEN_CONSTEXPR
const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
166 static EIGEN_CONSTEXPR Eigen::bfloat16(min)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
167 static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
168 static EIGEN_CONSTEXPR Eigen::bfloat16(max)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
169 static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
170 static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
171 static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
172 static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
173 static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
174 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
176 static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
180EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_specialized;
182EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_signed;
184EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_integer;
186EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_exact;
188EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_infinity;
190EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
192EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
193EIGEN_DIAGNOSTICS(push)
194EIGEN_DISABLE_DEPRECATED_WARNING
196EIGEN_CONSTEXPR
const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
198EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
199EIGEN_DIAGNOSTICS(pop)
201EIGEN_CONSTEXPR
const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
203EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_iec559;
205EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_bounded;
207EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_modulo;
209EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits;
211EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits10;
213EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_digits10;
215EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::radix;
217EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent;
219EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent10;
221EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent;
223EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent10;
225EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::traps;
227EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
237class numeric_limits<Eigen::bfloat16> :
public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
239class numeric_limits<const Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
241class numeric_limits<volatile Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
243class numeric_limits<const volatile Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
248namespace bfloat16_impl {
253#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
255#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
257#pragma push_macro("EIGEN_DEVICE_FUNC")
258#undef EIGEN_DEVICE_FUNC
259#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
260#define EIGEN_DEVICE_FUNC __host__
262#define EIGEN_DEVICE_FUNC __host__ __device__
269EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const bfloat16& b) {
270 return bfloat16(
float(a) +
float(b));
272EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const int& b) {
273 return bfloat16(
float(a) +
static_cast<float>(b));
275EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const int& a,
const bfloat16& b) {
276 return bfloat16(
static_cast<float>(a) +
float(b));
278EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(
const bfloat16& a,
const bfloat16& b) {
279 return bfloat16(
float(a) *
float(b));
281EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a,
const bfloat16& b) {
282 return bfloat16(
float(a) -
float(b));
284EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a,
const bfloat16& b) {
285 return bfloat16(
float(a) /
float(b));
287EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a) {
288 numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
289 return numext::bit_cast<bfloat16>(x);
291EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a,
const bfloat16& b) {
292 a = bfloat16(
float(a) +
float(b));
295EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a,
const bfloat16& b) {
296 a = bfloat16(
float(a) *
float(b));
299EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a,
const bfloat16& b) {
300 a = bfloat16(
float(a) -
float(b));
303EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a,
const bfloat16& b) {
304 a = bfloat16(
float(a) /
float(b));
307EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
311EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
315EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a,
int) {
316 bfloat16 original_value = a;
318 return original_value;
320EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a,
int) {
321 bfloat16 original_value = a;
323 return original_value;
325EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator==(
const bfloat16& a,
const bfloat16& b) {
326 return numext::equal_strict(
float(a),
float(b));
328EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator!=(
const bfloat16& a,
const bfloat16& b) {
329 return numext::not_equal_strict(
float(a),
float(b));
331EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<(
const bfloat16& a,
const bfloat16& b) {
332 return float(a) < float(b);
334EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<=(
const bfloat16& a,
const bfloat16& b) {
335 return float(a) <= float(b);
337EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>(
const bfloat16& a,
const bfloat16& b) {
338 return float(a) > float(b);
340EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>=(
const bfloat16& a,
const bfloat16& b) {
341 return float(a) >= float(b);
344#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
345#pragma pop_macro("EIGEN_DEVICE_FUNC")
351EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a,
Index b) {
352 return bfloat16(
static_cast<float>(a) /
static_cast<float>(b));
355EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(
const float v) {
356#if defined(EIGEN_USE_HIP_BF16)
357 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
359 __bfloat16_raw output;
360 if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
361 output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
364 output.value =
static_cast<numext::uint16_t
>(numext::bit_cast<numext::uint32_t>(v) >> 16);
369EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
370#if defined(EIGEN_USE_HIP_BF16)
375 return __bfloat16_raw(value);
379EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
380 const __bfloat16_raw& bf) {
381#if defined(EIGEN_USE_HIP_BF16)
391EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff) {
392#if defined(EIGEN_USE_HIP_BF16)
393 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
395 __bfloat16_raw output;
397 if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
403 output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
554 output = float_to_bfloat16_rtne<true>(ff);
565EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff) {
566#if defined(EIGEN_USE_HIP_BF16)
567 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
569 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
570 __bfloat16_raw output;
573 numext::uint32_t lsb = (input >> 16) & 1;
574 numext::uint32_t rounding_bias = 0x7fff + lsb;
575 input += rounding_bias;
576 output.value =
static_cast<numext::uint16_t
>(input >> 16);
581EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h) {
582#if defined(EIGEN_USE_HIP_BF16)
583 return static_cast<float>(h);
585 return numext::bit_cast<float>(
static_cast<numext::uint32_t
>(h.value) << 16);
591EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(
const bfloat16& a) {
592 EIGEN_USING_STD(isinf);
593#if defined(EIGEN_USE_HIP_BF16)
596 return (isinf)(float(a));
599EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(
const bfloat16& a) {
600 EIGEN_USING_STD(isnan);
601#if defined(EIGEN_USE_HIP_BF16)
604 return (isnan)(float(a));
607EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(
const bfloat16& a) {
608 return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
611EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(
const bfloat16& a) {
612 numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
613 return numext::bit_cast<bfloat16>(x);
615EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(
const bfloat16& a) {
return bfloat16(::expf(
float(a))); }
616EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp2(
const bfloat16& a) {
return bfloat16(::exp2f(
float(a))); }
617EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(
const bfloat16& a) {
return bfloat16(numext::expm1(
float(a))); }
618EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(
const bfloat16& a) {
return bfloat16(::logf(
float(a))); }
619EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(
const bfloat16& a) {
return bfloat16(numext::log1p(
float(a))); }
620EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(
const bfloat16& a) {
return bfloat16(::log10f(
float(a))); }
621EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(
const bfloat16& a) {
622 return bfloat16(
static_cast<float>(EIGEN_LOG2E) * ::logf(
float(a)));
624EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(
const bfloat16& a) {
return bfloat16(::sqrtf(
float(a))); }
625EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(
const bfloat16& a,
const bfloat16& b) {
626 return bfloat16(::powf(
float(a),
float(b)));
628EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(
const bfloat16& a,
const bfloat16& b) {
629 return bfloat16(::atan2f(
float(a),
float(b)));
631EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(
const bfloat16& a) {
return bfloat16(::sinf(
float(a))); }
632EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(
const bfloat16& a) {
return bfloat16(::cosf(
float(a))); }
633EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(
const bfloat16& a) {
return bfloat16(::tanf(
float(a))); }
634EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(
const bfloat16& a) {
return bfloat16(::asinf(
float(a))); }
635EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(
const bfloat16& a) {
return bfloat16(::acosf(
float(a))); }
636EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(
const bfloat16& a) {
return bfloat16(::atanf(
float(a))); }
637EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(
const bfloat16& a) {
return bfloat16(::sinhf(
float(a))); }
638EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(
const bfloat16& a) {
return bfloat16(::coshf(
float(a))); }
639EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(
const bfloat16& a) {
return bfloat16(::tanhf(
float(a))); }
640EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(
const bfloat16& a) {
return bfloat16(::asinhf(
float(a))); }
641EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(
const bfloat16& a) {
return bfloat16(::acoshf(
float(a))); }
642EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(
const bfloat16& a) {
return bfloat16(::atanhf(
float(a))); }
643EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(
const bfloat16& a) {
return bfloat16(::floorf(
float(a))); }
644EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(
const bfloat16& a) {
return bfloat16(::ceilf(
float(a))); }
645EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(
const bfloat16& a) {
return bfloat16(::rintf(
float(a))); }
646EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(
const bfloat16& a) {
return bfloat16(::roundf(
float(a))); }
647EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(
const bfloat16& a) {
return bfloat16(::truncf(
float(a))); }
648EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(
const bfloat16& a,
const bfloat16& b) {
649 return bfloat16(::fmodf(
float(a),
float(b)));
652EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(
const bfloat16& a,
const bfloat16& b) {
653 const float f1 =
static_cast<float>(a);
654 const float f2 =
static_cast<float>(b);
655 return f2 < f1 ? b : a;
658EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(
const bfloat16& a,
const bfloat16& b) {
659 const float f1 =
static_cast<float>(a);
660 const float f2 =
static_cast<float>(b);
661 return f1 < f2 ? b : a;
664EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(
const bfloat16& a,
const bfloat16& b) {
665 const float f1 =
static_cast<float>(a);
666 const float f2 =
static_cast<float>(b);
667 return bfloat16(::fminf(f1, f2));
670EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(
const bfloat16& a,
const bfloat16& b) {
671 const float f1 =
static_cast<float>(a);
672 const float f2 =
static_cast<float>(b);
673 return bfloat16(::fmaxf(f1, f2));
676EIGEN_DEVICE_FUNC
inline bfloat16 fma(
const bfloat16& a,
const bfloat16& b,
const bfloat16& c) {
678 return bfloat16(numext::fma(
static_cast<float>(a),
static_cast<float>(b),
static_cast<float>(c)));
682EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os,
const bfloat16& v) {
683 os << static_cast<float>(v);
693struct is_arithmetic<bfloat16> {
694 enum { value =
true };
698struct random_impl<bfloat16> {
699 enum :
int { MantissaBits = 7 };
700 using Impl = random_impl<float>;
701 static EIGEN_DEVICE_FUNC
inline bfloat16 run(
const bfloat16& x,
const bfloat16& y) {
702 float result = Impl::run(x, y, MantissaBits);
703 return bfloat16(result);
705 static EIGEN_DEVICE_FUNC
inline bfloat16 run() {
706 float result = Impl::run(MantissaBits);
707 return bfloat16(result);
714struct NumTraits<Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
715 enum { IsSigned =
true, IsInteger =
false, IsComplex =
false, RequireInitialization =
false };
717 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
718 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
720 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
721 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
723 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
724 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
726 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
727 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
729 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
730 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
732 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
733 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
739#if defined(EIGEN_HAS_HIP_BF16)
740#pragma pop_macro("EIGEN_CONSTEXPR")
747EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(
const Eigen::bfloat16& h) {
748 return (bfloat16_impl::isnan)(h);
752EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(
const Eigen::bfloat16& h) {
753 return (bfloat16_impl::isinf)(h);
757EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(
const Eigen::bfloat16& h) {
758 return (bfloat16_impl::isfinite)(h);
762EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src) {
763 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
767EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src) {
768 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
771EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(
const bfloat16& from,
const bfloat16& to) {
772 if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
775 if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
781 uint16_t from_bits = numext::bit_cast<uint16_t>(from);
782 bool from_sign = from_bits >> 15;
784 bool toward_inf = (to > from) == !from_sign;
787 }
else if ((from_bits & 0x7fff) == 0) {
793 return numext::bit_cast<bfloat16>(from_bits);
798EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(
const Eigen::bfloat16& x,
const Eigen::bfloat16& y,
const Eigen::bfloat16& z) {
799 return Eigen::bfloat16(
static_cast<float>(x) *
static_cast<float>(y) +
static_cast<float>(z));
805#if EIGEN_HAS_STD_HASH
808struct hash<Eigen::bfloat16> {
809 EIGEN_STRONG_INLINE std::size_t operator()(
const Eigen::bfloat16& a)
const {
810 return static_cast<std::size_t
>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
827#if defined(EIGEN_HIPCC)
829#if defined(EIGEN_HAS_HIP_BF16)
831__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var,
int srcLane,
int width = warpSize) {
832 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
833 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl(ivar, srcLane, width)));
836__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var,
unsigned int delta,
837 int width = warpSize) {
838 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
839 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl_up(ivar, delta, width)));
842__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var,
unsigned int delta,
843 int width = warpSize) {
844 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
845 return Eigen::numext::bit_cast<Eigen::bfloat16>(
846 static_cast<Eigen::numext::uint16_t
>(__shfl_down(ivar, delta, width)));
849__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var,
int laneMask,
int width = warpSize) {
850 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
851 return Eigen::numext::bit_cast<Eigen::bfloat16>(
852 static_cast<Eigen::numext::uint16_t
>(__shfl_xor(ivar, laneMask, width)));
859#if defined(EIGEN_HIPCC)
860EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(
const Eigen::bfloat16* ptr) {
861 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
862 __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition NumTraits.h:232