$darkmode
Eigen  5.0.1-dev
PacketMath.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020, Arm Limited and Contributors
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_PACKET_MATH_SVE_H
11 #define EIGEN_PACKET_MATH_SVE_H
12 
13 // IWYU pragma: private
14 #include "../../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 namespace internal {
18 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
19 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
20 #endif
21 
22 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
23 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
24 #endif
25 
26 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
27 
28 template <typename Scalar, int SVEVectorLength>
29 struct sve_packet_size_selector {
30  enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
31 };
32 
33 /********************************* int32 **************************************/
34 typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
35 
36 template <>
37 struct packet_traits<numext::int32_t> : default_packet_traits {
38  typedef PacketXi type;
39  typedef PacketXi half; // Half not implemented yet
40  enum {
41  Vectorizable = 1,
42  AlignedOnScalar = 1,
43  size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
44 
45  HasAdd = 1,
46  HasSub = 1,
47  HasShift = 1,
48  HasMul = 1,
49  HasNegate = 1,
50  HasAbs = 1,
51  HasArg = 0,
52  HasAbs2 = 1,
53  HasMin = 1,
54  HasMax = 1,
55  HasConj = 1,
56  HasSetLinear = 0,
57  HasBlend = 0,
58  HasReduxp = 0 // Not implemented in SVE
59  };
60 };
61 
62 template <>
63 struct unpacket_traits<PacketXi> {
64  typedef numext::int32_t type;
65  typedef PacketXi half; // Half not yet implemented
66  enum {
67  size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
68  alignment = Aligned64,
69  vectorizable = true,
70  masked_load_available = false,
71  masked_store_available = false
72  };
73 };
74 
75 template <>
76 EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr) {
77  svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
78 }
79 
80 template <>
81 EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from) {
82  return svdup_n_s32(from);
83 }
84 
85 template <>
86 EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a) {
87  numext::int32_t c[packet_traits<numext::int32_t>::size];
88  for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
89  return svadd_s32_x(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
90 }
91 
92 template <>
93 EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b) {
94  return svadd_s32_x(svptrue_b32(), a, b);
95 }
96 
97 template <>
98 EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b) {
99  return svsub_s32_x(svptrue_b32(), a, b);
100 }
101 
102 template <>
103 EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) {
104  return svneg_s32_x(svptrue_b32(), a);
105 }
106 
107 template <>
108 EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) {
109  return a;
110 }
111 
112 template <>
113 EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b) {
114  return svmul_s32_x(svptrue_b32(), a, b);
115 }
116 
117 template <>
118 EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b) {
119  return svdiv_s32_x(svptrue_b32(), a, b);
120 }
121 
122 template <>
123 EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) {
124  return svmla_s32_x(svptrue_b32(), c, a, b);
125 }
126 
127 template <>
128 EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b) {
129  return svmin_s32_x(svptrue_b32(), a, b);
130 }
131 
132 template <>
133 EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b) {
134  return svmax_s32_x(svptrue_b32(), a, b);
135 }
136 
137 template <>
138 EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b) {
139  return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu);
140 }
141 
142 template <>
143 EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b) {
144  return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
145 }
146 
147 template <>
148 EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b) {
149  return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
150 }
151 
152 template <>
153 EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/) {
154  return svdup_n_s32_x(svptrue_b32(), 0xffffffffu);
155 }
156 
157 template <>
158 EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/) {
159  return svdup_n_s32_x(svptrue_b32(), 0);
160 }
161 
162 template <>
163 EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b) {
164  return svand_s32_x(svptrue_b32(), a, b);
165 }
166 
167 template <>
168 EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b) {
169  return svorr_s32_x(svptrue_b32(), a, b);
170 }
171 
172 template <>
173 EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b) {
174  return sveor_s32_x(svptrue_b32(), a, b);
175 }
176 
177 template <>
178 EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b) {
179  return svbic_s32_x(svptrue_b32(), a, b);
180 }
181 
182 template <int N>
183 EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) {
184  return svasrd_n_s32_x(svptrue_b32(), a, N);
185 }
186 
187 template <int N>
188 EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) {
189  return svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), svreinterpret_u32_s32(a), N));
190 }
191 
192 template <int N>
193 EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) {
194  return svlsl_n_s32_x(svptrue_b32(), a, N);
195 }
196 
197 template <>
198 EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from) {
199  EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
200 }
201 
202 template <>
203 EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from) {
204  EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
205 }
206 
207 template <>
208 EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from) {
209  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
210  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
211  return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
212 }
213 
214 template <>
215 EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from) {
216  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
217  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
218  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
219  return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
220 }
221 
222 template <>
223 EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
224  EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
225 }
226 
227 template <>
228 EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from) {
229  EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
230 }
231 
232 template <>
233 EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride) {
234  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
235  svint32_t indices = svindex_s32(0, stride);
236  return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
237 }
238 
239 template <>
240 EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from,
241  Index stride) {
242  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
243  svint32_t indices = svindex_s32(0, stride);
244  svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
245 }
246 
247 template <>
248 EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a) {
249  // svlasta returns the first element if all predicate bits are 0
250  return svlasta_s32(svpfalse_b(), a);
251 }
252 
253 template <>
254 EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) {
255  return svrev_s32(a);
256 }
257 
258 template <>
259 EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) {
260  return svabs_s32_x(svptrue_b32(), a);
261 }
262 
263 template <>
264 EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a) {
265  return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
266 }
267 
268 template <>
269 EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a) {
270  EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
271 
272  // Multiply the vector by its reverse
273  svint32_t prod = svmul_s32_x(svptrue_b32(), a, svrev_s32(a));
274  svint32_t half_prod;
275 
276  // Extract the high half of the vector. Depending on the VL more reductions need to be done
277  if (EIGEN_ARM64_SVE_VL >= 2048) {
278  half_prod = svtbl_s32(prod, svindex_u32(32, 1));
279  prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
280  }
281  if (EIGEN_ARM64_SVE_VL >= 1024) {
282  half_prod = svtbl_s32(prod, svindex_u32(16, 1));
283  prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
284  }
285  if (EIGEN_ARM64_SVE_VL >= 512) {
286  half_prod = svtbl_s32(prod, svindex_u32(8, 1));
287  prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
288  }
289  if (EIGEN_ARM64_SVE_VL >= 256) {
290  half_prod = svtbl_s32(prod, svindex_u32(4, 1));
291  prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
292  }
293  // Last reduction
294  half_prod = svtbl_s32(prod, svindex_u32(2, 1));
295  prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
296 
297  // The reduction is done to the first element.
298  return pfirst<PacketXi>(prod);
299 }
300 
301 template <>
302 EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a) {
303  return svminv_s32(svptrue_b32(), a);
304 }
305 
306 template <>
307 EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a) {
308  return svmaxv_s32(svptrue_b32(), a);
309 }
310 
311 template <int N>
312 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
313  int buffer[packet_traits<numext::int32_t>::size * N] = {0};
314  int i = 0;
315 
316  PacketXi stride_index = svindex_s32(0, N);
317 
318  for (i = 0; i < N; i++) {
319  svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
320  }
321  for (i = 0; i < N; i++) {
322  kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
323  }
324 }
325 
326 /********************************* float32 ************************************/
327 
328 typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
329 
330 template <>
331 struct packet_traits<float> : default_packet_traits {
332  typedef PacketXf type;
333  typedef PacketXf half;
334 
335  enum {
336  Vectorizable = 1,
337  AlignedOnScalar = 1,
338  size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
339 
340  HasAdd = 1,
341  HasSub = 1,
342  HasShift = 1,
343  HasMul = 1,
344  HasNegate = 1,
345  HasAbs = 1,
346  HasArg = 0,
347  HasAbs2 = 1,
348  HasMin = 1,
349  HasMax = 1,
350  HasConj = 1,
351  HasSetLinear = 0,
352  HasBlend = 0,
353  HasReduxp = 0, // Not implemented in SVE
354 
355  HasDiv = 1,
356 
357  HasCmp = 1,
358  HasSin = EIGEN_FAST_MATH,
359  HasCos = EIGEN_FAST_MATH,
360  HasLog = 1,
361  HasExp = 1,
362  HasPow = 1,
363  HasSqrt = 1,
364  HasTanh = EIGEN_FAST_MATH,
365  HasErf = EIGEN_FAST_MATH,
366  HasErfc = EIGEN_FAST_MATH
367  };
368 };
369 
370 template <>
371 struct unpacket_traits<PacketXf> {
372  typedef float type;
373  typedef PacketXf half; // Half not yet implemented
374  typedef PacketXi integer_packet;
375 
376  enum {
377  size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
378  alignment = Aligned64,
379  vectorizable = true,
380  masked_load_available = false,
381  masked_store_available = false
382  };
383 };
384 
385 template <>
386 EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from) {
387  return svdup_n_f32(from);
388 }
389 
390 template <>
391 EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from) {
392  return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), from));
393 }
394 
395 template <>
396 EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a) {
397  float c[packet_traits<float>::size];
398  for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
399  return svadd_f32_x(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
400 }
401 
402 template <>
403 EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b) {
404  return svadd_f32_x(svptrue_b32(), a, b);
405 }
406 
407 template <>
408 EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b) {
409  return svsub_f32_x(svptrue_b32(), a, b);
410 }
411 
412 template <>
413 EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) {
414  return svneg_f32_x(svptrue_b32(), a);
415 }
416 
417 template <>
418 EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) {
419  return a;
420 }
421 
422 template <>
423 EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b) {
424  return svmul_f32_x(svptrue_b32(), a, b);
425 }
426 
427 template <>
428 EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b) {
429  return svdiv_f32_x(svptrue_b32(), a, b);
430 }
431 
432 template <>
433 EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) {
434  return svmla_f32_x(svptrue_b32(), c, a, b);
435 }
436 
437 template <>
438 EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b) {
439  return svmin_f32_x(svptrue_b32(), a, b);
440 }
441 
442 template <>
443 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
444  return pmin<PacketXf>(a, b);
445 }
446 
447 template <>
448 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
449  return svminnm_f32_x(svptrue_b32(), a, b);
450 }
451 
452 template <>
453 EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b) {
454  return svmax_f32_x(svptrue_b32(), a, b);
455 }
456 
457 template <>
458 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) {
459  return pmax<PacketXf>(a, b);
460 }
461 
462 template <>
463 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
464  return svmaxnm_f32_x(svptrue_b32(), a, b);
465 }
466 
467 // Float comparisons in SVE return svbool (predicate). Use svdup to set active
468 // lanes to 1 (0xffffffffu) and inactive lanes to 0.
469 template <>
470 EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b) {
471  return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu));
472 }
473 
474 template <>
475 EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b) {
476  return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
477 }
478 
479 template <>
480 EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b) {
481  return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
482 }
483 
484 // Do a predicate inverse (svnot_b_z) on the predicate resulted from the
485 // greater/equal comparison (svcmpge_f32). Then fill a float vector with the
486 // active elements.
487 template <>
488 EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b) {
489  return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
490 }
491 
492 template <>
493 EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a) {
494  return svrintm_f32_x(svptrue_b32(), a);
495 }
496 
497 template <>
498 EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/) {
499  return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), 0xffffffffu));
500 }
501 
502 // Logical Operations are not supported for float, so reinterpret casts
503 template <>
504 EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b) {
505  return svreinterpret_f32_u32(svand_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
506 }
507 
508 template <>
509 EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b) {
510  return svreinterpret_f32_u32(svorr_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
511 }
512 
513 template <>
514 EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b) {
515  return svreinterpret_f32_u32(sveor_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
516 }
517 
518 template <>
519 EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b) {
520  return svreinterpret_f32_u32(svbic_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
521 }
522 
523 template <>
524 EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from) {
525  EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
526 }
527 
528 template <>
529 EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from) {
530  EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
531 }
532 
533 template <>
534 EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from) {
535  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
536  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
537  return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
538 }
539 
540 template <>
541 EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from) {
542  svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
543  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
544  indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
545  return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
546 }
547 
548 template <>
549 EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from) {
550  EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
551 }
552 
553 template <>
554 EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from) {
555  EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
556 }
557 
558 template <>
559 EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride) {
560  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
561  svint32_t indices = svindex_s32(0, stride);
562  return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
563 }
564 
565 template <>
566 EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride) {
567  // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
568  svint32_t indices = svindex_s32(0, stride);
569  svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
570 }
571 
572 template <>
573 EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a) {
574  // svlasta returns the first element if all predicate bits are 0
575  return svlasta_f32(svpfalse_b(), a);
576 }
577 
578 template <>
579 EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) {
580  return svrev_f32(a);
581 }
582 
583 template <>
584 EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) {
585  return svabs_f32_x(svptrue_b32(), a);
586 }
587 
588 // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
589 // all vector extensions and the generic version.
590 template <>
591 EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent) {
592  return pfrexp_generic(a, exponent);
593 }
594 
595 template <>
596 EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a) {
597  return svaddv_f32(svptrue_b32(), a);
598 }
599 
600 // Other reduction functions:
601 // mul
602 // Only works for SVE Vls multiple of 128
603 template <>
604 EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a) {
605  EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
606  // Multiply the vector by its reverse
607  svfloat32_t prod = svmul_f32_x(svptrue_b32(), a, svrev_f32(a));
608  svfloat32_t half_prod;
609 
610  // Extract the high half of the vector. Depending on the VL more reductions need to be done
611  if (EIGEN_ARM64_SVE_VL >= 2048) {
612  half_prod = svtbl_f32(prod, svindex_u32(32, 1));
613  prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
614  }
615  if (EIGEN_ARM64_SVE_VL >= 1024) {
616  half_prod = svtbl_f32(prod, svindex_u32(16, 1));
617  prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
618  }
619  if (EIGEN_ARM64_SVE_VL >= 512) {
620  half_prod = svtbl_f32(prod, svindex_u32(8, 1));
621  prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
622  }
623  if (EIGEN_ARM64_SVE_VL >= 256) {
624  half_prod = svtbl_f32(prod, svindex_u32(4, 1));
625  prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
626  }
627  // Last reduction
628  half_prod = svtbl_f32(prod, svindex_u32(2, 1));
629  prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
630 
631  // The reduction is done to the first element.
632  return pfirst<PacketXf>(prod);
633 }
634 
635 template <>
636 EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a) {
637  return svminv_f32(svptrue_b32(), a);
638 }
639 
640 template <>
641 EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a) {
642  return svmaxv_f32(svptrue_b32(), a);
643 }
644 
645 template <int N>
646 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel) {
647  float buffer[packet_traits<float>::size * N] = {0};
648  int i = 0;
649 
650  PacketXi stride_index = svindex_s32(0, N);
651 
652  for (i = 0; i < N; i++) {
653  svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
654  }
655 
656  for (i = 0; i < N; i++) {
657  kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
658  }
659 }
660 
661 template <>
662 EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent) {
663  return pldexp_generic(a, exponent);
664 }
665 
666 template <>
667 EIGEN_STRONG_INLINE PacketXf psqrt<PacketXf>(const PacketXf& a) {
668  return svsqrt_f32_x(svptrue_b32(), a);
669 }
670 
671 } // namespace internal
672 } // namespace Eigen
673 
674 #endif // EIGEN_PACKET_MATH_SVE_H
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Definition: Constants.h:239
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82