Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TypeCasting.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TYPE_CASTING_AVX512_H
11#define EIGEN_TYPE_CASTING_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <>
21struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
22template <>
23struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
24
25template <>
26struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
27template <>
28struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
29
30template <>
31struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
32template <>
33struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
34
35template <>
36struct type_casting_traits<double, int> : vectorized_type_casting_traits<double, int> {};
37template <>
38struct type_casting_traits<int, double> : vectorized_type_casting_traits<int, double> {};
39
40template <>
41struct type_casting_traits<double, int64_t> : vectorized_type_casting_traits<double, int64_t> {};
42template <>
43struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
44
45template <>
46struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
47template <>
48struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
49
50template <>
51struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
52template <>
53struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
54
55template <>
56EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
57 __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
58 return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
59}
60
61template <>
62EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
63 return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
64}
65
66template <>
67EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
68 return _mm512_cvttps_epi32(a);
69}
70
71template <>
72EIGEN_STRONG_INLINE Packet8d pcast<Packet16f, Packet8d>(const Packet16f& a) {
73 return _mm512_cvtps_pd(_mm512_castps512_ps256(a));
74}
75
76template <>
77EIGEN_STRONG_INLINE Packet8d pcast<Packet8f, Packet8d>(const Packet8f& a) {
78 return _mm512_cvtps_pd(a);
79}
80
81template <>
82EIGEN_STRONG_INLINE Packet8l pcast<Packet8d, Packet8l>(const Packet8d& a) {
83#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
84 return _mm512_cvttpd_epi64(a);
85#else
86 constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
87 kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
88
89 const __m512i cst_one = _mm512_set1_epi64(1);
90 const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits);
91 const __m512i cst_bias = _mm512_set1_epi64(kBias);
92
93 __m512i a_bits = _mm512_castpd_si512(a);
94 // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
95 __m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1);
96 __m512i e = _mm512_sub_epi64(biased_e, cst_bias);
97
98 // shift to the left by kExponentBits + 1 to clear the sign and exponent bits
99 __m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1);
100 // shift to the right by kTotalBits - e to convert the significand to an integer
101 __m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e));
102
103 // add the implied bit
104 __m512i result_exponent = _mm512_sllv_epi64(cst_one, e);
105 // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
106 __m512i result = _mm512_add_epi64(result_significand, result_exponent);
107 // handle negative arguments
108 __mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512());
109 result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result);
110 return result;
111#endif
112}
113
114template <>
115EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
116 return _mm512_cvtepi32_ps(a);
117}
118
119template <>
120EIGEN_STRONG_INLINE Packet8d pcast<Packet16i, Packet8d>(const Packet16i& a) {
121 return _mm512_cvtepi32_pd(_mm512_castsi512_si256(a));
122}
123
124template <>
125EIGEN_STRONG_INLINE Packet8d pcast<Packet8i, Packet8d>(const Packet8i& a) {
126 return _mm512_cvtepi32_pd(a);
127}
128
129template <>
130EIGEN_STRONG_INLINE Packet8d pcast<Packet8l, Packet8d>(const Packet8l& a) {
131#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
132 return _mm512_cvtepi64_pd(a);
133#else
134 EIGEN_ALIGN64 int64_t aux[8];
135 pstore(aux, a);
136 return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
137 static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
138 static_cast<double>(aux[1]), static_cast<double>(aux[0]));
139#endif
140}
141
142template <>
143EIGEN_STRONG_INLINE Packet16f pcast<Packet8d, Packet16f>(const Packet8d& a, const Packet8d& b) {
144 return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
145}
146
147template <>
148EIGEN_STRONG_INLINE Packet16i pcast<Packet8d, Packet16i>(const Packet8d& a, const Packet8d& b) {
149 return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
150}
151
152template <>
153EIGEN_STRONG_INLINE Packet8i pcast<Packet8d, Packet8i>(const Packet8d& a) {
154 return _mm512_cvtpd_epi32(a);
155}
156template <>
157EIGEN_STRONG_INLINE Packet8f pcast<Packet8d, Packet8f>(const Packet8d& a) {
158 return _mm512_cvtpd_ps(a);
159}
160
161template <>
162EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
163 return _mm512_castps_si512(a);
164}
165
166template <>
167EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
168 return _mm512_castsi512_ps(a);
169}
170
171template <>
172EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
173 return _mm512_castps_pd(a);
174}
175
176template <>
177EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8l>(const Packet8l& a) {
178 return _mm512_castsi512_pd(a);
179}
180
181template <>
182EIGEN_STRONG_INLINE Packet8l preinterpret<Packet8l, Packet8d>(const Packet8d& a) {
183 return _mm512_castpd_si512(a);
184}
185
186template <>
187EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
188 return _mm512_castpd_ps(a);
189}
190
191template <>
192EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
193 return _mm512_castps512_ps256(a);
194}
195
196template <>
197EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
198 return _mm512_castps512_ps128(a);
199}
200
201template <>
202EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
203 return _mm512_castpd512_pd256(a);
204}
205
206template <>
207EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
208 return _mm512_castpd512_pd128(a);
209}
210
211template <>
212EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
213 return _mm512_castps256_ps512(a);
214}
215
216template <>
217EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
218 return _mm512_castps128_ps512(a);
219}
220
221template <>
222EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
223 return _mm512_castpd256_pd512(a);
224}
225
226template <>
227EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
228 return _mm512_castpd128_pd512(a);
229}
230
231template <>
232EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet16i>(const Packet16i& a) {
233 return _mm512_castsi512_si256(a);
234}
235template <>
236EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i& a) {
237 return _mm512_castsi512_si128(a);
238}
239
240#ifndef EIGEN_VECTORIZE_AVX512FP16
241template <>
242EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
243 return _mm256_castsi256_si128(a);
244}
245
246template <>
247EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
248 return half2float(a);
249}
250
251template <>
252EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
253 return float2half(a);
254}
255
256#endif
257
258template <>
259EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
260 return _mm256_castsi256_si128(a);
261}
262
263template <>
264EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
265 return Bf16ToF32(a);
266}
267
268template <>
269EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
270 return F32ToBf16(a);
271}
272
273} // end namespace internal
274
275} // end namespace Eigen
276
277#endif // EIGEN_TYPE_CASTING_AVX512_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1