Eigen  5.0.1-dev+7c7d8473
 
Loading...
Searching...
No Matches
PacketMath.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020, Arm Limited and Contributors
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_SVE_H
11#define EIGEN_PACKET_MATH_SVE_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17namespace internal {
18#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
19#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
20#endif
21
22#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
23#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
24#endif
25
26#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
27
28template <typename Scalar, int SVEVectorLength>
29struct sve_packet_size_selector {
30 enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
31};
32
33/********************************* int32 **************************************/
34typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
35
36template <>
37struct packet_traits<numext::int32_t> : default_packet_traits {
38 typedef PacketXi type;
39 typedef PacketXi half; // Half not implemented yet
40 enum {
41 Vectorizable = 1,
42 AlignedOnScalar = 1,
43 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
44
45 HasAdd = 1,
46 HasSub = 1,
47 HasShift = 1,
48 HasMul = 1,
49 HasNegate = 1,
50 HasAbs = 1,
51 HasArg = 0,
52 HasMin = 1,
53 HasMax = 1,
54 HasConj = 1,
55 HasSetLinear = 0,
56 HasReduxp = 0 // Not implemented in SVE
57 };
58};
59
60template <>
61struct unpacket_traits<PacketXi> {
62 typedef numext::int32_t type;
63 typedef PacketXi half; // Half not yet implemented
64 enum {
65 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
66 alignment = Aligned64,
67 vectorizable = true,
68 masked_load_available = false,
69 masked_store_available = false
70 };
71};
72
73template <>
74EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr) {
75 svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
76}
77
78template <>
79EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from) {
80 return svdup_n_s32(from);
81}
82
83template <>
84EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a) {
85 numext::int32_t c[packet_traits<numext::int32_t>::size];
86 for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
87 return svadd_s32_x(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
88}
89
90template <>
91EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b) {
92 return svadd_s32_x(svptrue_b32(), a, b);
93}
94
95template <>
96EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b) {
97 return svsub_s32_x(svptrue_b32(), a, b);
98}
99
100template <>
101EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) {
102 return svneg_s32_x(svptrue_b32(), a);
103}
104
105template <>
106EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) {
107 return a;
108}
109
110template <>
111EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b) {
112 return svmul_s32_x(svptrue_b32(), a, b);
113}
114
115template <>
116EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b) {
117 return svdiv_s32_x(svptrue_b32(), a, b);
118}
119
120template <>
121EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) {
122 return svmla_s32_x(svptrue_b32(), c, a, b);
123}
124
125template <>
126EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b) {
127 return svmin_s32_x(svptrue_b32(), a, b);
128}
129
130template <>
131EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b) {
132 return svmax_s32_x(svptrue_b32(), a, b);
133}
134
135template <>
136EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b) {
137 return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu);
138}
139
140template <>
141EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b) {
142 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
143}
144
145template <>
146EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b) {
147 return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
148}
149
150template <>
151EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/) {
152 return svdup_n_s32_x(svptrue_b32(), 0xffffffffu);
153}
154
155template <>
156EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/) {
157 return svdup_n_s32_x(svptrue_b32(), 0);
158}
159
160template <>
161EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b) {
162 return svand_s32_x(svptrue_b32(), a, b);
163}
164
165template <>
166EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b) {
167 return svorr_s32_x(svptrue_b32(), a, b);
168}
169
170template <>
171EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b) {
172 return sveor_s32_x(svptrue_b32(), a, b);
173}
174
175template <>
176EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b) {
177 return svbic_s32_x(svptrue_b32(), a, b);
178}
179
180template <int N>
181EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) {
182 return svasrd_n_s32_x(svptrue_b32(), a, N);
183}
184
185template <int N>
186EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) {
187 return svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), svreinterpret_u32_s32(a), N));
188}
189
190template <int N>
191EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) {
192 return svlsl_n_s32_x(svptrue_b32(), a, N);
193}
194
195template <>
196EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from) {
197 EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
198}
199
200template <>
201EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from) {
202 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
203}
204
205template <>
206EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from) {
207 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
208 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
209 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
210}
211
212template <>
213EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from) {
214 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
215 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
216 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
217 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
218}
219
220template <>
221EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
222 EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
223}
224
225template <>
226EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
227 EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
228}
229
230template <>
231EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride) {
232 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
233 svint32_t indices = svindex_s32(0, stride);
234 return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
235}
236
237template <>
238EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from,
239 Index stride) {
240 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
241 svint32_t indices = svindex_s32(0, stride);
242 svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
243}
244
245template <>
246EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a) {
247 // svlasta returns the first element if all predicate bits are 0
248 return svlasta_s32(svpfalse_b(), a);
249}
250
251template <>
252EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) {
253 return svrev_s32(a);
254}
255
256template <>
257EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) {
258 return svabs_s32_x(svptrue_b32(), a);
259}
260
261template <>
262EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a) {
263 return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
264}
265
266template <>
267EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a) {
268 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
269
270 // Multiply the vector by its reverse
271 svint32_t prod = svmul_s32_x(svptrue_b32(), a, svrev_s32(a));
272 svint32_t half_prod;
273
274 // Extract the high half of the vector. Depending on the VL more reductions need to be done
275 if (EIGEN_ARM64_SVE_VL >= 2048) {
276 half_prod = svtbl_s32(prod, svindex_u32(32, 1));
277 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
278 }
279 if (EIGEN_ARM64_SVE_VL >= 1024) {
280 half_prod = svtbl_s32(prod, svindex_u32(16, 1));
281 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
282 }
283 if (EIGEN_ARM64_SVE_VL >= 512) {
284 half_prod = svtbl_s32(prod, svindex_u32(8, 1));
285 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
286 }
287 if (EIGEN_ARM64_SVE_VL >= 256) {
288 half_prod = svtbl_s32(prod, svindex_u32(4, 1));
289 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
290 }
291 // Last reduction
292 half_prod = svtbl_s32(prod, svindex_u32(2, 1));
293 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
294
295 // The reduction is done to the first element.
296 return pfirst<PacketXi>(prod);
297}
298
299template <>
300EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a) {
301 return svminv_s32(svptrue_b32(), a);
302}
303
304template <>
305EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a) {
306 return svmaxv_s32(svptrue_b32(), a);
307}
308
309template <int N>
310EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
311 int buffer[packet_traits<numext::int32_t>::size * N] = {0};
312 int i = 0;
313
314 PacketXi stride_index = svindex_s32(0, N);
315
316 for (i = 0; i < N; i++) {
317 svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
318 }
319 for (i = 0; i < N; i++) {
320 kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
321 }
322}
323
324/********************************* float32 ************************************/
325
326typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
327
328template <>
329struct packet_traits<float> : default_packet_traits {
330 typedef PacketXf type;
331 typedef PacketXf half;
332
333 enum {
334 Vectorizable = 1,
335 AlignedOnScalar = 1,
336 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
337
338 HasAdd = 1,
339 HasSub = 1,
340 HasShift = 1,
341 HasMul = 1,
342 HasNegate = 1,
343 HasAbs = 1,
344 HasArg = 0,
345 HasMin = 1,
346 HasMax = 1,
347 HasConj = 1,
348 HasSetLinear = 0,
349 HasReduxp = 0, // Not implemented in SVE
350
351 HasDiv = 1,
352
353 HasCmp = 1,
354 HasSin = EIGEN_FAST_MATH,
355 HasCos = EIGEN_FAST_MATH,
356 HasLog = 1,
357 HasExp = 1,
358 HasPow = 1,
359 HasSqrt = 1,
360 HasTanh = EIGEN_FAST_MATH,
361 HasErf = EIGEN_FAST_MATH,
362 HasErfc = EIGEN_FAST_MATH
363 };
364};
365
366template <>
367struct unpacket_traits<PacketXf> {
368 typedef float type;
369 typedef PacketXf half; // Half not yet implemented
370 typedef PacketXi integer_packet;
371
372 enum {
373 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
374 alignment = Aligned64,
375 vectorizable = true,
376 masked_load_available = false,
377 masked_store_available = false
378 };
379};
380
381template <>
382EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from) {
383 return svdup_n_f32(from);
384}
385
386template <>
387EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from) {
388 return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), from));
389}
390
391template <>
392EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a) {
393 float c[packet_traits<float>::size];
394 for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
395 return svadd_f32_x(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
396}
397
398template <>
399EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b) {
400 return svadd_f32_x(svptrue_b32(), a, b);
401}
402
403template <>
404EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b) {
405 return svsub_f32_x(svptrue_b32(), a, b);
406}
407
408template <>
409EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) {
410 return svneg_f32_x(svptrue_b32(), a);
411}
412
413template <>
414EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) {
415 return a;
416}
417
418template <>
419EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b) {
420 return svmul_f32_x(svptrue_b32(), a, b);
421}
422
423template <>
424EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b) {
425 return svdiv_f32_x(svptrue_b32(), a, b);
426}
427
428template <>
429EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) {
430 return svmla_f32_x(svptrue_b32(), c, a, b);
431}
432
433template <>
434EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b) {
435 return svmin_f32_x(svptrue_b32(), a, b);
436}
437
438template <>
439EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
440 return pmin<PacketXf>(a, b);
441}
442
443template <>
444EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
445 return svminnm_f32_x(svptrue_b32(), a, b);
446}
447
448template <>
449EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b) {
450 return svmax_f32_x(svptrue_b32(), a, b);
451}
452
453template <>
454EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
455 return pmax<PacketXf>(a, b);
456}
457
458template <>
459EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
460 return svmaxnm_f32_x(svptrue_b32(), a, b);
461}
462
463// Float comparisons in SVE return svbool (predicate). Use svdup to set active
464// lanes to 1 (0xffffffffu) and inactive lanes to 0.
465template <>
466EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b) {
467 return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu));
468}
469
470template <>
471EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b) {
472 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
473}
474
475template <>
476EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b) {
477 return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
478}
479
480// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
481// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
482// active elements.
483template <>
484EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b) {
485 return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
486}
487
488template <>
489EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a) {
490 return svrintm_f32_x(svptrue_b32(), a);
491}
492
493template <>
494EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/) {
495 return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), 0xffffffffu));
496}
497
498// Logical Operations are not supported for float, so reinterpret casts
499template <>
500EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b) {
501 return svreinterpret_f32_u32(svand_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
502}
503
504template <>
505EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b) {
506 return svreinterpret_f32_u32(svorr_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
507}
508
509template <>
510EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b) {
511 return svreinterpret_f32_u32(sveor_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
512}
513
514template <>
515EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b) {
516 return svreinterpret_f32_u32(svbic_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
517}
518
519template <>
520EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from) {
521 EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
522}
523
524template <>
525EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from) {
526 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
527}
528
529template <>
530EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from) {
531 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
532 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
533 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
534}
535
536template <>
537EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from) {
538 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
539 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
540 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
541 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
542}
543
544template <>
545EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from) {
546 EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
547}
548
549template <>
550EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from) {
551 EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
552}
553
554template <>
555EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride) {
556 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
557 svint32_t indices = svindex_s32(0, stride);
558 return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
559}
560
561template <>
562EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride) {
563 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
564 svint32_t indices = svindex_s32(0, stride);
565 svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
566}
567
568template <>
569EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a) {
570 // svlasta returns the first element if all predicate bits are 0
571 return svlasta_f32(svpfalse_b(), a);
572}
573
574template <>
575EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) {
576 return svrev_f32(a);
577}
578
579template <>
580EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) {
581 return svabs_f32_x(svptrue_b32(), a);
582}
583
584// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
585// all vector extensions and the generic version.
586template <>
587EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent) {
588 return pfrexp_generic(a, exponent);
589}
590
591template <>
592EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a) {
593 return svaddv_f32(svptrue_b32(), a);
594}
595
596// Other reduction functions:
597// mul
598// Only works for SVE Vls multiple of 128
599template <>
600EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a) {
601 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
602 // Multiply the vector by its reverse
603 svfloat32_t prod = svmul_f32_x(svptrue_b32(), a, svrev_f32(a));
604 svfloat32_t half_prod;
605
606 // Extract the high half of the vector. Depending on the VL more reductions need to be done
607 if (EIGEN_ARM64_SVE_VL >= 2048) {
608 half_prod = svtbl_f32(prod, svindex_u32(32, 1));
609 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
610 }
611 if (EIGEN_ARM64_SVE_VL >= 1024) {
612 half_prod = svtbl_f32(prod, svindex_u32(16, 1));
613 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
614 }
615 if (EIGEN_ARM64_SVE_VL >= 512) {
616 half_prod = svtbl_f32(prod, svindex_u32(8, 1));
617 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
618 }
619 if (EIGEN_ARM64_SVE_VL >= 256) {
620 half_prod = svtbl_f32(prod, svindex_u32(4, 1));
621 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
622 }
623 // Last reduction
624 half_prod = svtbl_f32(prod, svindex_u32(2, 1));
625 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
626
627 // The reduction is done to the first element.
628 return pfirst<PacketXf>(prod);
629}
630
631template <>
632EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a) {
633 return svminv_f32(svptrue_b32(), a);
634}
635
636template <>
637EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a) {
638 return svmaxv_f32(svptrue_b32(), a);
639}
640
641template <int N>
642EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel) {
643 float buffer[packet_traits<float>::size * N] = {0};
644 int i = 0;
645
646 PacketXi stride_index = svindex_s32(0, N);
647
648 for (i = 0; i < N; i++) {
649 svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
650 }
651
652 for (i = 0; i < N; i++) {
653 kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
654 }
655}
656
657template <>
658EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent) {
659 return pldexp_generic(a, exponent);
660}
661
662template <>
663EIGEN_STRONG_INLINE PacketXf psqrt<PacketXf>(const PacketXf& a) {
664 return svsqrt_f32_x(svptrue_b32(), a);
665}
666
667} // namespace internal
668} // namespace Eigen
669
670#endif // EIGEN_PACKET_MATH_SVE_H
@ Aligned64
Definition Constants.h:239
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82