Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
PacketMathFP16.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2025 The Eigen Authors.
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
11#define EIGEN_PACKET_MATH_FP16_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20typedef __m512h Packet32h;
21typedef __m256h Packet16h;
22typedef __m128h Packet8h;
23
24template <>
25struct is_arithmetic<Packet8h> {
26 enum { value = true };
27};
28
29template <>
30struct packet_traits<half> : default_packet_traits {
31 typedef Packet32h type;
32 typedef Packet16h half;
33 enum {
34 Vectorizable = 1,
35 AlignedOnScalar = 1,
36 size = 32,
37
38 HasCmp = 1,
39 HasAdd = 1,
40 HasSub = 1,
41 HasMul = 1,
42 HasDiv = 1,
43 HasNegate = 1,
44 HasAbs = 1,
45 HasAbs2 = 0,
46 HasMin = 1,
47 HasMax = 1,
48 HasConj = 1,
49 HasSetLinear = 0,
50 HasLog = 1,
51 HasLog1p = 1,
52 HasExp = 1,
53 HasExpm1 = 1,
54 HasSqrt = 1,
55 HasRsqrt = 1,
56 // These ones should be implemented in future
57 HasBessel = 0,
58 HasNdtri = 0,
59 HasSin = EIGEN_FAST_MATH,
60 HasCos = EIGEN_FAST_MATH,
61 HasTanh = EIGEN_FAST_MATH,
62 HasErf = 0, // EIGEN_FAST_MATH,
63 HasBlend = 0
64 };
65};
66
67template <>
68struct unpacket_traits<Packet32h> {
69 typedef Eigen::half type;
70 typedef Packet16h half;
71 typedef Packet32s integer_packet;
72 enum {
73 size = 32,
74 alignment = Aligned64,
75 vectorizable = true,
76 masked_load_available = false,
77 masked_store_available = false
78 };
79};
80
81template <>
82struct unpacket_traits<Packet16h> {
83 typedef Eigen::half type;
84 typedef Packet8h half;
85 typedef Packet16s integer_packet;
86 enum {
87 size = 16,
88 alignment = Aligned32,
89 vectorizable = true,
90 masked_load_available = false,
91 masked_store_available = false
92 };
93};
94
95template <>
96struct unpacket_traits<Packet8h> {
97 typedef Eigen::half type;
98 typedef Packet8h half;
99 typedef Packet8s integer_packet;
100 enum {
101 size = 8,
102 alignment = Aligned16,
103 vectorizable = true,
104 masked_load_available = false,
105 masked_store_available = false
106 };
107};
108
109// Conversions
110
111EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtxph_ps(a); }
112
113EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { return _mm256_cvtxph_ps(a); }
114
115EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { return _mm512_cvtxps_ph(a); }
116
117EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { return _mm256_cvtxps_ph(a); }
118
119// Memory functions
120
121// pset1
122
123template <>
124EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
125 return _mm512_set1_ph(from.x);
126}
127
128template <>
129EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
130 return _mm256_set1_ph(from.x);
131}
132
133template <>
134EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
135 return _mm_set1_ph(from.x);
136}
137
138template <>
139EIGEN_STRONG_INLINE Packet32h pzero(const Packet32h& /*a*/) {
140 return _mm512_setzero_ph();
141}
142
143template <>
144EIGEN_STRONG_INLINE Packet16h pzero(const Packet16h& /*a*/) {
145 return _mm256_setzero_ph();
146}
147
148template <>
149EIGEN_STRONG_INLINE Packet8h pzero(const Packet8h& /*a*/) {
150 return _mm_setzero_ph();
151}
152
153// pset1frombits
154template <>
155EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
156 return _mm512_castsi512_ph(_mm512_set1_epi16(from));
157}
158
159template <>
160EIGEN_STRONG_INLINE Packet16h pset1frombits<Packet16h>(unsigned short from) {
161 return _mm256_castsi256_ph(_mm256_set1_epi16(from));
162}
163
164template <>
165EIGEN_STRONG_INLINE Packet8h pset1frombits<Packet8h>(unsigned short from) {
166 return _mm_castsi128_ph(_mm_set1_epi16(from));
167}
168
169// pfirst
170
171template <>
172EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
173 return Eigen::half(_mm512_cvtsh_h(from));
174}
175
176template <>
177EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
178 return Eigen::half(_mm256_cvtsh_h(from));
179}
180
181template <>
182EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
183 return Eigen::half(_mm_cvtsh_h(from));
184}
185
186// pload
187
188template <>
189EIGEN_STRONG_INLINE Packet32h pload<Packet32h>(const Eigen::half* from) {
190 EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
191}
192
193template <>
194EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
195 EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ph(from);
196}
197
198template <>
199EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
200 EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ph(from);
201}
202
203// ploadu
204
205template <>
206EIGEN_STRONG_INLINE Packet32h ploadu<Packet32h>(const Eigen::half* from) {
207 EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
208}
209
210template <>
211EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
212 EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ph(from);
213}
214
215template <>
216EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
217 EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_ph(from);
218}
219
220// pstore
221
222template <>
223EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet32h& from) {
224 EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
225}
226
227template <>
228EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
229 EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ph(to, from);
230}
231
232template <>
233EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet8h& from) {
234 EIGEN_DEBUG_ALIGNED_STORE _mm_store_ph(to, from);
235}
236
237// pstoreu
238
239template <>
240EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet32h& from) {
241 EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
242}
243
244template <>
245EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
246 EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ph(to, from);
247}
248
249template <>
250EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet8h& from) {
251 EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ph(to, from);
252}
253
254// ploaddup
255template <>
256EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
257 __m512h a = _mm512_castph256_ph512(_mm256_loadu_ph(from));
258 return _mm512_permutexvar_ph(_mm512_set_epi16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6,
259 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0),
260 a);
261}
262
263template <>
264EIGEN_STRONG_INLINE Packet16h ploaddup<Packet16h>(const Eigen::half* from) {
265 __m256h a = _mm256_castph128_ph256(_mm_loadu_ph(from));
266 return _mm256_permutexvar_ph(_mm256_set_epi16(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), a);
267}
268
269template <>
270EIGEN_STRONG_INLINE Packet8h ploaddup<Packet8h>(const Eigen::half* from) {
271 return _mm_set_ph(from[3].x, from[3].x, from[2].x, from[2].x, from[1].x, from[1].x, from[0].x, from[0].x);
272}
273
274// ploadquad
275template <>
276EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
277 __m512h a = _mm512_castph128_ph512(_mm_loadu_ph(from));
278 return _mm512_permutexvar_ph(
279 _mm512_set_epi16(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0),
280 a);
281}
282
283template <>
284EIGEN_STRONG_INLINE Packet16h ploadquad<Packet16h>(const Eigen::half* from) {
285 return _mm256_set_ph(from[3].x, from[3].x, from[3].x, from[3].x, from[2].x, from[2].x, from[2].x, from[2].x,
286 from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
287}
288
289template <>
290EIGEN_STRONG_INLINE Packet8h ploadquad<Packet8h>(const Eigen::half* from) {
291 return _mm_set_ph(from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
292}
293
294// pabs
295
296template <>
297EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
298 return _mm512_abs_ph(a);
299}
300
301template <>
302EIGEN_STRONG_INLINE Packet16h pabs<Packet16h>(const Packet16h& a) {
303 return _mm256_abs_ph(a);
304}
305
306template <>
307EIGEN_STRONG_INLINE Packet8h pabs<Packet8h>(const Packet8h& a) {
308 return _mm_abs_ph(a);
309}
310
311// psignbit
312
313template <>
314EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
315 return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
316}
317
318template <>
319EIGEN_STRONG_INLINE Packet16h psignbit<Packet16h>(const Packet16h& a) {
320 return _mm256_castsi256_ph(_mm256_srai_epi16(_mm256_castph_si256(a), 15));
321}
322
323template <>
324EIGEN_STRONG_INLINE Packet8h psignbit<Packet8h>(const Packet8h& a) {
325 return _mm_castsi128_ph(_mm_srai_epi16(_mm_castph_si128(a), 15));
326}
327
328// pmin
329
330template <>
331EIGEN_STRONG_INLINE Packet32h pmin<Packet32h>(const Packet32h& a, const Packet32h& b) {
332 return _mm512_min_ph(a, b);
333}
334
335template <>
336EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, const Packet16h& b) {
337 return _mm256_min_ph(a, b);
338}
339
340template <>
341EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a, const Packet8h& b) {
342 return _mm_min_ph(a, b);
343}
344
345// pmax
346
347template <>
348EIGEN_STRONG_INLINE Packet32h pmax<Packet32h>(const Packet32h& a, const Packet32h& b) {
349 return _mm512_max_ph(a, b);
350}
351
352template <>
353EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, const Packet16h& b) {
354 return _mm256_max_ph(a, b);
355}
356
357template <>
358EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a, const Packet8h& b) {
359 return _mm_max_ph(a, b);
360}
361
362// plset
363template <>
364EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
365 return _mm512_add_ph(pset1<Packet32h>(a), _mm512_set_ph(31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
366 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
367}
368
369template <>
370EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
371 return _mm256_add_ph(pset1<Packet16h>(a), _mm256_set_ph(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
372}
373
374template <>
375EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
376 return _mm_add_ph(pset1<Packet8h>(a), _mm_set_ph(7, 6, 5, 4, 3, 2, 1, 0));
377}
378
379// por
380
381template <>
382EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) {
383 return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
384}
385
386template <>
387EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) {
388 return _mm256_castsi256_ph(_mm256_or_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
389}
390
391template <>
392EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a, const Packet8h& b) {
393 return _mm_castsi128_ph(_mm_or_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
394}
395
396// pxor
397
398template <>
399EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) {
400 return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
401}
402
403template <>
404EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) {
405 return _mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
406}
407
408template <>
409EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a, const Packet8h& b) {
410 return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
411}
412
413// pand
414
415template <>
416EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) {
417 return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
418}
419
420template <>
421EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) {
422 return _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
423}
424
425template <>
426EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a, const Packet8h& b) {
427 return _mm_castsi128_ph(_mm_and_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
428}
429
430// pandnot
431
432template <>
433EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) {
434 return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
435}
436
437template <>
438EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) {
439 return _mm256_castsi256_ph(_mm256_andnot_si256(_mm256_castph_si256(b), _mm256_castph_si256(a)));
440}
441
442template <>
443EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a, const Packet8h& b) {
444 return _mm_castsi128_ph(_mm_andnot_si128(_mm_castph_si128(b), _mm_castph_si128(a)));
445}
446
447// pselect
448
449template <>
450EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32h& a, const Packet32h& b) {
451 __mmask32 mask32 = _mm512_cmp_epi16_mask(_mm512_castph_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
452 return _mm512_mask_blend_ph(mask32, a, b);
453}
454
455template <>
456EIGEN_DEVICE_FUNC inline Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
457 __mmask16 mask16 = _mm256_cmp_epi16_mask(_mm256_castph_si256(mask), _mm256_setzero_si256(), _MM_CMPINT_EQ);
458 return _mm256_mask_blend_ph(mask16, a, b);
459}
460
461template <>
462EIGEN_DEVICE_FUNC inline Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
463 __mmask8 mask8 = _mm_cmp_epi16_mask(_mm_castph_si128(mask), _mm_setzero_si128(), _MM_CMPINT_EQ);
464 return _mm_mask_blend_ph(mask8, a, b);
465}
466
467// pcmp_eq
468
469template <>
470EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) {
471 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_EQ_OQ);
472 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
473}
474
475template <>
476EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
477 __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_EQ_OQ);
478 return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
479}
480
481template <>
482EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
483 __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_EQ_OQ);
484 return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
485}
486
487// pcmp_le
488
489template <>
490EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) {
491 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LE_OQ);
492 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
493}
494
495template <>
496EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) {
497 __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LE_OQ);
498 return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
499}
500
501template <>
502EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
503 __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LE_OQ);
504 return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
505}
506
507// pcmp_lt
508
509template <>
510EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) {
511 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LT_OQ);
512 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
513}
514
515template <>
516EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) {
517 __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LT_OQ);
518 return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
519}
520
521template <>
522EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
523 __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LT_OQ);
524 return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
525}
526
527// pcmp_lt_or_nan
528
529template <>
530EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h& b) {
531 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_NGE_UQ);
532 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, static_cast<short>(0xffffu)));
533}
534
535template <>
536EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) {
537 __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_NGE_UQ);
538 return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
539}
540
541template <>
542EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
543 __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_NGE_UQ);
544 return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
545}
546
547// padd
548
549template <>
550EIGEN_STRONG_INLINE Packet32h padd<Packet32h>(const Packet32h& a, const Packet32h& b) {
551 return _mm512_add_ph(a, b);
552}
553
554template <>
555EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
556 return _mm256_add_ph(a, b);
557}
558
559template <>
560EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
561 return _mm_add_ph(a, b);
562}
563
564// psub
565
566template <>
567EIGEN_STRONG_INLINE Packet32h psub<Packet32h>(const Packet32h& a, const Packet32h& b) {
568 return _mm512_sub_ph(a, b);
569}
570
571template <>
572EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
573 return _mm256_sub_ph(a, b);
574}
575
576template <>
577EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
578 return _mm_sub_ph(a, b);
579}
580
581// pmul
582
583template <>
584EIGEN_STRONG_INLINE Packet32h pmul<Packet32h>(const Packet32h& a, const Packet32h& b) {
585 return _mm512_mul_ph(a, b);
586}
587
588template <>
589EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
590 return _mm256_mul_ph(a, b);
591}
592
593template <>
594EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
595 return _mm_mul_ph(a, b);
596}
597
598// pdiv
599
600template <>
601EIGEN_STRONG_INLINE Packet32h pdiv<Packet32h>(const Packet32h& a, const Packet32h& b) {
602 return _mm512_div_ph(a, b);
603}
604
605template <>
606EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
607 return _mm256_div_ph(a, b);
608}
609
610template <>
611EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
612 return _mm_div_ph(a, b);
613 ;
614}
615
616// pround
617
618template <>
619EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
620 // Work-around for default std::round rounding mode.
621
622 // Mask for the sign bit.
623 const Packet32h signMask =
624 pset1frombits<Packet32h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
625 // The largest half-precision float less than 0.5.
626 const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
627
628 return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
629}
630
631template <>
632EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
633 // Work-around for default std::round rounding mode.
634
635 // Mask for the sign bit.
636 const Packet16h signMask =
637 pset1frombits<Packet16h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
638 // The largest half-precision float less than 0.5.
639 const Packet16h prev0dot5 = pset1frombits<Packet16h>(static_cast<numext::uint16_t>(0x37FFu));
640
641 return _mm256_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
642}
643
644template <>
645EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
646 // Work-around for default std::round rounding mode.
647
648 // Mask for the sign bit.
649 const Packet8h signMask = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
650 // The largest half-precision float less than 0.5.
651 const Packet8h prev0dot5 = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(0x37FFu));
652
653 return _mm_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
654}
655
656// print
657
658template <>
659EIGEN_STRONG_INLINE Packet32h print<Packet32h>(const Packet32h& a) {
660 return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
661}
662
663template <>
664EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
665 return _mm256_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
666}
667
668template <>
669EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
670 return _mm_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
671}
672
673// pceil
674
675template <>
676EIGEN_STRONG_INLINE Packet32h pceil<Packet32h>(const Packet32h& a) {
677 return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
678}
679
680template <>
681EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
682 return _mm256_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
683}
684
685template <>
686EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
687 return _mm_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
688}
689
690// pfloor
691
692template <>
693EIGEN_STRONG_INLINE Packet32h pfloor<Packet32h>(const Packet32h& a) {
694 return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
695}
696
697template <>
698EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
699 return _mm256_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
700}
701
702template <>
703EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
704 return _mm_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
705}
706
707// ptrunc
708
709template <>
710EIGEN_STRONG_INLINE Packet32h ptrunc<Packet32h>(const Packet32h& a) {
711 return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO);
712}
713
714template <>
715EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
716 return _mm256_roundscale_ph(a, _MM_FROUND_TO_ZERO);
717}
718
719template <>
720EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
721 return _mm_roundscale_ph(a, _MM_FROUND_TO_ZERO);
722}
723
724// predux
725template <>
726EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
727 return half(_mm512_reduce_add_ph(a));
728}
729
730template <>
731EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
732 return half(_mm256_reduce_add_ph(a));
733}
734
735template <>
736EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
737 return half(_mm_reduce_add_ph(a));
738}
739
740// predux_half_dowto4
741template <>
742EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
743 const __m512i bits = _mm512_castph_si512(a);
744 Packet16h lo = _mm256_castsi256_ph(_mm512_castsi512_si256(bits));
745 Packet16h hi = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(bits, 1));
746 return padd(lo, hi);
747}
748
749template <>
750EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
751 Packet8h lo = _mm_castsi128_ph(_mm256_castsi256_si128(_mm256_castph_si256(a)));
752 Packet8h hi = _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(a), 1));
753 return padd(lo, hi);
754}
755
756// predux_max
757
758template <>
759EIGEN_STRONG_INLINE half predux_max<Packet32h>(const Packet32h& a) {
760 return half(_mm512_reduce_max_ph(a));
761}
762
763template <>
764EIGEN_STRONG_INLINE half predux_max<Packet16h>(const Packet16h& a) {
765 return half(_mm256_reduce_max_ph(a));
766}
767
768template <>
769EIGEN_STRONG_INLINE half predux_max<Packet8h>(const Packet8h& a) {
770 return half(_mm_reduce_max_ph(a));
771}
772
773// predux_min
774
775template <>
776EIGEN_STRONG_INLINE half predux_min<Packet32h>(const Packet32h& a) {
777 return half(_mm512_reduce_min_ph(a));
778}
779
780template <>
781EIGEN_STRONG_INLINE half predux_min<Packet16h>(const Packet16h& a) {
782 return half(_mm256_reduce_min_ph(a));
783}
784
785template <>
786EIGEN_STRONG_INLINE half predux_min<Packet8h>(const Packet8h& a) {
787 return half(_mm_reduce_min_ph(a));
788}
789
790// predux_mul
791
792template <>
793EIGEN_STRONG_INLINE half predux_mul<Packet32h>(const Packet32h& a) {
794 return half(_mm512_reduce_mul_ph(a));
795}
796
797template <>
798EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& a) {
799 return half(_mm256_reduce_mul_ph(a));
800}
801
802template <>
803EIGEN_STRONG_INLINE half predux_mul<Packet8h>(const Packet8h& a) {
804 return half(_mm_reduce_mul_ph(a));
805}
806
807#ifdef EIGEN_VECTORIZE_FMA
808
809// pmadd
810
811template <>
812EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
813 return _mm512_fmadd_ph(a, b, c);
814}
815
816template <>
817EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
818 return _mm256_fmadd_ph(a, b, c);
819}
820
821template <>
822EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
823 return _mm_fmadd_ph(a, b, c);
824}
825
826// pmsub
827
828template <>
829EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
830 return _mm512_fmsub_ph(a, b, c);
831}
832
833template <>
834EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
835 return _mm256_fmsub_ph(a, b, c);
836}
837
838template <>
839EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
840 return _mm_fmsub_ph(a, b, c);
841}
842
843// pnmadd
844
845template <>
846EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
847 return _mm512_fnmadd_ph(a, b, c);
848}
849
850template <>
851EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
852 return _mm256_fnmadd_ph(a, b, c);
853}
854
855template <>
856EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
857 return _mm_fnmadd_ph(a, b, c);
858}
859
860// pnmsub
861
862template <>
863EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
864 return _mm512_fnmsub_ph(a, b, c);
865}
866
867template <>
868EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
869 return _mm256_fnmsub_ph(a, b, c);
870}
871
872template <>
873EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
874 return _mm_fnmsub_ph(a, b, c);
875}
876
877#endif
878
879// pnegate
880
881template <>
882EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
883 return _mm512_castsi512_ph(
884 _mm512_xor_si512(_mm512_castph_si512(a), _mm512_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
885}
886
887template <>
888EIGEN_STRONG_INLINE Packet16h pnegate<Packet16h>(const Packet16h& a) {
889 return _mm256_castsi256_ph(
890 _mm256_xor_si256(_mm256_castph_si256(a), _mm256_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
891}
892
893template <>
894EIGEN_STRONG_INLINE Packet8h pnegate<Packet8h>(const Packet8h& a) {
895 return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
896}
897
898// pconj
899
900// Nothing, packets are real.
901
902// psqrt
903
904template <>
905EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
906 return generic_sqrt_newton_step<Packet32h>::run(a, _mm512_rsqrt_ph(a));
907}
908
909template <>
910EIGEN_STRONG_INLINE Packet16h psqrt<Packet16h>(const Packet16h& a) {
911 return generic_sqrt_newton_step<Packet16h>::run(a, _mm256_rsqrt_ph(a));
912}
913
914template <>
915EIGEN_STRONG_INLINE Packet8h psqrt<Packet8h>(const Packet8h& a) {
916 return generic_sqrt_newton_step<Packet8h>::run(a, _mm_rsqrt_ph(a));
917}
918
919// prsqrt
920
921template <>
922EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
923 return generic_rsqrt_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rsqrt_ph(a));
924}
925
926template <>
927EIGEN_STRONG_INLINE Packet16h prsqrt<Packet16h>(const Packet16h& a) {
928 return generic_rsqrt_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rsqrt_ph(a));
929}
930
931template <>
932EIGEN_STRONG_INLINE Packet8h prsqrt<Packet8h>(const Packet8h& a) {
933 return generic_rsqrt_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rsqrt_ph(a));
934}
935
936// preciprocal
937
938template <>
939EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
940 return generic_reciprocal_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rcp_ph(a));
941}
942
943template <>
944EIGEN_STRONG_INLINE Packet16h preciprocal<Packet16h>(const Packet16h& a) {
945 return generic_reciprocal_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rcp_ph(a));
946}
947
948template <>
949EIGEN_STRONG_INLINE Packet8h preciprocal<Packet8h>(const Packet8h& a) {
950 return generic_reciprocal_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rcp_ph(a));
951}
952
953// ptranspose
954
955EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 32>& a) {
956 __m512i t[32];
957
958 EIGEN_UNROLL_LOOP
959 for (int i = 0; i < 16; i++) {
960 t[2 * i] = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
961 t[2 * i + 1] =
962 _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
963 }
964
965 __m512i p[32];
966
967 EIGEN_UNROLL_LOOP
968 for (int i = 0; i < 8; i++) {
969 p[4 * i] = _mm512_unpacklo_epi32(t[4 * i], t[4 * i + 2]);
970 p[4 * i + 1] = _mm512_unpackhi_epi32(t[4 * i], t[4 * i + 2]);
971 p[4 * i + 2] = _mm512_unpacklo_epi32(t[4 * i + 1], t[4 * i + 3]);
972 p[4 * i + 3] = _mm512_unpackhi_epi32(t[4 * i + 1], t[4 * i + 3]);
973 }
974
975 __m512i q[32];
976
977 EIGEN_UNROLL_LOOP
978 for (int i = 0; i < 4; i++) {
979 q[8 * i] = _mm512_unpacklo_epi64(p[8 * i], p[8 * i + 4]);
980 q[8 * i + 1] = _mm512_unpackhi_epi64(p[8 * i], p[8 * i + 4]);
981 q[8 * i + 2] = _mm512_unpacklo_epi64(p[8 * i + 1], p[8 * i + 5]);
982 q[8 * i + 3] = _mm512_unpackhi_epi64(p[8 * i + 1], p[8 * i + 5]);
983 q[8 * i + 4] = _mm512_unpacklo_epi64(p[8 * i + 2], p[8 * i + 6]);
984 q[8 * i + 5] = _mm512_unpackhi_epi64(p[8 * i + 2], p[8 * i + 6]);
985 q[8 * i + 6] = _mm512_unpacklo_epi64(p[8 * i + 3], p[8 * i + 7]);
986 q[8 * i + 7] = _mm512_unpackhi_epi64(p[8 * i + 3], p[8 * i + 7]);
987 }
988
989 __m512i f[32];
990
991#define PACKET32H_TRANSPOSE_HELPER(X, Y) \
992 do { \
993 f[Y * 8] = _mm512_inserti32x4(f[Y * 8], _mm512_extracti32x4_epi32(q[X * 8], Y), X); \
994 f[Y * 8 + 1] = _mm512_inserti32x4(f[Y * 8 + 1], _mm512_extracti32x4_epi32(q[X * 8 + 1], Y), X); \
995 f[Y * 8 + 2] = _mm512_inserti32x4(f[Y * 8 + 2], _mm512_extracti32x4_epi32(q[X * 8 + 2], Y), X); \
996 f[Y * 8 + 3] = _mm512_inserti32x4(f[Y * 8 + 3], _mm512_extracti32x4_epi32(q[X * 8 + 3], Y), X); \
997 f[Y * 8 + 4] = _mm512_inserti32x4(f[Y * 8 + 4], _mm512_extracti32x4_epi32(q[X * 8 + 4], Y), X); \
998 f[Y * 8 + 5] = _mm512_inserti32x4(f[Y * 8 + 5], _mm512_extracti32x4_epi32(q[X * 8 + 5], Y), X); \
999 f[Y * 8 + 6] = _mm512_inserti32x4(f[Y * 8 + 6], _mm512_extracti32x4_epi32(q[X * 8 + 6], Y), X); \
1000 f[Y * 8 + 7] = _mm512_inserti32x4(f[Y * 8 + 7], _mm512_extracti32x4_epi32(q[X * 8 + 7], Y), X); \
1001 } while (false);
1002
1003 PACKET32H_TRANSPOSE_HELPER(0, 0);
1004 PACKET32H_TRANSPOSE_HELPER(1, 1);
1005 PACKET32H_TRANSPOSE_HELPER(2, 2);
1006 PACKET32H_TRANSPOSE_HELPER(3, 3);
1007
1008 PACKET32H_TRANSPOSE_HELPER(1, 0);
1009 PACKET32H_TRANSPOSE_HELPER(2, 0);
1010 PACKET32H_TRANSPOSE_HELPER(3, 0);
1011 PACKET32H_TRANSPOSE_HELPER(2, 1);
1012 PACKET32H_TRANSPOSE_HELPER(3, 1);
1013 PACKET32H_TRANSPOSE_HELPER(3, 2);
1014
1015 PACKET32H_TRANSPOSE_HELPER(0, 1);
1016 PACKET32H_TRANSPOSE_HELPER(0, 2);
1017 PACKET32H_TRANSPOSE_HELPER(0, 3);
1018 PACKET32H_TRANSPOSE_HELPER(1, 2);
1019 PACKET32H_TRANSPOSE_HELPER(1, 3);
1020 PACKET32H_TRANSPOSE_HELPER(2, 3);
1021
1022#undef PACKET32H_TRANSPOSE_HELPER
1023
1024 EIGEN_UNROLL_LOOP
1025 for (int i = 0; i < 32; i++) {
1026 a.packet[i] = _mm512_castsi512_ph(f[i]);
1027 }
1028}
1029
1030EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 4>& a) {
1031 __m512i p0, p1, p2, p3, t0, t1, t2, t3, a0, a1, a2, a3;
1032 t0 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
1033 t1 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
1034 t2 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
1035 t3 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
1036
1037 p0 = _mm512_unpacklo_epi32(t0, t2);
1038 p1 = _mm512_unpackhi_epi32(t0, t2);
1039 p2 = _mm512_unpacklo_epi32(t1, t3);
1040 p3 = _mm512_unpackhi_epi32(t1, t3);
1041
1042 a0 = p0;
1043 a1 = p1;
1044 a2 = p2;
1045 a3 = p3;
1046
1047 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p1, 0), 1);
1048 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p0, 1), 0);
1049
1050 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p2, 0), 2);
1051 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p0, 2), 0);
1052
1053 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p3, 0), 3);
1054 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p0, 3), 0);
1055
1056 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p2, 1), 2);
1057 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p1, 2), 1);
1058
1059 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p3, 2), 3);
1060 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p2, 3), 2);
1061
1062 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p3, 1), 3);
1063 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p1, 3), 1);
1064
1065 a.packet[0] = _mm512_castsi512_ph(a0);
1066 a.packet[1] = _mm512_castsi512_ph(a1);
1067 a.packet[2] = _mm512_castsi512_ph(a2);
1068 a.packet[3] = _mm512_castsi512_ph(a3);
1069}
1070
1071EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 16>& kernel) {
1072 __m256i a = _mm256_castph_si256(kernel.packet[0]);
1073 __m256i b = _mm256_castph_si256(kernel.packet[1]);
1074 __m256i c = _mm256_castph_si256(kernel.packet[2]);
1075 __m256i d = _mm256_castph_si256(kernel.packet[3]);
1076 __m256i e = _mm256_castph_si256(kernel.packet[4]);
1077 __m256i f = _mm256_castph_si256(kernel.packet[5]);
1078 __m256i g = _mm256_castph_si256(kernel.packet[6]);
1079 __m256i h = _mm256_castph_si256(kernel.packet[7]);
1080 __m256i i = _mm256_castph_si256(kernel.packet[8]);
1081 __m256i j = _mm256_castph_si256(kernel.packet[9]);
1082 __m256i k = _mm256_castph_si256(kernel.packet[10]);
1083 __m256i l = _mm256_castph_si256(kernel.packet[11]);
1084 __m256i m = _mm256_castph_si256(kernel.packet[12]);
1085 __m256i n = _mm256_castph_si256(kernel.packet[13]);
1086 __m256i o = _mm256_castph_si256(kernel.packet[14]);
1087 __m256i p = _mm256_castph_si256(kernel.packet[15]);
1088
1089 __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
1090 __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
1091 __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
1092 __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
1093 __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
1094 __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
1095 __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
1096 __m256i op_07 = _mm256_unpacklo_epi16(o, p);
1097
1098 __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
1099 __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
1100 __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
1101 __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
1102 __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
1103 __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
1104 __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
1105 __m256i op_8f = _mm256_unpackhi_epi16(o, p);
1106
1107 __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
1108 __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
1109 __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
1110 __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
1111 __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
1112 __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
1113 __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
1114 __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
1115
1116 __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
1117 __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
1118 __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
1119 __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
1120 __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
1121 __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
1122 __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
1123 __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
1124
1125 __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
1126 __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
1127 __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
1128 __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
1129 __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
1130 __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
1131 __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
1132 __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
1133 __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
1134 __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
1135 __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
1136 __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
1137 __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
1138 __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
1139 __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
1140 __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
1141
1142 // NOTE: no unpacklo/hi instr in this case, so using permute instr.
1143 __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
1144 __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
1145 __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
1146 __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
1147 __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
1148 __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
1149 __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
1150 __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
1151 __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
1152 __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
1153 __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
1154 __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
1155 __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
1156 __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
1157 __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
1158 __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
1159
1160 kernel.packet[0] = _mm256_castsi256_ph(a_p_0);
1161 kernel.packet[1] = _mm256_castsi256_ph(a_p_1);
1162 kernel.packet[2] = _mm256_castsi256_ph(a_p_2);
1163 kernel.packet[3] = _mm256_castsi256_ph(a_p_3);
1164 kernel.packet[4] = _mm256_castsi256_ph(a_p_4);
1165 kernel.packet[5] = _mm256_castsi256_ph(a_p_5);
1166 kernel.packet[6] = _mm256_castsi256_ph(a_p_6);
1167 kernel.packet[7] = _mm256_castsi256_ph(a_p_7);
1168 kernel.packet[8] = _mm256_castsi256_ph(a_p_8);
1169 kernel.packet[9] = _mm256_castsi256_ph(a_p_9);
1170 kernel.packet[10] = _mm256_castsi256_ph(a_p_a);
1171 kernel.packet[11] = _mm256_castsi256_ph(a_p_b);
1172 kernel.packet[12] = _mm256_castsi256_ph(a_p_c);
1173 kernel.packet[13] = _mm256_castsi256_ph(a_p_d);
1174 kernel.packet[14] = _mm256_castsi256_ph(a_p_e);
1175 kernel.packet[15] = _mm256_castsi256_ph(a_p_f);
1176}
1177
1178EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 8>& kernel) {
1179 EIGEN_ALIGN64 half in[8][16];
1180 pstore<half>(in[0], kernel.packet[0]);
1181 pstore<half>(in[1], kernel.packet[1]);
1182 pstore<half>(in[2], kernel.packet[2]);
1183 pstore<half>(in[3], kernel.packet[3]);
1184 pstore<half>(in[4], kernel.packet[4]);
1185 pstore<half>(in[5], kernel.packet[5]);
1186 pstore<half>(in[6], kernel.packet[6]);
1187 pstore<half>(in[7], kernel.packet[7]);
1188
1189 EIGEN_ALIGN64 half out[8][16];
1190
1191 for (int i = 0; i < 8; ++i) {
1192 for (int j = 0; j < 8; ++j) {
1193 out[i][j] = in[j][2 * i];
1194 }
1195 for (int j = 0; j < 8; ++j) {
1196 out[i][j + 8] = in[j][2 * i + 1];
1197 }
1198 }
1199
1200 kernel.packet[0] = pload<Packet16h>(out[0]);
1201 kernel.packet[1] = pload<Packet16h>(out[1]);
1202 kernel.packet[2] = pload<Packet16h>(out[2]);
1203 kernel.packet[3] = pload<Packet16h>(out[3]);
1204 kernel.packet[4] = pload<Packet16h>(out[4]);
1205 kernel.packet[5] = pload<Packet16h>(out[5]);
1206 kernel.packet[6] = pload<Packet16h>(out[6]);
1207 kernel.packet[7] = pload<Packet16h>(out[7]);
1208}
1209
1210EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
1211 EIGEN_ALIGN64 half in[4][16];
1212 pstore<half>(in[0], kernel.packet[0]);
1213 pstore<half>(in[1], kernel.packet[1]);
1214 pstore<half>(in[2], kernel.packet[2]);
1215 pstore<half>(in[3], kernel.packet[3]);
1216
1217 EIGEN_ALIGN64 half out[4][16];
1218
1219 for (int i = 0; i < 4; ++i) {
1220 for (int j = 0; j < 4; ++j) {
1221 out[i][j] = in[j][4 * i];
1222 }
1223 for (int j = 0; j < 4; ++j) {
1224 out[i][j + 4] = in[j][4 * i + 1];
1225 }
1226 for (int j = 0; j < 4; ++j) {
1227 out[i][j + 8] = in[j][4 * i + 2];
1228 }
1229 for (int j = 0; j < 4; ++j) {
1230 out[i][j + 12] = in[j][4 * i + 3];
1231 }
1232 }
1233
1234 kernel.packet[0] = pload<Packet16h>(out[0]);
1235 kernel.packet[1] = pload<Packet16h>(out[1]);
1236 kernel.packet[2] = pload<Packet16h>(out[2]);
1237 kernel.packet[3] = pload<Packet16h>(out[3]);
1238}
1239
1240EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 8>& kernel) {
1241 __m128i a = _mm_castph_si128(kernel.packet[0]);
1242 __m128i b = _mm_castph_si128(kernel.packet[1]);
1243 __m128i c = _mm_castph_si128(kernel.packet[2]);
1244 __m128i d = _mm_castph_si128(kernel.packet[3]);
1245 __m128i e = _mm_castph_si128(kernel.packet[4]);
1246 __m128i f = _mm_castph_si128(kernel.packet[5]);
1247 __m128i g = _mm_castph_si128(kernel.packet[6]);
1248 __m128i h = _mm_castph_si128(kernel.packet[7]);
1249
1250 __m128i a03b03 = _mm_unpacklo_epi16(a, b);
1251 __m128i c03d03 = _mm_unpacklo_epi16(c, d);
1252 __m128i e03f03 = _mm_unpacklo_epi16(e, f);
1253 __m128i g03h03 = _mm_unpacklo_epi16(g, h);
1254 __m128i a47b47 = _mm_unpackhi_epi16(a, b);
1255 __m128i c47d47 = _mm_unpackhi_epi16(c, d);
1256 __m128i e47f47 = _mm_unpackhi_epi16(e, f);
1257 __m128i g47h47 = _mm_unpackhi_epi16(g, h);
1258
1259 __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
1260 __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
1261 __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
1262 __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
1263 __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
1264 __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
1265 __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
1266 __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
1267
1268 __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
1269 __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
1270 __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
1271 __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
1272 __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
1273 __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
1274 __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
1275 __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
1276
1277 kernel.packet[0] = _mm_castsi128_ph(a0b0c0d0e0f0g0h0);
1278 kernel.packet[1] = _mm_castsi128_ph(a1b1c1d1e1f1g1h1);
1279 kernel.packet[2] = _mm_castsi128_ph(a2b2c2d2e2f2g2h2);
1280 kernel.packet[3] = _mm_castsi128_ph(a3b3c3d3e3f3g3h3);
1281 kernel.packet[4] = _mm_castsi128_ph(a4b4c4d4e4f4g4h4);
1282 kernel.packet[5] = _mm_castsi128_ph(a5b5c5d5e5f5g5h5);
1283 kernel.packet[6] = _mm_castsi128_ph(a6b6c6d6e6f6g6h6);
1284 kernel.packet[7] = _mm_castsi128_ph(a7b7c7d7e7f7g7h7);
1285}
1286
1287EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 4>& kernel) {
1288 EIGEN_ALIGN32 Eigen::half in[4][8];
1289 pstore<Eigen::half>(in[0], kernel.packet[0]);
1290 pstore<Eigen::half>(in[1], kernel.packet[1]);
1291 pstore<Eigen::half>(in[2], kernel.packet[2]);
1292 pstore<Eigen::half>(in[3], kernel.packet[3]);
1293
1294 EIGEN_ALIGN32 Eigen::half out[4][8];
1295
1296 for (int i = 0; i < 4; ++i) {
1297 for (int j = 0; j < 4; ++j) {
1298 out[i][j] = in[j][2 * i];
1299 }
1300 for (int j = 0; j < 4; ++j) {
1301 out[i][j + 4] = in[j][2 * i + 1];
1302 }
1303 }
1304
1305 kernel.packet[0] = pload<Packet8h>(out[0]);
1306 kernel.packet[1] = pload<Packet8h>(out[1]);
1307 kernel.packet[2] = pload<Packet8h>(out[2]);
1308 kernel.packet[3] = pload<Packet8h>(out[3]);
1309}
1310
1311// preverse
1312
1313template <>
1314EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) {
1315 return _mm512_permutexvar_ph(_mm512_set_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
1316 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31),
1317 a);
1318}
1319
1320template <>
1321EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) {
1322 __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
1323 return _mm256_castsi256_ph(_mm256_insertf128_si256(
1324 _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 1), m)),
1325 _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 0), m), 1));
1326}
1327
1328template <>
1329EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) {
1330 __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
1331 return _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a), m));
1332}
1333
1334// pscatter
1335
1336template <>
1337EIGEN_STRONG_INLINE void pscatter<half, Packet32h>(half* to, const Packet32h& from, Index stride) {
1338 EIGEN_ALIGN64 half aux[32];
1339 pstore(aux, from);
1340
1341 EIGEN_UNROLL_LOOP
1342 for (int i = 0; i < 32; i++) {
1343 to[stride * i] = aux[i];
1344 }
1345}
1346template <>
1347EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) {
1348 EIGEN_ALIGN64 half aux[16];
1349 pstore(aux, from);
1350 to[stride * 0] = aux[0];
1351 to[stride * 1] = aux[1];
1352 to[stride * 2] = aux[2];
1353 to[stride * 3] = aux[3];
1354 to[stride * 4] = aux[4];
1355 to[stride * 5] = aux[5];
1356 to[stride * 6] = aux[6];
1357 to[stride * 7] = aux[7];
1358 to[stride * 8] = aux[8];
1359 to[stride * 9] = aux[9];
1360 to[stride * 10] = aux[10];
1361 to[stride * 11] = aux[11];
1362 to[stride * 12] = aux[12];
1363 to[stride * 13] = aux[13];
1364 to[stride * 14] = aux[14];
1365 to[stride * 15] = aux[15];
1366}
1367
1368template <>
1369EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) {
1370 EIGEN_ALIGN32 Eigen::half aux[8];
1371 pstore(aux, from);
1372 to[stride * 0] = aux[0];
1373 to[stride * 1] = aux[1];
1374 to[stride * 2] = aux[2];
1375 to[stride * 3] = aux[3];
1376 to[stride * 4] = aux[4];
1377 to[stride * 5] = aux[5];
1378 to[stride * 6] = aux[6];
1379 to[stride * 7] = aux[7];
1380}
1381
1382// pgather
1383
1384template <>
1385EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
1386 return _mm512_set_ph(from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x,
1387 from[27 * stride].x, from[26 * stride].x, from[25 * stride].x, from[24 * stride].x,
1388 from[23 * stride].x, from[22 * stride].x, from[21 * stride].x, from[20 * stride].x,
1389 from[19 * stride].x, from[18 * stride].x, from[17 * stride].x, from[16 * stride].x,
1390 from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
1391 from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
1392 from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
1393 from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
1394}
1395
1396template <>
1397EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) {
1398 return _mm256_set_ph(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
1399 from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
1400 from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
1401 from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
1402}
1403
1404template <>
1405EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) {
1406 return _mm_set_ph(from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x,
1407 from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
1408}
1409
1410} // end namespace internal
1411} // end namespace Eigen
1412
1413#endif // EIGEN_PACKET_MATH_FP16_AVX512_H
@ Aligned64
Definition Constants.h:239
@ Aligned32
Definition Constants.h:238
@ Aligned16
Definition Constants.h:237
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82