Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
BFloat16.h
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
18
19// IWYU pragma: private
20#include "../../InternalHeaderCheck.h"
21
22#if defined(EIGEN_HAS_HIP_BF16)
23// When compiling with GPU support, the "hip_bfloat16" base class as well as
24// some other routines are defined in the GPU compiler header files
25// (hip_bfloat16.h), and they are not tagged constexpr
26// As a consequence, we get compile failures when compiling Eigen with
27// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
28// Eigen with GPU support
29#pragma push_macro("EIGEN_CONSTEXPR")
30#undef EIGEN_CONSTEXPR
31#define EIGEN_CONSTEXPR
32#endif
33
34#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
35 template <> \
36 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \
37 const PACKET_BF16& _x) { \
38 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
39 }
40
41// Only use HIP GPU bf16 in kernels
42#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
43#define EIGEN_USE_HIP_BF16
44#endif
45
46namespace Eigen {
47
48struct bfloat16;
49
50namespace numext {
51template <>
52EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src);
53
54template <>
55EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src);
56} // namespace numext
57namespace bfloat16_impl {
58
59#if defined(EIGEN_USE_HIP_BF16)
60
61struct __bfloat16_raw : public hip_bfloat16 {
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : hip_bfloat16(raw) {}
65};
66
67#else
68
69// Make our own __bfloat16_raw definition.
70struct __bfloat16_raw {
71#if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
73#else
74 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
75#endif
76 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
77 unsigned short value;
78};
79
80#endif // defined(EIGEN_USE_HIP_BF16)
81
82EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
83template <bool AssumeArgumentIsNormalOrInfinityOrZero>
84EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
85// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
86// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
87template <>
88EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
89template <>
90EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
91EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
92
93struct bfloat16_base : public __bfloat16_raw {
94 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
96};
97
98} // namespace bfloat16_impl
99
100// Class definition.
101struct bfloat16 : public bfloat16_impl::bfloat16_base {
102 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
103
104 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
105
106 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
107
108 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
109 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
110
111 template <class T>
112 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113 : bfloat16_impl::bfloat16_base(
114 bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
115
116 explicit EIGEN_DEVICE_FUNC bfloat16(float f)
117 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
118
119 // Following the convention of numpy, converting between complex and
120 // float will lead to loss of imag value.
121 template <typename RealScalar>
122 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
123 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
124
125 EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
126 return bfloat16_impl::bfloat16_to_float(*this);
127 }
128};
129
130// TODO(majnemer): Get rid of this once we can rely on C++17 inline variables do
131// solve the ODR issue.
132namespace bfloat16_impl {
133template <typename = void>
134struct numeric_limits_bfloat16_impl {
135 static EIGEN_CONSTEXPR const bool is_specialized = true;
136 static EIGEN_CONSTEXPR const bool is_signed = true;
137 static EIGEN_CONSTEXPR const bool is_integer = false;
138 static EIGEN_CONSTEXPR const bool is_exact = false;
139 static EIGEN_CONSTEXPR const bool has_infinity = true;
140 static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
141 static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
142 EIGEN_DIAGNOSTICS(push)
143 EIGEN_DISABLE_DEPRECATED_WARNING
144 static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
145 static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
146 EIGEN_DIAGNOSTICS(pop)
147 static EIGEN_CONSTEXPR const std::float_round_style round_style = std::numeric_limits<float>::round_style;
148 static EIGEN_CONSTEXPR const bool is_iec559 = true;
149 // The C++ standard defines this as "true if the set of values representable
150 // by the type is finite." BFloat16 has finite precision.
151 static EIGEN_CONSTEXPR const bool is_bounded = true;
152 static EIGEN_CONSTEXPR const bool is_modulo = false;
153 static EIGEN_CONSTEXPR const int digits = 8;
154 static EIGEN_CONSTEXPR const int digits10 = 2;
155 static EIGEN_CONSTEXPR const int max_digits10 = 4;
156 static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
157 static EIGEN_CONSTEXPR const int min_exponent = std::numeric_limits<float>::min_exponent;
158 static EIGEN_CONSTEXPR const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
159 static EIGEN_CONSTEXPR const int max_exponent = std::numeric_limits<float>::max_exponent;
160 static EIGEN_CONSTEXPR const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
161 static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
162 // IEEE754: "The implementer shall choose how tininess is detected, but shall
163 // detect tininess in the same way for all operations in radix two"
164 static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
165
166 static EIGEN_CONSTEXPR Eigen::bfloat16(min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
167 static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
168 static EIGEN_CONSTEXPR Eigen::bfloat16(max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
169 static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
170 static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
171 static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
172 static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
173 static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
174 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
175 }
176 static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
177};
178
179template <typename T>
180EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_specialized;
181template <typename T>
182EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_signed;
183template <typename T>
184EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_integer;
185template <typename T>
186EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_exact;
187template <typename T>
188EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_infinity;
189template <typename T>
190EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
191template <typename T>
192EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
193EIGEN_DIAGNOSTICS(push)
194EIGEN_DISABLE_DEPRECATED_WARNING
195template <typename T>
196EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
197template <typename T>
198EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
199EIGEN_DIAGNOSTICS(pop)
200template <typename T>
201EIGEN_CONSTEXPR const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
202template <typename T>
203EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_iec559;
204template <typename T>
205EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_bounded;
206template <typename T>
207EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_modulo;
208template <typename T>
209EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits;
210template <typename T>
211EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits10;
212template <typename T>
213EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_digits10;
214template <typename T>
215EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::radix;
216template <typename T>
217EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent;
218template <typename T>
219EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent10;
220template <typename T>
221EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent;
222template <typename T>
223EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent10;
224template <typename T>
225EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::traps;
226template <typename T>
227EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
228} // end namespace bfloat16_impl
229} // end namespace Eigen
230
231namespace std {
232// If std::numeric_limits<T> is specialized, should also specialize
233// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
234// std::numeric_limits<const volatile T>
235// https://stackoverflow.com/a/16519653/
236template <>
237class numeric_limits<Eigen::bfloat16> : public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
238template <>
239class numeric_limits<const Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
240template <>
241class numeric_limits<volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
242template <>
243class numeric_limits<const volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
244} // end namespace std
245
246namespace Eigen {
247
248namespace bfloat16_impl {
249
250// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
251// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
252// of the functions, while the latter can only deal with one of them.
253#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
254
255#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
256// We need to provide emulated *host-side* BF16 operators for clang.
257#pragma push_macro("EIGEN_DEVICE_FUNC")
258#undef EIGEN_DEVICE_FUNC
259#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
260#define EIGEN_DEVICE_FUNC __host__
261#else // both host and device need emulated ops.
262#define EIGEN_DEVICE_FUNC __host__ __device__
263#endif
264#endif
265
266// Definitions for CPUs, mostly working through conversion
267// to/from fp32.
268
269EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
270 return bfloat16(float(a) + float(b));
271}
272EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const int& b) {
273 return bfloat16(float(a) + static_cast<float>(b));
274}
275EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const int& a, const bfloat16& b) {
276 return bfloat16(static_cast<float>(a) + float(b));
277}
278EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
279 return bfloat16(float(a) * float(b));
280}
281EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
282 return bfloat16(float(a) - float(b));
283}
284EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
285 return bfloat16(float(a) / float(b));
286}
287EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a) {
288 numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
289 return numext::bit_cast<bfloat16>(x);
290}
291EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a, const bfloat16& b) {
292 a = bfloat16(float(a) + float(b));
293 return a;
294}
295EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a, const bfloat16& b) {
296 a = bfloat16(float(a) * float(b));
297 return a;
298}
299EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a, const bfloat16& b) {
300 a = bfloat16(float(a) - float(b));
301 return a;
302}
303EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a, const bfloat16& b) {
304 a = bfloat16(float(a) / float(b));
305 return a;
306}
307EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
308 a += bfloat16(1);
309 return a;
310}
311EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
312 a -= bfloat16(1);
313 return a;
314}
315EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
316 bfloat16 original_value = a;
317 ++a;
318 return original_value;
319}
320EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
321 bfloat16 original_value = a;
322 --a;
323 return original_value;
324}
325EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16& a, const bfloat16& b) {
326 return numext::equal_strict(float(a), float(b));
327}
328EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const bfloat16& a, const bfloat16& b) {
329 return numext::not_equal_strict(float(a), float(b));
330}
331EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const bfloat16& a, const bfloat16& b) {
332 return float(a) < float(b);
333}
334EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16& a, const bfloat16& b) {
335 return float(a) <= float(b);
336}
337EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16& a, const bfloat16& b) {
338 return float(a) > float(b);
339}
340EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const bfloat16& a, const bfloat16& b) {
341 return float(a) >= float(b);
342}
343
344#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
345#pragma pop_macro("EIGEN_DEVICE_FUNC")
346#endif
347#endif // Emulate support for bfloat16 floats
348
349// Division by an index. Do it in full float precision to avoid accuracy
350// issues in converting the denominator to bfloat16.
351EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, Index b) {
352 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
353}
354
355EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
356#if defined(EIGEN_USE_HIP_BF16)
357 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
358#else
359 __bfloat16_raw output;
360 if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
361 output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
362 return output;
363 }
364 output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
365 return output;
366#endif
367}
368
369EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
370#if defined(EIGEN_USE_HIP_BF16)
371 __bfloat16_raw bf;
372 bf.data = value;
373 return bf;
374#else
375 return __bfloat16_raw(value);
376#endif
377}
378
379EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
380 const __bfloat16_raw& bf) {
381#if defined(EIGEN_USE_HIP_BF16)
382 return bf.data;
383#else
384 return bf.value;
385#endif
386}
387
388// float_to_bfloat16_rtne template specialization that does not make any
389// assumption about the value of its function argument (ff).
390template <>
391EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
392#if defined(EIGEN_USE_HIP_BF16)
393 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
394#else
395 __bfloat16_raw output;
396
397 if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
398 // If the value is a NaN, squash it to a qNaN with msb of fraction set,
399 // this makes sure after truncation we don't end up with an inf.
400 //
401 // qNaN magic: All exponent bits set + most significant bit of fraction
402 // set.
403 output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
404 } else {
405 // Fast rounding algorithm that rounds a half value to nearest even. This
406 // reduces expected error when we convert a large number of floats. Here
407 // is how it works:
408 //
409 // Definitions:
410 // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
411 // with the following tags:
412 //
413 // Sign | Exp (8 bits) | Frac (23 bits)
414 // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
415 //
416 // S: Sign bit.
417 // E: Exponent bits.
418 // F: First 6 bits of fraction.
419 // L: Least significant bit of resulting bfloat16 if we truncate away the
420 // rest of the float32. This is also the 7th bit of fraction
421 // R: Rounding bit, 8th bit of fraction.
422 // T: Sticky bits, rest of fraction, 15 bits.
423 //
424 // To round half to nearest even, there are 3 cases where we want to round
425 // down (simply truncate the result of the bits away, which consists of
426 // rounding bit and sticky bits) and two cases where we want to round up
427 // (truncate then add one to the result).
428 //
429 // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
430 // 1s) as the rounding bias, adds the rounding bias to the input, then
431 // truncates the last 16 bits away.
432 //
433 // To understand how it works, we can analyze this algorithm case by case:
434 //
435 // 1. L = 0, R = 0:
436 // Expect: round down, this is less than half value.
437 //
438 // Algorithm:
439 // - Rounding bias: 0x7fff + 0 = 0x7fff
440 // - Adding rounding bias to input may create any carry, depending on
441 // whether there is any value set to 1 in T bits.
442 // - R may be set to 1 if there is a carry.
443 // - L remains 0.
444 // - Note that this case also handles Inf and -Inf, where all fraction
445 // bits, including L, R and Ts are all 0. The output remains Inf after
446 // this algorithm.
447 //
448 // 2. L = 1, R = 0:
449 // Expect: round down, this is less than half value.
450 //
451 // Algorithm:
452 // - Rounding bias: 0x7fff + 1 = 0x8000
453 // - Adding rounding bias to input doesn't change sticky bits but
454 // adds 1 to rounding bit.
455 // - L remains 1.
456 //
457 // 3. L = 0, R = 1, all of T are 0:
458 // Expect: round down, this is exactly at half, the result is already
459 // even (L=0).
460 //
461 // Algorithm:
462 // - Rounding bias: 0x7fff + 0 = 0x7fff
463 // - Adding rounding bias to input sets all sticky bits to 1, but
464 // doesn't create a carry.
465 // - R remains 1.
466 // - L remains 0.
467 //
468 // 4. L = 1, R = 1:
469 // Expect: round up, this is exactly at half, the result needs to be
470 // round to the next even number.
471 //
472 // Algorithm:
473 // - Rounding bias: 0x7fff + 1 = 0x8000
474 // - Adding rounding bias to input doesn't change sticky bits, but
475 // creates a carry from rounding bit.
476 // - The carry sets L to 0, creates another carry bit and propagate
477 // forward to F bits.
478 // - If all the F bits are 1, a carry then propagates to the exponent
479 // bits, which then creates the minimum value with the next exponent
480 // value. Note that we won't have the case where exponents are all 1,
481 // since that's either a NaN (handled in the other if condition) or inf
482 // (handled in case 1).
483 //
484 // 5. L = 0, R = 1, any of T is 1:
485 // Expect: round up, this is greater than half.
486 //
487 // Algorithm:
488 // - Rounding bias: 0x7fff + 0 = 0x7fff
489 // - Adding rounding bias to input creates a carry from sticky bits,
490 // sets rounding bit to 0, then create another carry.
491 // - The second carry sets L to 1.
492 //
493 // Examples:
494 //
495 // Exact half value that is already even:
496 // Input:
497 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
498 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
499 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
500 //
501 // This falls into case 3. We truncate the rest of 16 bits and no
502 // carry is created into F and L:
503 //
504 // Output:
505 // Sign | Exp (8 bit) | Frac (first 7 bit)
506 // S E E E E E E E E F F F F F F L
507 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
508 //
509 // Exact half value, round to next even number:
510 // Input:
511 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
512 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
513 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
514 //
515 // This falls into case 4. We create a carry from R and T,
516 // which then propagates into L and F:
517 //
518 // Output:
519 // Sign | Exp (8 bit) | Frac (first 7 bit)
520 // S E E E E E E E E F F F F F F L
521 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
522 //
523 //
524 // Max denormal value round to min normal value:
525 // Input:
526 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
527 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
528 // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
529 //
530 // This falls into case 4. We create a carry from R and T,
531 // propagate into L and F, which then propagates into exponent
532 // bits:
533 //
534 // Output:
535 // Sign | Exp (8 bit) | Frac (first 7 bit)
536 // S E E E E E E E E F F F F F F L
537 // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
538 //
539 // Max normal value round to Inf:
540 // Input:
541 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
542 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
543 // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
544 //
545 // This falls into case 4. We create a carry from R and T,
546 // propagate into L and F, which then propagates into exponent
547 // bits:
548 //
549 // Sign | Exp (8 bit) | Frac (first 7 bit)
550 // S E E E E E E E E F F F F F F L
551 // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
552
553 // At this point, ff must be either a normal float, or +/-infinity.
554 output = float_to_bfloat16_rtne<true>(ff);
555 }
556 return output;
557#endif
558}
559
560// float_to_bfloat16_rtne template specialization that assumes that its function
561// argument (ff) is either a normal floating point number, or +/-infinity, or
562// zero. Used to improve the runtime performance of conversion from an integer
563// type to bfloat16.
564template <>
565EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
566#if defined(EIGEN_USE_HIP_BF16)
567 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
568#else
569 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
570 __bfloat16_raw output;
571
572 // Least significant bit of resulting bfloat.
573 numext::uint32_t lsb = (input >> 16) & 1;
574 numext::uint32_t rounding_bias = 0x7fff + lsb;
575 input += rounding_bias;
576 output.value = static_cast<numext::uint16_t>(input >> 16);
577 return output;
578#endif
579}
580
581EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
582#if defined(EIGEN_USE_HIP_BF16)
583 return static_cast<float>(h);
584#else
585 return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
586#endif
587}
588
589// --- standard functions ---
590
591EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const bfloat16& a) {
592 EIGEN_USING_STD(isinf);
593#if defined(EIGEN_USE_HIP_BF16)
594 return (isinf)(a); // Uses HIP hip_bfloat16 isinf operator
595#else
596 return (isinf)(float(a));
597#endif
598}
599EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const bfloat16& a) {
600 EIGEN_USING_STD(isnan);
601#if defined(EIGEN_USE_HIP_BF16)
602 return (isnan)(a); // Uses HIP hip_bfloat16 isnan operator
603#else
604 return (isnan)(float(a));
605#endif
606}
607EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const bfloat16& a) {
608 return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
609}
610
611EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
612 numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
613 return numext::bit_cast<bfloat16>(x);
614}
615EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { return bfloat16(::expf(float(a))); }
616EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp2(const bfloat16& a) { return bfloat16(::exp2f(float(a))); }
617EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { return bfloat16(numext::expm1(float(a))); }
618EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { return bfloat16(::logf(float(a))); }
619EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { return bfloat16(numext::log1p(float(a))); }
620EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { return bfloat16(::log10f(float(a))); }
621EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
622 return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
623}
624EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { return bfloat16(::sqrtf(float(a))); }
625EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
626 return bfloat16(::powf(float(a), float(b)));
627}
628EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
629 return bfloat16(::atan2f(float(a), float(b)));
630}
631EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { return bfloat16(::sinf(float(a))); }
632EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { return bfloat16(::cosf(float(a))); }
633EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { return bfloat16(::tanf(float(a))); }
634EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { return bfloat16(::asinf(float(a))); }
635EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { return bfloat16(::acosf(float(a))); }
636EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { return bfloat16(::atanf(float(a))); }
637EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { return bfloat16(::sinhf(float(a))); }
638EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { return bfloat16(::coshf(float(a))); }
639EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { return bfloat16(::tanhf(float(a))); }
640EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { return bfloat16(::asinhf(float(a))); }
641EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { return bfloat16(::acoshf(float(a))); }
642EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { return bfloat16(::atanhf(float(a))); }
643EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { return bfloat16(::floorf(float(a))); }
644EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { return bfloat16(::ceilf(float(a))); }
645EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { return bfloat16(::rintf(float(a))); }
646EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { return bfloat16(::roundf(float(a))); }
647EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(const bfloat16& a) { return bfloat16(::truncf(float(a))); }
648EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
649 return bfloat16(::fmodf(float(a), float(b)));
650}
651
652EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(const bfloat16& a, const bfloat16& b) {
653 const float f1 = static_cast<float>(a);
654 const float f2 = static_cast<float>(b);
655 return f2 < f1 ? b : a;
656}
657
658EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(const bfloat16& a, const bfloat16& b) {
659 const float f1 = static_cast<float>(a);
660 const float f2 = static_cast<float>(b);
661 return f1 < f2 ? b : a;
662}
663
664EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
665 const float f1 = static_cast<float>(a);
666 const float f2 = static_cast<float>(b);
667 return bfloat16(::fminf(f1, f2));
668}
669
670EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
671 const float f1 = static_cast<float>(a);
672 const float f2 = static_cast<float>(b);
673 return bfloat16(::fmaxf(f1, f2));
674}
675
676EIGEN_DEVICE_FUNC inline bfloat16 fma(const bfloat16& a, const bfloat16& b, const bfloat16& c) {
677 // Emulate FMA via float.
678 return bfloat16(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
679}
680
681#ifndef EIGEN_NO_IO
682EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) {
683 os << static_cast<float>(v);
684 return os;
685}
686#endif
687
688} // namespace bfloat16_impl
689
690namespace internal {
691
692template <>
693struct is_arithmetic<bfloat16> {
694 enum { value = true };
695};
696
697template <>
698struct random_impl<bfloat16> {
699 enum : int { MantissaBits = 7 };
700 using Impl = random_impl<float>;
701 static EIGEN_DEVICE_FUNC inline bfloat16 run(const bfloat16& x, const bfloat16& y) {
702 float result = Impl::run(x, y, MantissaBits);
703 return bfloat16(result);
704 }
705 static EIGEN_DEVICE_FUNC inline bfloat16 run() {
706 float result = Impl::run(MantissaBits);
707 return bfloat16(result);
708 }
709};
710
711} // namespace internal
712
713template <>
714struct NumTraits<Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
715 enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
716
717 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
718 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
719 }
720 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
721 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
722 }
723 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
724 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
725 }
726 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
727 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
728 }
729 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
730 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
731 }
732 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
733 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
734 }
735};
736
737} // namespace Eigen
738
739#if defined(EIGEN_HAS_HIP_BF16)
740#pragma pop_macro("EIGEN_CONSTEXPR")
741#endif
742
743namespace Eigen {
744namespace numext {
745
746template <>
747EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::bfloat16& h) {
748 return (bfloat16_impl::isnan)(h);
749}
750
751template <>
752EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::bfloat16& h) {
753 return (bfloat16_impl::isinf)(h);
754}
755
756template <>
757EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::bfloat16& h) {
758 return (bfloat16_impl::isfinite)(h);
759}
760
761template <>
762EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
763 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
764}
765
766template <>
767EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
768 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
769}
770
771EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, const bfloat16& to) {
772 if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
773 return from;
774 }
775 if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
776 return to;
777 }
778 if (from == to) {
779 return to;
780 }
781 uint16_t from_bits = numext::bit_cast<uint16_t>(from);
782 bool from_sign = from_bits >> 15;
783 // Whether we are adjusting toward the infinity with the same sign as from.
784 bool toward_inf = (to > from) == !from_sign;
785 if (toward_inf) {
786 ++from_bits;
787 } else if ((from_bits & 0x7fff) == 0) {
788 // Adjusting away from inf, but from is zero, so just toggle the sign.
789 from_bits ^= 0x8000;
790 } else {
791 --from_bits;
792 }
793 return numext::bit_cast<bfloat16>(from_bits);
794}
795
796// Specialize multiply-add to match packet operations and reduce conversions to/from float.
797template<>
798EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
799 return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
800}
801
802} // namespace numext
803} // namespace Eigen
804
805#if EIGEN_HAS_STD_HASH
806namespace std {
807template <>
808struct hash<Eigen::bfloat16> {
809 EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
810 return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
811 }
812};
813} // namespace std
814#endif
815
816// Add the missing shfl* intrinsics.
817// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
818// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
819//
820// HIP and CUDA prior to SDK 9.0 define
821// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
822// CUDA since 9.0 deprecates those and instead defines
823// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
824// with native support for __half and __nv_bfloat16
825//
826// Note that the following are __device__ - only functions.
827#if defined(EIGEN_HIPCC)
828
829#if defined(EIGEN_HAS_HIP_BF16)
830
831__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var, int srcLane, int width = warpSize) {
832 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
833 return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
834}
835
836__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var, unsigned int delta,
837 int width = warpSize) {
838 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
839 return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
840}
841
842__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var, unsigned int delta,
843 int width = warpSize) {
844 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
845 return Eigen::numext::bit_cast<Eigen::bfloat16>(
846 static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
847}
848
849__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var, int laneMask, int width = warpSize) {
850 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
851 return Eigen::numext::bit_cast<Eigen::bfloat16>(
852 static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
853}
854
855#endif // HIP
856
857#endif // __shfl*
858
859#if defined(EIGEN_HIPCC)
860EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(const Eigen::bfloat16* ptr) {
861 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
862 __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
863}
864#endif // __ldg
865
866#endif // EIGEN_BFLOAT16_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition NumTraits.h:232