Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
Reductions.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_REDUCTIONS_AVX_H
11#define EIGEN_REDUCTIONS_AVX_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8i -- -- -- -- -- -- -- -- -- -- -- -- */
21
22template <>
23EIGEN_STRONG_INLINE int predux(const Packet8i& a) {
24 Packet4i lo = _mm256_castsi256_si128(a);
25 Packet4i hi = _mm256_extractf128_si256(a, 1);
26 return predux(padd(lo, hi));
27}
28
29template <>
30EIGEN_STRONG_INLINE int predux_mul(const Packet8i& a) {
31 Packet4i lo = _mm256_castsi256_si128(a);
32 Packet4i hi = _mm256_extractf128_si256(a, 1);
33 return predux_mul(pmul(lo, hi));
34}
35
36template <>
37EIGEN_STRONG_INLINE int predux_min(const Packet8i& a) {
38 Packet4i lo = _mm256_castsi256_si128(a);
39 Packet4i hi = _mm256_extractf128_si256(a, 1);
40 return predux_min(pmin(lo, hi));
41}
42
43template <>
44EIGEN_STRONG_INLINE int predux_max(const Packet8i& a) {
45 Packet4i lo = _mm256_castsi256_si128(a);
46 Packet4i hi = _mm256_extractf128_si256(a, 1);
47 return predux_max(pmax(lo, hi));
48}
49
50template <>
51EIGEN_STRONG_INLINE bool predux_any(const Packet8i& a) {
52#ifdef EIGEN_VECTORIZE_AVX2
53 return _mm256_movemask_epi8(a) != 0x0;
54#else
55 return _mm256_movemask_ps(_mm256_castsi256_ps(a)) != 0x0;
56#endif
57}
58
59/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8ui -- -- -- -- -- -- -- -- -- -- -- -- */
60
61template <>
62EIGEN_STRONG_INLINE uint32_t predux(const Packet8ui& a) {
63 Packet4ui lo = _mm256_castsi256_si128(a);
64 Packet4ui hi = _mm256_extractf128_si256(a, 1);
65 return predux(padd(lo, hi));
66}
67
68template <>
69EIGEN_STRONG_INLINE uint32_t predux_mul(const Packet8ui& a) {
70 Packet4ui lo = _mm256_castsi256_si128(a);
71 Packet4ui hi = _mm256_extractf128_si256(a, 1);
72 return predux_mul(pmul(lo, hi));
73}
74
75template <>
76EIGEN_STRONG_INLINE uint32_t predux_min(const Packet8ui& a) {
77 Packet4ui lo = _mm256_castsi256_si128(a);
78 Packet4ui hi = _mm256_extractf128_si256(a, 1);
79 return predux_min(pmin(lo, hi));
80}
81
82template <>
83EIGEN_STRONG_INLINE uint32_t predux_max(const Packet8ui& a) {
84 Packet4ui lo = _mm256_castsi256_si128(a);
85 Packet4ui hi = _mm256_extractf128_si256(a, 1);
86 return predux_max(pmax(lo, hi));
87}
88
89template <>
90EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& a) {
91#ifdef EIGEN_VECTORIZE_AVX2
92 return _mm256_movemask_epi8(a) != 0x0;
93#else
94 return _mm256_movemask_ps(_mm256_castsi256_ps(a)) != 0x0;
95#endif
96}
97
98#ifdef EIGEN_VECTORIZE_AVX2
99
100/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4l -- -- -- -- -- -- -- -- -- -- -- -- */
101
102template <>
103EIGEN_STRONG_INLINE int64_t predux(const Packet4l& a) {
104 Packet2l lo = _mm256_castsi256_si128(a);
105 Packet2l hi = _mm256_extractf128_si256(a, 1);
106 return predux(padd(lo, hi));
107}
108
109template <>
110EIGEN_STRONG_INLINE bool predux_any(const Packet4l& a) {
111 return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0x0;
112}
113
114/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4ul -- -- -- -- -- -- -- -- -- -- -- -- */
115
116template <>
117EIGEN_STRONG_INLINE uint64_t predux(const Packet4ul& a) {
118 return static_cast<uint64_t>(predux(Packet4l(a)));
119}
120
121template <>
122EIGEN_STRONG_INLINE bool predux_any(const Packet4ul& a) {
123 return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0x0;
124}
125
126#endif
127
128/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8f -- -- -- -- -- -- -- -- -- -- -- -- */
129
130template <>
131EIGEN_STRONG_INLINE float predux(const Packet8f& a) {
132 Packet4f lo = _mm256_castps256_ps128(a);
133 Packet4f hi = _mm256_extractf128_ps(a, 1);
134 return predux(padd(lo, hi));
135}
136
137template <>
138EIGEN_STRONG_INLINE float predux_mul(const Packet8f& a) {
139 Packet4f lo = _mm256_castps256_ps128(a);
140 Packet4f hi = _mm256_extractf128_ps(a, 1);
141 return predux_mul(pmul(lo, hi));
142}
143
144template <>
145EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) {
146 Packet4f lo = _mm256_castps256_ps128(a);
147 Packet4f hi = _mm256_extractf128_ps(a, 1);
148 return predux_min(pmin(lo, hi));
149}
150
151template <>
152EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet8f& a) {
153 Packet4f lo = _mm256_castps256_ps128(a);
154 Packet4f hi = _mm256_extractf128_ps(a, 1);
155 return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lo, hi));
156}
157
158template <>
159EIGEN_STRONG_INLINE float predux_min<PropagateNaN>(const Packet8f& a) {
160 Packet4f lo = _mm256_castps256_ps128(a);
161 Packet4f hi = _mm256_extractf128_ps(a, 1);
162 return predux_min<PropagateNaN>(pmin<PropagateNaN>(lo, hi));
163}
164
165template <>
166EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) {
167 Packet4f lo = _mm256_castps256_ps128(a);
168 Packet4f hi = _mm256_extractf128_ps(a, 1);
169 return predux_max(pmax(lo, hi));
170}
171
172template <>
173EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet8f& a) {
174 Packet4f lo = _mm256_castps256_ps128(a);
175 Packet4f hi = _mm256_extractf128_ps(a, 1);
176 return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lo, hi));
177}
178
179template <>
180EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet8f& a) {
181 Packet4f lo = _mm256_castps256_ps128(a);
182 Packet4f hi = _mm256_extractf128_ps(a, 1);
183 return predux_max<PropagateNaN>(pmax<PropagateNaN>(lo, hi));
184}
185
186template <>
187EIGEN_STRONG_INLINE bool predux_any(const Packet8f& a) {
188 return _mm256_movemask_ps(a) != 0x0;
189}
190
191/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4d -- -- -- -- -- -- -- -- -- -- -- -- */
192
193template <>
194EIGEN_STRONG_INLINE double predux(const Packet4d& a) {
195 Packet2d lo = _mm256_castpd256_pd128(a);
196 Packet2d hi = _mm256_extractf128_pd(a, 1);
197 return predux(padd(lo, hi));
198}
199
200template <>
201EIGEN_STRONG_INLINE double predux_mul(const Packet4d& a) {
202 Packet2d lo = _mm256_castpd256_pd128(a);
203 Packet2d hi = _mm256_extractf128_pd(a, 1);
204 return predux_mul(pmul(lo, hi));
205}
206
207template <>
208EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) {
209 Packet2d lo = _mm256_castpd256_pd128(a);
210 Packet2d hi = _mm256_extractf128_pd(a, 1);
211 return predux_min(pmin(lo, hi));
212}
213
214template <>
215EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet4d& a) {
216 Packet2d lo = _mm256_castpd256_pd128(a);
217 Packet2d hi = _mm256_extractf128_pd(a, 1);
218 return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lo, hi));
219}
220
221template <>
222EIGEN_STRONG_INLINE double predux_min<PropagateNaN>(const Packet4d& a) {
223 Packet2d lo = _mm256_castpd256_pd128(a);
224 Packet2d hi = _mm256_extractf128_pd(a, 1);
225 return predux_min<PropagateNaN>(pmin<PropagateNaN>(lo, hi));
226}
227
228template <>
229EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) {
230 Packet2d lo = _mm256_castpd256_pd128(a);
231 Packet2d hi = _mm256_extractf128_pd(a, 1);
232 return predux_max(pmax(lo, hi));
233}
234
235template <>
236EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet4d& a) {
237 Packet2d lo = _mm256_castpd256_pd128(a);
238 Packet2d hi = _mm256_extractf128_pd(a, 1);
239 return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lo, hi));
240}
241
242template <>
243EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet4d& a) {
244 Packet2d lo = _mm256_castpd256_pd128(a);
245 Packet2d hi = _mm256_extractf128_pd(a, 1);
246 return predux_max<PropagateNaN>(pmax<PropagateNaN>(lo, hi));
247}
248
249template <>
250EIGEN_STRONG_INLINE bool predux_any(const Packet4d& a) {
251 return _mm256_movemask_pd(a) != 0x0;
252}
253
254/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8h -- -- -- -- -- -- -- -- -- -- -- -- */
255#ifndef EIGEN_VECTORIZE_AVX512FP16
256
257template <>
258EIGEN_STRONG_INLINE half predux(const Packet8h& a) {
259 return static_cast<half>(predux(half2float(a)));
260}
261
262template <>
263EIGEN_STRONG_INLINE half predux_mul(const Packet8h& a) {
264 return static_cast<half>(predux_mul(half2float(a)));
265}
266
267template <>
268EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) {
269 return static_cast<half>(predux_min(half2float(a)));
270}
271
272template <>
273EIGEN_STRONG_INLINE half predux_min<PropagateNumbers>(const Packet8h& a) {
274 return static_cast<half>(predux_min<PropagateNumbers>(half2float(a)));
275}
276
277template <>
278EIGEN_STRONG_INLINE half predux_min<PropagateNaN>(const Packet8h& a) {
279 return static_cast<half>(predux_min<PropagateNaN>(half2float(a)));
280}
281
282template <>
283EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) {
284 return static_cast<half>(predux_max(half2float(a)));
285}
286
287template <>
288EIGEN_STRONG_INLINE half predux_max<PropagateNumbers>(const Packet8h& a) {
289 return static_cast<half>(predux_max<PropagateNumbers>(half2float(a)));
290}
291
292template <>
293EIGEN_STRONG_INLINE half predux_max<PropagateNaN>(const Packet8h& a) {
294 return static_cast<half>(predux_max<PropagateNaN>(half2float(a)));
295}
296
297template <>
298EIGEN_STRONG_INLINE bool predux_any(const Packet8h& a) {
299 return _mm_movemask_epi8(a) != 0;
300}
301#endif // EIGEN_VECTORIZE_AVX512FP16
302
303/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8bf -- -- -- -- -- -- -- -- -- -- -- -- */
304
305template <>
306EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) {
307 return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
308}
309
310template <>
311EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) {
312 return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
313}
314
315template <>
316EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) {
317 return static_cast<bfloat16>(predux_min(Bf16ToF32(a)));
318}
319
320template <>
321EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNumbers>(const Packet8bf& a) {
322 return static_cast<bfloat16>(predux_min<PropagateNumbers>(Bf16ToF32(a)));
323}
324
325template <>
326EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNaN>(const Packet8bf& a) {
327 return static_cast<bfloat16>(predux_min<PropagateNaN>(Bf16ToF32(a)));
328}
329
330template <>
331EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) {
332 return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
333}
334
335template <>
336EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNumbers>(const Packet8bf& a) {
337 return static_cast<bfloat16>(predux_max<PropagateNumbers>(Bf16ToF32(a)));
338}
339
340template <>
341EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNaN>(const Packet8bf& a) {
342 return static_cast<bfloat16>(predux_max<PropagateNaN>(Bf16ToF32(a)));
343}
344
345template <>
346EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& a) {
347 return _mm_movemask_epi8(a) != 0;
348}
349
350} // end namespace internal
351} // end namespace Eigen
352
353#endif // EIGEN_REDUCTIONS_AVX_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1