Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
Reductions.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_REDUCTIONS_AVX512_H
11#define EIGEN_REDUCTIONS_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20/* -- -- -- -- -- -- -- -- -- -- -- -- Packet16i -- -- -- -- -- -- -- -- -- -- -- -- */
21
22template <>
23EIGEN_STRONG_INLINE int predux(const Packet16i& a) {
24 return _mm512_reduce_add_epi32(a);
25}
26
27template <>
28EIGEN_STRONG_INLINE int predux_mul(const Packet16i& a) {
29 return _mm512_reduce_mul_epi32(a);
30}
31
32template <>
33EIGEN_STRONG_INLINE int predux_min(const Packet16i& a) {
34 return _mm512_reduce_min_epi32(a);
35}
36
37template <>
38EIGEN_STRONG_INLINE int predux_max(const Packet16i& a) {
39 return _mm512_reduce_max_epi32(a);
40}
41
42template <>
43EIGEN_STRONG_INLINE bool predux_any(const Packet16i& a) {
44 return _mm512_reduce_or_epi32(a) != 0;
45}
46
47/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8l -- -- -- -- -- -- -- -- -- -- -- -- */
48
49template <>
50EIGEN_STRONG_INLINE int64_t predux(const Packet8l& a) {
51 return _mm512_reduce_add_epi64(a);
52}
53
54#if EIGEN_COMP_MSVC
55// MSVC's _mm512_reduce_mul_epi64 is borked, at least up to and including 1939.
56// alignas(64) int64_t data[] = { 1,1,-1,-1,1,-1,-1,-1 };
57// int64_t out = _mm512_reduce_mul_epi64(_mm512_load_epi64(data));
58// produces garbage: 4294967295. It seems to happen whenever the output is supposed to be negative.
59// Fall back to a manual approach:
60template <>
61EIGEN_STRONG_INLINE int64_t predux_mul(const Packet8l& a) {
62 Packet4l lane0 = _mm512_extracti64x4_epi64(a, 0);
63 Packet4l lane1 = _mm512_extracti64x4_epi64(a, 1);
64 return predux_mul(pmul(lane0, lane1));
65}
66#else
67template <>
68EIGEN_STRONG_INLINE int64_t predux_mul<Packet8l>(const Packet8l& a) {
69 return _mm512_reduce_mul_epi64(a);
70}
71#endif
72
73template <>
74EIGEN_STRONG_INLINE int64_t predux_min(const Packet8l& a) {
75 return _mm512_reduce_min_epi64(a);
76}
77
78template <>
79EIGEN_STRONG_INLINE int64_t predux_max(const Packet8l& a) {
80 return _mm512_reduce_max_epi64(a);
81}
82
83template <>
84EIGEN_STRONG_INLINE bool predux_any(const Packet8l& a) {
85 return _mm512_reduce_or_epi64(a) != 0;
86}
87
88/* -- -- -- -- -- -- -- -- -- -- -- -- Packet16f -- -- -- -- -- -- -- -- -- -- -- -- */
89
90template <>
91EIGEN_STRONG_INLINE float predux(const Packet16f& a) {
92 return _mm512_reduce_add_ps(a);
93}
94
95template <>
96EIGEN_STRONG_INLINE float predux_mul(const Packet16f& a) {
97 return _mm512_reduce_mul_ps(a);
98}
99
100template <>
101EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) {
102 return _mm512_reduce_min_ps(a);
103}
104
105template <>
106EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet16f& a) {
107 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
108 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
109 return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lane0, lane1));
110}
111
112template <>
113EIGEN_STRONG_INLINE float predux_min<PropagateNaN>(const Packet16f& a) {
114 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
115 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
116 return predux_min<PropagateNaN>(pmin<PropagateNaN>(lane0, lane1));
117}
118
119template <>
120EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) {
121 return _mm512_reduce_max_ps(a);
122}
123
124template <>
125EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet16f& a) {
126 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
127 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
128 return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lane0, lane1));
129}
130
131template <>
132EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet16f& a) {
133 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
134 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
135 return predux_max<PropagateNaN>(pmax<PropagateNaN>(lane0, lane1));
136}
137
138template <>
139EIGEN_STRONG_INLINE bool predux_any(const Packet16f& a) {
140 return _mm512_reduce_or_epi32(_mm512_castps_si512(a)) != 0;
141}
142
143/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8d -- -- -- -- -- -- -- -- -- -- -- -- */
144
145template <>
146EIGEN_STRONG_INLINE double predux(const Packet8d& a) {
147 return _mm512_reduce_add_pd(a);
148}
149
150template <>
151EIGEN_STRONG_INLINE double predux_mul(const Packet8d& a) {
152 return _mm512_reduce_mul_pd(a);
153}
154
155template <>
156EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) {
157 return _mm512_reduce_min_pd(a);
158}
159
160template <>
161EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet8d& a) {
162 Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
163 Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
164 return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lane0, lane1));
165}
166
167template <>
168EIGEN_STRONG_INLINE double predux_min<PropagateNaN>(const Packet8d& a) {
169 Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
170 Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
171 return predux_min<PropagateNaN>(pmin<PropagateNaN>(lane0, lane1));
172}
173
174template <>
175EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) {
176 return _mm512_reduce_max_pd(a);
177}
178
179template <>
180EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet8d& a) {
181 Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
182 Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
183 return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lane0, lane1));
184}
185
186template <>
187EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet8d& a) {
188 Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
189 Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
190 return predux_max<PropagateNaN>(pmax<PropagateNaN>(lane0, lane1));
191}
192
193template <>
194EIGEN_STRONG_INLINE bool predux_any(const Packet8d& a) {
195 return _mm512_reduce_or_epi64(_mm512_castpd_si512(a)) != 0;
196}
197
198#ifndef EIGEN_VECTORIZE_AVX512FP16
199/* -- -- -- -- -- -- -- -- -- -- -- -- Packet16h -- -- -- -- -- -- -- -- -- -- -- -- */
200
201template <>
202EIGEN_STRONG_INLINE half predux(const Packet16h& from) {
203 return half(predux(half2float(from)));
204}
205
206template <>
207EIGEN_STRONG_INLINE half predux_mul(const Packet16h& from) {
208 return half(predux_mul(half2float(from)));
209}
210
211template <>
212EIGEN_STRONG_INLINE half predux_min(const Packet16h& from) {
213 return half(predux_min(half2float(from)));
214}
215
216template <>
217EIGEN_STRONG_INLINE half predux_min<PropagateNumbers>(const Packet16h& from) {
218 return half(predux_min<PropagateNumbers>(half2float(from)));
219}
220
221template <>
222EIGEN_STRONG_INLINE half predux_min<PropagateNaN>(const Packet16h& from) {
223 return half(predux_min<PropagateNaN>(half2float(from)));
224}
225
226template <>
227EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) {
228 return half(predux_max(half2float(from)));
229}
230
231template <>
232EIGEN_STRONG_INLINE half predux_max<PropagateNumbers>(const Packet16h& from) {
233 return half(predux_max<PropagateNumbers>(half2float(from)));
234}
235
236template <>
237EIGEN_STRONG_INLINE half predux_max<PropagateNaN>(const Packet16h& from) {
238 return half(predux_max<PropagateNaN>(half2float(from)));
239}
240
241template <>
242EIGEN_STRONG_INLINE bool predux_any(const Packet16h& a) {
243 return predux_any<Packet8i>(a.m_val);
244}
245#endif
246
247/* -- -- -- -- -- -- -- -- -- -- -- -- Packet16bf -- -- -- -- -- -- -- -- -- -- -- -- */
248
249template <>
250EIGEN_STRONG_INLINE bfloat16 predux(const Packet16bf& from) {
251 return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(from)));
252}
253
254template <>
255EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet16bf& from) {
256 return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
257}
258
259template <>
260EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) {
261 return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
262}
263
264template <>
265EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNumbers>(const Packet16bf& from) {
266 return static_cast<bfloat16>(predux_min<PropagateNumbers>(Bf16ToF32(from)));
267}
268
269template <>
270EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNaN>(const Packet16bf& from) {
271 return static_cast<bfloat16>(predux_min<PropagateNaN>(Bf16ToF32(from)));
272}
273
274template <>
275EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) {
276 return static_cast<bfloat16>(predux_max(Bf16ToF32(from)));
277}
278
279template <>
280EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNumbers>(const Packet16bf& from) {
281 return static_cast<bfloat16>(predux_max<PropagateNumbers>(Bf16ToF32(from)));
282}
283
284template <>
285EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNaN>(const Packet16bf& from) {
286 return static_cast<bfloat16>(predux_max<PropagateNaN>(Bf16ToF32(from)));
287}
288
289template <>
290EIGEN_STRONG_INLINE bool predux_any(const Packet16bf& a) {
291 return predux_any<Packet8i>(a.m_val);
292}
293
294} // end namespace internal
295} // end namespace Eigen
296
297#endif // EIGEN_REDUCTIONS_AVX512_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1