Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
PacketMath.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020, Arm Limited and Contributors
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_SVE_H
11#define EIGEN_PACKET_MATH_SVE_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17namespace internal {
18#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
19#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
20#endif
21
22#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
23#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
24#endif
25
26#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
27
28template <typename Scalar, int SVEVectorLength>
29struct sve_packet_size_selector {
30 enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
31};
32
33/********************************* int32 **************************************/
34typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
35
36template <>
37struct packet_traits<numext::int32_t> : default_packet_traits {
38 typedef PacketXi type;
39 typedef PacketXi half; // Half not implemented yet
40 enum {
41 Vectorizable = 1,
42 AlignedOnScalar = 1,
43 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
44
45 HasAdd = 1,
46 HasSub = 1,
47 HasShift = 1,
48 HasMul = 1,
49 HasNegate = 1,
50 HasAbs = 1,
51 HasArg = 0,
52 HasAbs2 = 1,
53 HasMin = 1,
54 HasMax = 1,
55 HasConj = 1,
56 HasSetLinear = 0,
57 HasBlend = 0,
58 HasReduxp = 0 // Not implemented in SVE
59 };
60};
61
62template <>
63struct unpacket_traits<PacketXi> {
64 typedef numext::int32_t type;
65 typedef PacketXi half; // Half not yet implemented
66 enum {
67 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
68 alignment = Aligned64,
69 vectorizable = true,
70 masked_load_available = false,
71 masked_store_available = false
72 };
73};
74
75template <>
76EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr) {
77 svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
78}
79
80template <>
81EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from) {
82 return svdup_n_s32(from);
83}
84
85template <>
86EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a) {
87 numext::int32_t c[packet_traits<numext::int32_t>::size];
88 for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
89 return svadd_s32_x(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
90}
91
92template <>
93EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b) {
94 return svadd_s32_x(svptrue_b32(), a, b);
95}
96
97template <>
98EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b) {
99 return svsub_s32_x(svptrue_b32(), a, b);
100}
101
102template <>
103EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) {
104 return svneg_s32_x(svptrue_b32(), a);
105}
106
107template <>
108EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) {
109 return a;
110}
111
112template <>
113EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b) {
114 return svmul_s32_x(svptrue_b32(), a, b);
115}
116
117template <>
118EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b) {
119 return svdiv_s32_x(svptrue_b32(), a, b);
120}
121
122template <>
123EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) {
124 return svmla_s32_x(svptrue_b32(), c, a, b);
125}
126
127template <>
128EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b) {
129 return svmin_s32_x(svptrue_b32(), a, b);
130}
131
132template <>
133EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b) {
134 return svmax_s32_x(svptrue_b32(), a, b);
135}
136
137template <>
138EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b) {
139 return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu);
140}
141
142template <>
143EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b) {
144 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
145}
146
147template <>
148EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b) {
149 return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
150}
151
152template <>
153EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/) {
154 return svdup_n_s32_x(svptrue_b32(), 0xffffffffu);
155}
156
157template <>
158EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/) {
159 return svdup_n_s32_x(svptrue_b32(), 0);
160}
161
162template <>
163EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b) {
164 return svand_s32_x(svptrue_b32(), a, b);
165}
166
167template <>
168EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b) {
169 return svorr_s32_x(svptrue_b32(), a, b);
170}
171
172template <>
173EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b) {
174 return sveor_s32_x(svptrue_b32(), a, b);
175}
176
177template <>
178EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b) {
179 return svbic_s32_x(svptrue_b32(), a, b);
180}
181
182template <int N>
183EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) {
184 return svasrd_n_s32_x(svptrue_b32(), a, N);
185}
186
187template <int N>
188EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) {
189 return svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), svreinterpret_u32_s32(a), N));
190}
191
192template <int N>
193EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) {
194 return svlsl_n_s32_x(svptrue_b32(), a, N);
195}
196
197template <>
198EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from) {
199 EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
200}
201
202template <>
203EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from) {
204 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
205}
206
207template <>
208EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from) {
209 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
210 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
211 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
212}
213
214template <>
215EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from) {
216 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
217 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
218 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
219 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
220}
221
222template <>
223EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
224 EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
225}
226
227template <>
228EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
229 EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
230}
231
232template <>
233EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride) {
234 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
235 svint32_t indices = svindex_s32(0, stride);
236 return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
237}
238
239template <>
240EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from,
241 Index stride) {
242 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
243 svint32_t indices = svindex_s32(0, stride);
244 svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
245}
246
247template <>
248EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a) {
249 // svlasta returns the first element if all predicate bits are 0
250 return svlasta_s32(svpfalse_b(), a);
251}
252
253template <>
254EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) {
255 return svrev_s32(a);
256}
257
258template <>
259EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) {
260 return svabs_s32_x(svptrue_b32(), a);
261}
262
263template <>
264EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a) {
265 return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
266}
267
268template <>
269EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a) {
270 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
271
272 // Multiply the vector by its reverse
273 svint32_t prod = svmul_s32_x(svptrue_b32(), a, svrev_s32(a));
274 svint32_t half_prod;
275
276 // Extract the high half of the vector. Depending on the VL more reductions need to be done
277 if (EIGEN_ARM64_SVE_VL >= 2048) {
278 half_prod = svtbl_s32(prod, svindex_u32(32, 1));
279 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
280 }
281 if (EIGEN_ARM64_SVE_VL >= 1024) {
282 half_prod = svtbl_s32(prod, svindex_u32(16, 1));
283 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
284 }
285 if (EIGEN_ARM64_SVE_VL >= 512) {
286 half_prod = svtbl_s32(prod, svindex_u32(8, 1));
287 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
288 }
289 if (EIGEN_ARM64_SVE_VL >= 256) {
290 half_prod = svtbl_s32(prod, svindex_u32(4, 1));
291 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
292 }
293 // Last reduction
294 half_prod = svtbl_s32(prod, svindex_u32(2, 1));
295 prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
296
297 // The reduction is done to the first element.
298 return pfirst<PacketXi>(prod);
299}
300
301template <>
302EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a) {
303 return svminv_s32(svptrue_b32(), a);
304}
305
306template <>
307EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a) {
308 return svmaxv_s32(svptrue_b32(), a);
309}
310
311template <int N>
312EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
313 int buffer[packet_traits<numext::int32_t>::size * N] = {0};
314 int i = 0;
315
316 PacketXi stride_index = svindex_s32(0, N);
317
318 for (i = 0; i < N; i++) {
319 svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
320 }
321 for (i = 0; i < N; i++) {
322 kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
323 }
324}
325
326/********************************* float32 ************************************/
327
328typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
329
330template <>
331struct packet_traits<float> : default_packet_traits {
332 typedef PacketXf type;
333 typedef PacketXf half;
334
335 enum {
336 Vectorizable = 1,
337 AlignedOnScalar = 1,
338 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
339
340 HasAdd = 1,
341 HasSub = 1,
342 HasShift = 1,
343 HasMul = 1,
344 HasNegate = 1,
345 HasAbs = 1,
346 HasArg = 0,
347 HasAbs2 = 1,
348 HasMin = 1,
349 HasMax = 1,
350 HasConj = 1,
351 HasSetLinear = 0,
352 HasBlend = 0,
353 HasReduxp = 0, // Not implemented in SVE
354
355 HasDiv = 1,
356
357 HasCmp = 1,
358 HasSin = EIGEN_FAST_MATH,
359 HasCos = EIGEN_FAST_MATH,
360 HasLog = 1,
361 HasExp = 1,
362 HasPow = 1,
363 HasSqrt = 1,
364 HasTanh = EIGEN_FAST_MATH,
365 HasErf = EIGEN_FAST_MATH,
366 HasErfc = EIGEN_FAST_MATH
367 };
368};
369
370template <>
371struct unpacket_traits<PacketXf> {
372 typedef float type;
373 typedef PacketXf half; // Half not yet implemented
374 typedef PacketXi integer_packet;
375
376 enum {
377 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
378 alignment = Aligned64,
379 vectorizable = true,
380 masked_load_available = false,
381 masked_store_available = false
382 };
383};
384
385template <>
386EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from) {
387 return svdup_n_f32(from);
388}
389
390template <>
391EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from) {
392 return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), from));
393}
394
395template <>
396EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a) {
397 float c[packet_traits<float>::size];
398 for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
399 return svadd_f32_x(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
400}
401
402template <>
403EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b) {
404 return svadd_f32_x(svptrue_b32(), a, b);
405}
406
407template <>
408EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b) {
409 return svsub_f32_x(svptrue_b32(), a, b);
410}
411
412template <>
413EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) {
414 return svneg_f32_x(svptrue_b32(), a);
415}
416
417template <>
418EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) {
419 return a;
420}
421
422template <>
423EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b) {
424 return svmul_f32_x(svptrue_b32(), a, b);
425}
426
427template <>
428EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b) {
429 return svdiv_f32_x(svptrue_b32(), a, b);
430}
431
432template <>
433EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) {
434 return svmla_f32_x(svptrue_b32(), c, a, b);
435}
436
437template <>
438EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b) {
439 return svmin_f32_x(svptrue_b32(), a, b);
440}
441
442template <>
443EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
444 return pmin<PacketXf>(a, b);
445}
446
447template <>
448EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
449 return svminnm_f32_x(svptrue_b32(), a, b);
450}
451
452template <>
453EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b) {
454 return svmax_f32_x(svptrue_b32(), a, b);
455}
456
457template <>
458EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
459 return pmax<PacketXf>(a, b);
460}
461
462template <>
463EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
464 return svmaxnm_f32_x(svptrue_b32(), a, b);
465}
466
467// Float comparisons in SVE return svbool (predicate). Use svdup to set active
468// lanes to 1 (0xffffffffu) and inactive lanes to 0.
469template <>
470EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b) {
471 return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu));
472}
473
474template <>
475EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b) {
476 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
477}
478
479template <>
480EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b) {
481 return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
482}
483
484// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
485// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
486// active elements.
487template <>
488EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b) {
489 return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
490}
491
492template <>
493EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a) {
494 return svrintm_f32_x(svptrue_b32(), a);
495}
496
497template <>
498EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/) {
499 return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), 0xffffffffu));
500}
501
502// Logical Operations are not supported for float, so reinterpret casts
503template <>
504EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b) {
505 return svreinterpret_f32_u32(svand_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
506}
507
508template <>
509EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b) {
510 return svreinterpret_f32_u32(svorr_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
511}
512
513template <>
514EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b) {
515 return svreinterpret_f32_u32(sveor_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
516}
517
518template <>
519EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b) {
520 return svreinterpret_f32_u32(svbic_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
521}
522
523template <>
524EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from) {
525 EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
526}
527
528template <>
529EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from) {
530 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
531}
532
533template <>
534EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from) {
535 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
536 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
537 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
538}
539
540template <>
541EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from) {
542 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
543 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
544 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
545 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
546}
547
548template <>
549EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from) {
550 EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
551}
552
553template <>
554EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from) {
555 EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
556}
557
558template <>
559EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride) {
560 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
561 svint32_t indices = svindex_s32(0, stride);
562 return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
563}
564
565template <>
566EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride) {
567 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
568 svint32_t indices = svindex_s32(0, stride);
569 svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
570}
571
572template <>
573EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a) {
574 // svlasta returns the first element if all predicate bits are 0
575 return svlasta_f32(svpfalse_b(), a);
576}
577
578template <>
579EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) {
580 return svrev_f32(a);
581}
582
583template <>
584EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) {
585 return svabs_f32_x(svptrue_b32(), a);
586}
587
588// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
589// all vector extensions and the generic version.
590template <>
591EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent) {
592 return pfrexp_generic(a, exponent);
593}
594
595template <>
596EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a) {
597 return svaddv_f32(svptrue_b32(), a);
598}
599
600// Other reduction functions:
601// mul
602// Only works for SVE Vls multiple of 128
603template <>
604EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a) {
605 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
606 // Multiply the vector by its reverse
607 svfloat32_t prod = svmul_f32_x(svptrue_b32(), a, svrev_f32(a));
608 svfloat32_t half_prod;
609
610 // Extract the high half of the vector. Depending on the VL more reductions need to be done
611 if (EIGEN_ARM64_SVE_VL >= 2048) {
612 half_prod = svtbl_f32(prod, svindex_u32(32, 1));
613 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
614 }
615 if (EIGEN_ARM64_SVE_VL >= 1024) {
616 half_prod = svtbl_f32(prod, svindex_u32(16, 1));
617 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
618 }
619 if (EIGEN_ARM64_SVE_VL >= 512) {
620 half_prod = svtbl_f32(prod, svindex_u32(8, 1));
621 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
622 }
623 if (EIGEN_ARM64_SVE_VL >= 256) {
624 half_prod = svtbl_f32(prod, svindex_u32(4, 1));
625 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
626 }
627 // Last reduction
628 half_prod = svtbl_f32(prod, svindex_u32(2, 1));
629 prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
630
631 // The reduction is done to the first element.
632 return pfirst<PacketXf>(prod);
633}
634
635template <>
636EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a) {
637 return svminv_f32(svptrue_b32(), a);
638}
639
640template <>
641EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a) {
642 return svmaxv_f32(svptrue_b32(), a);
643}
644
645template <int N>
646EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel) {
647 float buffer[packet_traits<float>::size * N] = {0};
648 int i = 0;
649
650 PacketXi stride_index = svindex_s32(0, N);
651
652 for (i = 0; i < N; i++) {
653 svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
654 }
655
656 for (i = 0; i < N; i++) {
657 kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
658 }
659}
660
661template <>
662EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent) {
663 return pldexp_generic(a, exponent);
664}
665
666template <>
667EIGEN_STRONG_INLINE PacketXf psqrt<PacketXf>(const PacketXf& a) {
668 return svsqrt_f32_x(svptrue_b32(), a);
669}
670
671} // namespace internal
672} // namespace Eigen
673
674#endif // EIGEN_PACKET_MATH_SVE_H
@ Aligned64
Definition Constants.h:239
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82