Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
GeneralBlockPanelKernel.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GENERAL_BLOCK_PANEL_H
11#define EIGEN_GENERAL_BLOCK_PANEL_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20enum GEBPPacketSizeType { GEBPPacketFull = 0, GEBPPacketHalf, GEBPPacketQuarter };
21
22template <typename LhsScalar_, typename RhsScalar_, bool ConjLhs_ = false, bool ConjRhs_ = false,
23 int Arch = Architecture::Target, int PacketSize_ = GEBPPacketFull>
24class gebp_traits;
25
27inline std::ptrdiff_t manage_caching_sizes_helper(std::ptrdiff_t a, std::ptrdiff_t b) { return a <= 0 ? b : a; }
28
29#if defined(EIGEN_DEFAULT_L1_CACHE_SIZE)
30#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) EIGEN_DEFAULT_L1_CACHE_SIZE
31#else
32#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) val
33#endif // defined(EIGEN_DEFAULT_L1_CACHE_SIZE)
34
35#if defined(EIGEN_DEFAULT_L2_CACHE_SIZE)
36#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) EIGEN_DEFAULT_L2_CACHE_SIZE
37#else
38#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) val
39#endif // defined(EIGEN_DEFAULT_L2_CACHE_SIZE)
40
41#if defined(EIGEN_DEFAULT_L3_CACHE_SIZE)
42#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) EIGEN_DEFAULT_L3_CACHE_SIZE
43#else
44#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) val
45#endif // defined(EIGEN_DEFAULT_L3_CACHE_SIZE)
46
47#if EIGEN_ARCH_i386_OR_x86_64
48const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(32 * 1024);
49const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(256 * 1024);
50const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(2 * 1024 * 1024);
51#elif EIGEN_ARCH_PPC
52const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(64 * 1024);
53#ifdef _ARCH_PWR10
54const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(2 * 1024 * 1024);
55const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(8 * 1024 * 1024);
56#else
57const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512 * 1024);
58const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(4 * 1024 * 1024);
59#endif
60#else
61const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(16 * 1024);
62const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512 * 1024);
63const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(512 * 1024);
64#endif
65
66#undef EIGEN_SET_DEFAULT_L1_CACHE_SIZE
67#undef EIGEN_SET_DEFAULT_L2_CACHE_SIZE
68#undef EIGEN_SET_DEFAULT_L3_CACHE_SIZE
69
71struct CacheSizes {
72 CacheSizes() : m_l1(-1), m_l2(-1), m_l3(-1) {
74 queryCacheSizes(l1CacheSize, l2CacheSize, l3CacheSize);
75 m_l1 = manage_caching_sizes_helper(l1CacheSize, defaultL1CacheSize);
76 m_l2 = manage_caching_sizes_helper(l2CacheSize, defaultL2CacheSize);
77 m_l3 = manage_caching_sizes_helper(l3CacheSize, defaultL3CacheSize);
78 }
79
80 std::ptrdiff_t m_l1;
81 std::ptrdiff_t m_l2;
82 std::ptrdiff_t m_l3;
83};
84
86inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff_t* l2, std::ptrdiff_t* l3) {
87 static CacheSizes m_cacheSizes;
88
89 if (action == SetAction) {
90 // set the cpu cache size and cache all block sizes from a global cache size in byte
91 eigen_internal_assert(l1 != 0 && l2 != 0);
92 m_cacheSizes.m_l1 = *l1;
93 m_cacheSizes.m_l2 = *l2;
94 m_cacheSizes.m_l3 = *l3;
95 } else if (action == GetAction) {
96 eigen_internal_assert(l1 != 0 && l2 != 0);
97 *l1 = m_cacheSizes.m_l1;
98 *l2 = m_cacheSizes.m_l2;
99 *l3 = m_cacheSizes.m_l3;
100 } else {
101 eigen_internal_assert(false);
102 }
103}
104
105/* Helper for computeProductBlockingSizes.
106 *
107 * Given a m x k times k x n matrix product of scalar types \c LhsScalar and \c RhsScalar,
108 * this function computes the blocking size parameters along the respective dimensions
109 * for matrix products and related algorithms. The blocking sizes depends on various
110 * parameters:
111 * - the L1 and L2 cache sizes,
112 * - the register level blocking sizes defined by gebp_traits,
113 * - the number of scalars that fit into a packet (when vectorization is enabled).
114 *
115 * \sa setCpuCacheSizes */
116
117template <typename LhsScalar, typename RhsScalar, int KcFactor, typename Index>
118void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index num_threads = 1) {
119 typedef gebp_traits<LhsScalar, RhsScalar> Traits;
120
121 // Explanations:
122 // Let's recall that the product algorithms form mc x kc vertical panels A' on the lhs and
123 // kc x nc blocks B' on the rhs. B' has to fit into L2/L3 cache. Moreover, A' is processed
124 // per mr x kc horizontal small panels where mr is the blocking size along the m dimension
125 // at the register level. This small horizontal panel has to stay within L1 cache.
126 std::ptrdiff_t l1, l2, l3;
127 manage_caching_sizes(GetAction, &l1, &l2, &l3);
128#ifdef EIGEN_VECTORIZE_AVX512
129 // We need to find a rationale for that, but without this adjustment,
130 // performance with AVX512 is pretty bad, like -20% slower.
131 // One reason is that with increasing packet-size, the blocking size k
132 // has to become pretty small if we want that 1 lhs panel fit within L1.
133 // For instance, with the 3pX4 kernel and double, the size of the lhs+rhs panels are:
134 // k*(3*64 + 4*8) Bytes, with l1=32kBytes, and k%8=0, we have k=144.
135 // This is quite small for a good reuse of the accumulation registers.
136 l1 *= 4;
137#endif
138
139 if (num_threads > 1) {
140 typedef typename Traits::ResScalar ResScalar;
141 enum {
142 kdiv = KcFactor * (Traits::mr * sizeof(LhsScalar) + Traits::nr * sizeof(RhsScalar)),
143 ksub = Traits::mr * (Traits::nr * sizeof(ResScalar)),
144 kr = 8,
145 mr = Traits::mr,
146 nr = Traits::nr
147 };
148 // Increasing k gives us more time to prefetch the content of the "C"
149 // registers. However once the latency is hidden there is no point in
150 // increasing the value of k, so we'll cap it at 320 (value determined
151 // experimentally).
152 // To avoid that k vanishes, we make k_cache at least as big as kr
153 const Index k_cache = numext::maxi<Index>(kr, (numext::mini<Index>)((l1 - ksub) / kdiv, 320));
154 if (k_cache < k) {
155 k = k_cache - (k_cache % kr);
156 eigen_internal_assert(k > 0);
157 }
158
159 const Index n_cache = (l2 - l1) / (nr * sizeof(RhsScalar) * k);
160 const Index n_per_thread = numext::div_ceil(n, num_threads);
161 if (n_cache <= n_per_thread) {
162 // Don't exceed the capacity of the l2 cache.
163 eigen_internal_assert(n_cache >= static_cast<Index>(nr));
164 n = n_cache - (n_cache % nr);
165 eigen_internal_assert(n > 0);
166 } else {
167 n = (numext::mini<Index>)(n, (n_per_thread + nr - 1) - ((n_per_thread + nr - 1) % nr));
168 }
169
170 if (l3 > l2) {
171 // l3 is shared between all cores, so we'll give each thread its own chunk of l3.
172 const Index m_cache = (l3 - l2) / (sizeof(LhsScalar) * k * num_threads);
173 const Index m_per_thread = numext::div_ceil(m, num_threads);
174 if (m_cache < m_per_thread && m_cache >= static_cast<Index>(mr)) {
175 m = m_cache - (m_cache % mr);
176 eigen_internal_assert(m > 0);
177 } else {
178 m = (numext::mini<Index>)(m, (m_per_thread + mr - 1) - ((m_per_thread + mr - 1) % mr));
179 }
180 }
181 } else {
182 // In unit tests we do not want to use extra large matrices,
183 // so we reduce the cache size to check the blocking strategy is not flawed
184#ifdef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS
185 l1 = 9 * 1024;
186 l2 = 32 * 1024;
187 l3 = 512 * 1024;
188#endif
189
190 // Early return for small problems because the computation below are time consuming for small problems.
191 // Perhaps it would make more sense to consider k*n*m??
192 // Note that for very tiny problem, this function should be bypassed anyway
193 // because we use the coefficient-based implementation for them.
194 if ((numext::maxi)(k, (numext::maxi)(m, n)) < 48) return;
195
196 typedef typename Traits::ResScalar ResScalar;
197 enum {
198 k_peeling = 8,
199 k_div = KcFactor * (Traits::mr * sizeof(LhsScalar) + Traits::nr * sizeof(RhsScalar)),
200 k_sub = Traits::mr * (Traits::nr * sizeof(ResScalar))
201 };
202
203 // ---- 1st level of blocking on L1, yields kc ----
204
205 // Blocking on the third dimension (i.e., k) is chosen so that an horizontal panel
206 // of size mr x kc of the lhs plus a vertical panel of kc x nr of the rhs both fits within L1 cache.
207 // We also include a register-level block of the result (mx x nr).
208 // (In an ideal world only the lhs panel would stay in L1)
209 // Moreover, kc has to be a multiple of 8 to be compatible with loop peeling, leading to a maximum blocking size of:
210 const Index max_kc = numext::maxi<Index>(((l1 - k_sub) / k_div) & (~(k_peeling - 1)), 1);
211 const Index old_k = k;
212 if (k > max_kc) {
213 // We are really blocking on the third dimension:
214 // -> reduce blocking size to make sure the last block is as large as possible
215 // while keeping the same number of sweeps over the result.
216 k = (k % max_kc) == 0 ? max_kc
217 : max_kc - k_peeling * ((max_kc - 1 - (k % max_kc)) / (k_peeling * (k / max_kc + 1)));
218
219 eigen_internal_assert(((old_k / k) == (old_k / max_kc)) && "the number of sweeps has to remain the same");
220 }
221
222// ---- 2nd level of blocking on max(L2,L3), yields nc ----
223
224// TODO find a reliable way to get the actual amount of cache per core to use for 2nd level blocking, that is:
225// actual_l2 = max(l2, l3/nb_core_sharing_l3)
226// The number below is quite conservative: it is better to underestimate the cache size rather than overestimating it)
227// For instance, it corresponds to 6MB of L3 shared among 4 cores.
228#ifdef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS
229 const Index actual_l2 = l3;
230#else
231 const Index actual_l2 = 1572864; // == 1.5 MB
232#endif
233
234 // Here, nc is chosen such that a block of kc x nc of the rhs fit within half of L2.
235 // The second half is implicitly reserved to access the result and lhs coefficients.
236 // When k<max_kc, then nc can arbitrarily growth. In practice, it seems to be fruitful
237 // to limit this growth: we bound nc to growth by a factor x1.5.
238 // However, if the entire lhs block fit within L1, then we are not going to block on the rows at all,
239 // and it becomes fruitful to keep the packed rhs blocks in L1 if there is enough remaining space.
240 Index max_nc;
241 const Index lhs_bytes = m * k * sizeof(LhsScalar);
242 const Index remaining_l1 = l1 - k_sub - lhs_bytes;
243 if (remaining_l1 >= Index(Traits::nr * sizeof(RhsScalar)) * k) {
244 // L1 blocking
245 max_nc = remaining_l1 / (k * sizeof(RhsScalar));
246 } else {
247 // L2 blocking
248 max_nc = (3 * actual_l2) / (2 * 2 * max_kc * sizeof(RhsScalar));
249 }
250 // WARNING Below, we assume that Traits::nr is a power of two.
251 Index nc = numext::mini<Index>(actual_l2 / (2 * k * sizeof(RhsScalar)), max_nc) & (~(Traits::nr - 1));
252 if (n > nc) {
253 // We are really blocking over the columns:
254 // -> reduce blocking size to make sure the last block is as large as possible
255 // while keeping the same number of sweeps over the packed lhs.
256 // Here we allow one more sweep if this gives us a perfect match, thus the commented "-1"
257 n = (n % nc) == 0 ? nc : (nc - Traits::nr * ((nc /*-1*/ - (n % nc)) / (Traits::nr * (n / nc + 1))));
258 } else if (old_k == k) {
259 // So far, no blocking at all, i.e., kc==k, and nc==n.
260 // In this case, let's perform a blocking over the rows such that the packed lhs data is kept in cache L1/L2
261 // TODO: part of this blocking strategy is now implemented within the kernel itself, so the L1-based heuristic
262 // here should be obsolete.
263 Index problem_size = k * n * sizeof(LhsScalar);
264 Index actual_lm = actual_l2;
265 Index max_mc = m;
266 if (problem_size <= 1024) {
267 // problem is small enough to keep in L1
268 // Let's choose m such that lhs's block fit in 1/3 of L1
269 actual_lm = l1;
270 } else if (l3 != 0 && problem_size <= 32768) {
271 // we have both L2 and L3, and problem is small enough to be kept in L2
272 // Let's choose m such that lhs's block fit in 1/3 of L2
273 actual_lm = l2;
274 max_mc = (numext::mini<Index>)(576, max_mc);
275 }
276 Index mc = (numext::mini<Index>)(actual_lm / (3 * k * sizeof(LhsScalar)), max_mc);
277 if (mc > Traits::mr)
278 mc -= mc % Traits::mr;
279 else if (mc == 0)
280 return;
281 m = (m % mc) == 0 ? mc : (mc - Traits::mr * ((mc /*-1*/ - (m % mc)) / (Traits::mr * (m / mc + 1))));
282 }
283 }
284}
285
286template <typename Index>
287inline bool useSpecificBlockingSizes(Index& k, Index& m, Index& n) {
288#ifdef EIGEN_TEST_SPECIFIC_BLOCKING_SIZES
289 if (EIGEN_TEST_SPECIFIC_BLOCKING_SIZES) {
290 k = numext::mini<Index>(k, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_K);
291 m = numext::mini<Index>(m, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_M);
292 n = numext::mini<Index>(n, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_N);
293 return true;
294 }
295#else
296 EIGEN_UNUSED_VARIABLE(k)
297 EIGEN_UNUSED_VARIABLE(m)
298 EIGEN_UNUSED_VARIABLE(n)
299#endif
300 return false;
301}
302
320
321template <typename LhsScalar, typename RhsScalar, int KcFactor, typename Index>
322void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_threads = 1) {
323 if (!useSpecificBlockingSizes(k, m, n)) {
324 evaluateProductBlockingSizesHeuristic<LhsScalar, RhsScalar, KcFactor, Index>(k, m, n, num_threads);
325 }
326}
327
328template <typename LhsScalar, typename RhsScalar, typename Index>
329inline void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_threads = 1) {
330 computeProductBlockingSizes<LhsScalar, RhsScalar, 1, Index>(k, m, n, num_threads);
331}
332
333template <typename RhsPacket, typename RhsPacketx4, int registers_taken>
334struct RhsPanelHelper {
335 private:
336 static constexpr int remaining_registers =
337 (std::max)(int(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS) - registers_taken, 0);
338
339 public:
340 typedef std::conditional_t<remaining_registers >= 4, RhsPacketx4, RhsPacket> type;
341};
342
343template <typename Packet>
344struct QuadPacket {
345 Packet B_0, B1, B2, B3;
346 const Packet& get(const FixedInt<0>&) const { return B_0; }
347 const Packet& get(const FixedInt<1>&) const { return B1; }
348 const Packet& get(const FixedInt<2>&) const { return B2; }
349 const Packet& get(const FixedInt<3>&) const { return B3; }
350};
351
352template <int N, typename T1, typename T2, typename T3>
353struct packet_conditional {
354 typedef T3 type;
355};
356
357template <typename T1, typename T2, typename T3>
358struct packet_conditional<GEBPPacketFull, T1, T2, T3> {
359 typedef T1 type;
360};
361
362template <typename T1, typename T2, typename T3>
363struct packet_conditional<GEBPPacketHalf, T1, T2, T3> {
364 typedef T2 type;
365};
366
367#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
368 typedef typename packet_conditional< \
369 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
370 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
371
372#define PACKET_DECL_COND(name, packet_size) \
373 typedef typename packet_conditional< \
374 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
375 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet
376
377#define PACKET_DECL_COND_SCALAR_POSTFIX(postfix, packet_size) \
378 typedef typename packet_conditional< \
379 packet_size, typename packet_traits<Scalar>::type, typename packet_traits<Scalar>::half, \
380 typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type ScalarPacket##postfix
381
382#define PACKET_DECL_COND_SCALAR(packet_size) \
383 typedef typename packet_conditional< \
384 packet_size, typename packet_traits<Scalar>::type, typename packet_traits<Scalar>::half, \
385 typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type ScalarPacket
386
387/* Vectorization logic
388 * real*real: unpack rhs to constant packets, ...
389 *
390 * cd*cd : unpack rhs to (b_r,b_r), (b_i,b_i), mul to get (a_r b_r,a_i b_r) (a_r b_i,a_i b_i),
391 * storing each res packet into two packets (2x2),
392 * at the end combine them: swap the second and addsub them
393 * cf*cf : same but with 2x4 blocks
394 * cplx*real : unpack rhs to constant packets, ...
395 * real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual
396 */
397template <typename LhsScalar_, typename RhsScalar_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
398class gebp_traits {
399 public:
400 typedef LhsScalar_ LhsScalar;
401 typedef RhsScalar_ RhsScalar;
402 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
403
404 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
405 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
406 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
407
408 enum {
409 ConjLhs = ConjLhs_,
410 ConjRhs = ConjRhs_,
411 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
412 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
413 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
414 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
415
416 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
417
418 // register block size along the N direction must be 1 or 4
419 nr = 4,
420
421 // register block size along the M direction (currently, this one cannot be modified)
422 default_mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * LhsPacketSize,
423#if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && \
424 !defined(EIGEN_VECTORIZE_VSX) && ((!EIGEN_COMP_MSVC) || (EIGEN_COMP_MSVC >= 1914))
425 // we assume 16 registers or more
426 // See bug 992, if the scalar type is not vectorizable but that EIGEN_HAS_SINGLE_INSTRUCTION_MADD is defined,
427 // then using 3*LhsPacketSize triggers non-implemented paths in syrk.
428 // Bug 1515: MSVC prior to v19.14 yields to register spilling.
429 mr = Vectorizable ? 3 * LhsPacketSize : default_mr,
430#else
431 mr = default_mr,
432#endif
433
434 LhsProgress = LhsPacketSize,
435 RhsProgress = 1
436 };
437
438 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
439 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
440 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
441 typedef LhsPacket LhsPacket4Packing;
442
443 typedef QuadPacket<RhsPacket> RhsPacketx4;
444 typedef ResPacket AccPacket;
445
446 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
447
448 template <typename RhsPacketType>
449 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const {
450 dest = pset1<RhsPacketType>(*b);
451 }
452
453 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
454 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
455 }
456
457 template <typename RhsPacketType>
458 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const {
459 loadRhs(b, dest);
460 }
461
462 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
463
464 EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad<RhsPacket>(b); }
465
466 template <typename LhsPacketType>
467 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacketType& dest) const {
468 dest = pload<LhsPacketType>(a);
469 }
470
471 template <typename LhsPacketType>
472 EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const {
473 dest = ploadu<LhsPacketType>(a);
474 }
475
476 template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
477 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
478 const LaneIdType&) const {
479 conj_helper<LhsPacketType, RhsPacketType, ConjLhs, ConjRhs> cj;
480 // It would be a lot cleaner to call pmadd all the time. Unfortunately if we
481 // let gcc allocate the register in which to store the result of the pmul
482 // (in the case where there is no FMA) gcc fails to figure out how to avoid
483 // spilling register.
484#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
485 EIGEN_UNUSED_VARIABLE(tmp);
486 c = cj.pmadd(a, b, c);
487#else
488 tmp = b;
489 tmp = cj.pmul(a, tmp);
490 c = padd(c, tmp);
491#endif
492 }
493
494 template <typename LhsPacketType, typename AccPacketType, typename LaneIdType>
495 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
496 const LaneIdType& lane) const {
497 madd(a, b.get(lane), c, tmp, lane);
498 }
499
500 EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const {
501 r = pmadd(c, alpha, r);
502 }
503
504 template <typename ResPacketHalf>
505 EIGEN_STRONG_INLINE void acc(const ResPacketHalf& c, const ResPacketHalf& alpha, ResPacketHalf& r) const {
506 r = pmadd(c, alpha, r);
507 }
508};
509
510template <typename RealScalar, bool ConjLhs_, int Arch, int PacketSize_>
511class gebp_traits<std::complex<RealScalar>, RealScalar, ConjLhs_, false, Arch, PacketSize_> {
512 public:
513 typedef std::complex<RealScalar> LhsScalar;
514 typedef RealScalar RhsScalar;
515 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
516
517 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
518 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
519 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
520
521 enum {
522 ConjLhs = ConjLhs_,
523 ConjRhs = false,
524 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
525 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
526 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
527 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
528
529 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
530 nr = 4,
531#if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && !defined(EIGEN_VECTORIZE_VSX)
532 // we assume 16 registers
533 mr = 3 * LhsPacketSize,
534#else
535 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * LhsPacketSize,
536#endif
537
538 LhsProgress = LhsPacketSize,
539 RhsProgress = 1
540 };
541
542 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
543 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
544 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
545 typedef LhsPacket LhsPacket4Packing;
546
547 typedef QuadPacket<RhsPacket> RhsPacketx4;
548
549 typedef ResPacket AccPacket;
550
551 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
552
553 template <typename RhsPacketType>
554 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const {
555 dest = pset1<RhsPacketType>(*b);
556 }
557
558 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
559 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
560 }
561
562 template <typename RhsPacketType>
563 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const {
564 loadRhs(b, dest);
565 }
566
567 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
568
569 EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const {
570 loadRhsQuad_impl(b, dest, std::conditional_t<RhsPacketSize == 16, true_type, false_type>());
571 }
572
573 EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const true_type&) const {
574 // FIXME we can do better!
575 // what we want here is a ploadheight
576 RhsScalar tmp[4] = {b[0], b[0], b[1], b[1]};
577 dest = ploadquad<RhsPacket>(tmp);
578 }
579
580 EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const false_type&) const {
581 eigen_internal_assert(RhsPacketSize <= 8);
582 dest = pset1<RhsPacket>(*b);
583 }
584
585 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const { dest = pload<LhsPacket>(a); }
586
587 template <typename LhsPacketType>
588 EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const {
589 dest = ploadu<LhsPacketType>(a);
590 }
591
592 template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
593 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
594 const LaneIdType&) const {
595 madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable, true_type, false_type>());
596 }
597
598 template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
599 EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c,
600 RhsPacketType& tmp, const true_type&) const {
601#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
602 EIGEN_UNUSED_VARIABLE(tmp);
603 c.v = pmadd(a.v, b, c.v);
604#else
605 tmp = b;
606 tmp = pmul(a.v, tmp);
607 c.v = padd(c.v, tmp);
608#endif
609 }
610
611 EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/,
612 const false_type&) const {
613 c += a * b;
614 }
615
616 template <typename LhsPacketType, typename AccPacketType, typename LaneIdType>
617 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
618 const LaneIdType& lane) const {
619 madd(a, b.get(lane), c, tmp, lane);
620 }
621
622 template <typename ResPacketType, typename AccPacketType>
623 EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const {
624 conj_helper<ResPacketType, ResPacketType, ConjLhs, false> cj;
625 r = cj.pmadd(c, alpha, r);
626 }
627
628 protected:
629};
630
631template <typename Packet>
632struct DoublePacket {
633 Packet first;
634 Packet second;
635};
636
637template <typename Packet>
638DoublePacket<Packet> padd(const DoublePacket<Packet>& a, const DoublePacket<Packet>& b) {
639 DoublePacket<Packet> res;
640 res.first = padd(a.first, b.first);
641 res.second = padd(a.second, b.second);
642 return res;
643}
644
645// note that for DoublePacket<RealPacket> the "4" in "downto4"
646// corresponds to the number of complexes, so it means "8"
647// it terms of real coefficients.
648
649template <typename Packet>
650const DoublePacket<Packet>& predux_half_dowto4(const DoublePacket<Packet>& a,
651 std::enable_if_t<unpacket_traits<Packet>::size <= 8>* = 0) {
652 return a;
653}
654
655template <typename Packet>
656DoublePacket<typename unpacket_traits<Packet>::half> predux_half_dowto4(
657 const DoublePacket<Packet>& a, std::enable_if_t<unpacket_traits<Packet>::size == 16>* = 0) {
658 // yes, that's pretty hackish :(
659 DoublePacket<typename unpacket_traits<Packet>::half> res;
660 typedef std::complex<typename unpacket_traits<Packet>::type> Cplx;
661 typedef typename packet_traits<Cplx>::type CplxPacket;
662 res.first = predux_half_dowto4(CplxPacket(a.first)).v;
663 res.second = predux_half_dowto4(CplxPacket(a.second)).v;
664 return res;
665}
666
667// same here, "quad" actually means "8" in terms of real coefficients
668template <typename Scalar, typename RealPacket>
669void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
670 std::enable_if_t<unpacket_traits<RealPacket>::size <= 8>* = 0) {
671 dest.first = pset1<RealPacket>(numext::real(*b));
672 dest.second = pset1<RealPacket>(numext::imag(*b));
673}
674
675template <typename Scalar, typename RealPacket>
676void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
677 std::enable_if_t<unpacket_traits<RealPacket>::size == 16>* = 0) {
678 // yes, that's pretty hackish too :(
679 typedef typename NumTraits<Scalar>::Real RealScalar;
680 RealScalar r[4] = {numext::real(b[0]), numext::real(b[0]), numext::real(b[1]), numext::real(b[1])};
681 RealScalar i[4] = {numext::imag(b[0]), numext::imag(b[0]), numext::imag(b[1]), numext::imag(b[1])};
682 dest.first = ploadquad<RealPacket>(r);
683 dest.second = ploadquad<RealPacket>(i);
684}
685
686template <typename Packet>
687struct unpacket_traits<DoublePacket<Packet> > {
688 typedef DoublePacket<typename unpacket_traits<Packet>::half> half;
689 enum { size = 2 * unpacket_traits<Packet>::size };
690};
691// template<typename Packet>
692// DoublePacket<Packet> pmadd(const DoublePacket<Packet> &a, const DoublePacket<Packet> &b)
693// {
694// DoublePacket<Packet> res;
695// res.first = padd(a.first, b.first);
696// res.second = padd(a.second,b.second);
697// return res;
698// }
699
700template <typename RealScalar, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
701class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, ConjLhs_, ConjRhs_, Arch, PacketSize_> {
702 public:
703 typedef std::complex<RealScalar> Scalar;
704 typedef std::complex<RealScalar> LhsScalar;
705 typedef std::complex<RealScalar> RhsScalar;
706 typedef std::complex<RealScalar> ResScalar;
707
708 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
709 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
710 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
711 PACKET_DECL_COND(Real, PacketSize_);
712 PACKET_DECL_COND_SCALAR(PacketSize_);
713
714 enum {
715 ConjLhs = ConjLhs_,
716 ConjRhs = ConjRhs_,
717 Vectorizable = unpacket_traits<RealPacket>::vectorizable && unpacket_traits<ScalarPacket>::vectorizable,
718 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
719 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
720 RhsPacketSize = Vectorizable ? unpacket_traits<RhsScalar>::size : 1,
721 RealPacketSize = Vectorizable ? unpacket_traits<RealPacket>::size : 1,
722 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
723
724 nr = 4,
725 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * ResPacketSize,
726
727 LhsProgress = ResPacketSize,
728 RhsProgress = 1
729 };
730
731 typedef DoublePacket<RealPacket> DoublePacketType;
732
733 typedef std::conditional_t<Vectorizable, ScalarPacket, Scalar> LhsPacket4Packing;
734 typedef std::conditional_t<Vectorizable, RealPacket, Scalar> LhsPacket;
735 typedef std::conditional_t<Vectorizable, DoublePacketType, Scalar> RhsPacket;
736 typedef std::conditional_t<Vectorizable, ScalarPacket, Scalar> ResPacket;
737 typedef std::conditional_t<Vectorizable, DoublePacketType, Scalar> AccPacket;
738
739 // this actually holds 8 packets!
740 typedef QuadPacket<RhsPacket> RhsPacketx4;
741
742 EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); }
743
744 EIGEN_STRONG_INLINE void initAcc(DoublePacketType& p) {
745 p.first = pset1<RealPacket>(RealScalar(0));
746 p.second = pset1<RealPacket>(RealScalar(0));
747 }
748
749 // Scalar path
750 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ScalarPacket& dest) const { dest = pset1<ScalarPacket>(*b); }
751
752 // Vectorized path
753 template <typename RealPacketType>
754 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const {
755 dest.first = pset1<RealPacketType>(numext::real(*b));
756 dest.second = pset1<RealPacketType>(numext::imag(*b));
757 }
758
759 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
760 loadRhs(b, dest.B_0);
761 loadRhs(b + 1, dest.B1);
762 loadRhs(b + 2, dest.B2);
763 loadRhs(b + 3, dest.B3);
764 }
765
766 // Scalar path
767 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, ScalarPacket& dest) const { loadRhs(b, dest); }
768
769 // Vectorized path
770 template <typename RealPacketType>
771 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, DoublePacket<RealPacketType>& dest) const {
772 loadRhs(b, dest);
773 }
774
775 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
776
777 EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const { loadRhs(b, dest); }
778 EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const {
779 loadQuadToDoublePacket(b, dest);
780 }
781
782 // nothing special here
783 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const {
784 dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
785 }
786
787 template <typename LhsPacketType>
788 EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const {
789 dest = ploadu<LhsPacketType>((const typename unpacket_traits<LhsPacketType>::type*)(a));
790 }
791
792 template <typename LhsPacketType, typename RhsPacketType, typename ResPacketType, typename TmpType,
793 typename LaneIdType>
794 EIGEN_STRONG_INLINE std::enable_if_t<!is_same<RhsPacketType, RhsPacketx4>::value> madd(const LhsPacketType& a,
795 const RhsPacketType& b,
796 DoublePacket<ResPacketType>& c,
797 TmpType& /*tmp*/,
798 const LaneIdType&) const {
799 c.first = pmadd(a, b.first, c.first);
800 c.second = pmadd(a, b.second, c.second);
801 }
802
803 template <typename LaneIdType>
804 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/,
805 const LaneIdType&) const {
806 c = cj.pmadd(a, b, c);
807 }
808
809 template <typename LhsPacketType, typename AccPacketType, typename LaneIdType>
810 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
811 const LaneIdType& lane) const {
812 madd(a, b.get(lane), c, tmp, lane);
813 }
814
815 EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; }
816
817 template <typename RealPacketType, typename ResPacketType>
818 EIGEN_STRONG_INLINE void acc(const DoublePacket<RealPacketType>& c, const ResPacketType& alpha,
819 ResPacketType& r) const {
820 // assemble c
821 ResPacketType tmp;
822 if ((!ConjLhs) && (!ConjRhs)) {
823 tmp = pcplxflip(pconj(ResPacketType(c.second)));
824 tmp = padd(ResPacketType(c.first), tmp);
825 } else if ((!ConjLhs) && (ConjRhs)) {
826 tmp = pconj(pcplxflip(ResPacketType(c.second)));
827 tmp = padd(ResPacketType(c.first), tmp);
828 } else if ((ConjLhs) && (!ConjRhs)) {
829 tmp = pcplxflip(ResPacketType(c.second));
830 tmp = padd(pconj(ResPacketType(c.first)), tmp);
831 } else if ((ConjLhs) && (ConjRhs)) {
832 tmp = pcplxflip(ResPacketType(c.second));
833 tmp = psub(pconj(ResPacketType(c.first)), tmp);
834 }
835
836 r = pmadd(tmp, alpha, r);
837 }
838
839 protected:
840 conj_helper<LhsScalar, RhsScalar, ConjLhs, ConjRhs> cj;
841};
842
843template <typename RealScalar, bool ConjRhs_, int Arch, int PacketSize_>
844class gebp_traits<RealScalar, std::complex<RealScalar>, false, ConjRhs_, Arch, PacketSize_> {
845 public:
846 typedef std::complex<RealScalar> Scalar;
847 typedef RealScalar LhsScalar;
848 typedef Scalar RhsScalar;
849 typedef Scalar ResScalar;
850
851 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
852 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
853 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
854 PACKET_DECL_COND_POSTFIX(_, Real, PacketSize_);
855 PACKET_DECL_COND_SCALAR_POSTFIX(_, PacketSize_);
856
857#undef PACKET_DECL_COND_SCALAR_POSTFIX
858#undef PACKET_DECL_COND_POSTFIX
859#undef PACKET_DECL_COND_SCALAR
860#undef PACKET_DECL_COND
861
862 enum {
863 ConjLhs = false,
864 ConjRhs = ConjRhs_,
865 Vectorizable = unpacket_traits<RealPacket_>::vectorizable && unpacket_traits<ScalarPacket_>::vectorizable,
866 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
867 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
868 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
869
870 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
871 // FIXME: should depend on NumberOfRegisters
872 nr = 4,
873 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * ResPacketSize,
874
875 LhsProgress = ResPacketSize,
876 RhsProgress = 1
877 };
878
879 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
880 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
881 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
882 typedef LhsPacket LhsPacket4Packing;
883 typedef QuadPacket<RhsPacket> RhsPacketx4;
884 typedef ResPacket AccPacket;
885
886 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
887
888 template <typename RhsPacketType>
889 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const {
890 dest = pset1<RhsPacketType>(*b);
891 }
892
893 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
894 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
895 }
896
897 template <typename RhsPacketType>
898 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketType& dest) const {
899 loadRhs(b, dest);
900 }
901
902 EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
903
904 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const { dest = ploaddup<LhsPacket>(a); }
905
906 EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad<RhsPacket>(b); }
907
908 template <typename LhsPacketType>
909 EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const {
910 dest = ploaddup<LhsPacketType>(a);
911 }
912
913 template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
914 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
915 const LaneIdType&) const {
916 madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable, true_type, false_type>());
917 }
918
919 template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
920 EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c,
921 RhsPacketType& tmp, const true_type&) const {
922#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
923 EIGEN_UNUSED_VARIABLE(tmp);
924 c.v = pmadd(a, b.v, c.v);
925#else
926 tmp = b;
927 tmp.v = pmul(a, tmp.v);
928 c = padd(c, tmp);
929#endif
930 }
931
932 EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/,
933 const false_type&) const {
934 c += a * b;
935 }
936
937 template <typename LhsPacketType, typename AccPacketType, typename LaneIdType>
938 EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
939 const LaneIdType& lane) const {
940 madd(a, b.get(lane), c, tmp, lane);
941 }
942
943 template <typename ResPacketType, typename AccPacketType>
944 EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const {
945 conj_helper<ResPacketType, ResPacketType, false, ConjRhs> cj;
946 r = cj.pmadd(alpha, c, r);
947 }
948
949 protected:
950};
951
952/* optimized General packed Block * packed Panel product kernel
953 *
954 * Mixing type logic: C += A * B
955 * | A | B | comments
956 * |real |cplx | no vectorization yet, would require to pack A with duplication
957 * |cplx |real | easy vectorization
958 */
959template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
960 bool ConjugateLhs, bool ConjugateRhs>
961struct gebp_kernel {
962 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
963 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target, GEBPPacketHalf>
964 HalfTraits;
965 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target, GEBPPacketQuarter>
966 QuarterTraits;
967
968 typedef typename Traits::ResScalar ResScalar;
969 typedef typename Traits::LhsPacket LhsPacket;
970 typedef typename Traits::RhsPacket RhsPacket;
971 typedef typename Traits::ResPacket ResPacket;
972 typedef typename Traits::AccPacket AccPacket;
973 typedef typename Traits::RhsPacketx4 RhsPacketx4;
974
975 typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
976 typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 27>::type RhsPanel27;
977
978 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
979
980 typedef typename SwappedTraits::ResScalar SResScalar;
981 typedef typename SwappedTraits::LhsPacket SLhsPacket;
982 typedef typename SwappedTraits::RhsPacket SRhsPacket;
983 typedef typename SwappedTraits::ResPacket SResPacket;
984 typedef typename SwappedTraits::AccPacket SAccPacket;
985
986 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
987 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
988 typedef typename HalfTraits::ResPacket ResPacketHalf;
989 typedef typename HalfTraits::AccPacket AccPacketHalf;
990
991 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
992 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
993 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
994 typedef typename QuarterTraits::AccPacket AccPacketQuarter;
995
996 typedef typename DataMapper::LinearMapper LinearMapper;
997
998 enum {
999 Vectorizable = Traits::Vectorizable,
1000 LhsProgress = Traits::LhsProgress,
1001 LhsProgressHalf = HalfTraits::LhsProgress,
1002 LhsProgressQuarter = QuarterTraits::LhsProgress,
1003 RhsProgress = Traits::RhsProgress,
1004 RhsProgressHalf = HalfTraits::RhsProgress,
1005 RhsProgressQuarter = QuarterTraits::RhsProgress,
1006 ResPacketSize = Traits::ResPacketSize
1007 };
1008
1009 EIGEN_DONT_INLINE void operator()(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, Index rows,
1010 Index depth, Index cols, ResScalar alpha, Index strideA = -1, Index strideB = -1,
1011 Index offsetA = 0, Index offsetB = 0);
1012};
1013
1014template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
1015 bool ConjugateLhs, bool ConjugateRhs,
1016 int SwappedLhsProgress =
1017 gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target>::LhsProgress>
1018struct last_row_process_16_packets {
1019 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
1020 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
1021
1022 typedef typename Traits::ResScalar ResScalar;
1023 typedef typename SwappedTraits::LhsPacket SLhsPacket;
1024 typedef typename SwappedTraits::RhsPacket SRhsPacket;
1025 typedef typename SwappedTraits::ResPacket SResPacket;
1026 typedef typename SwappedTraits::AccPacket SAccPacket;
1027
1028 EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits& straits, const LhsScalar* blA,
1029 const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
1030 ResScalar alpha, SAccPacket& C0) {
1031 EIGEN_UNUSED_VARIABLE(res);
1032 EIGEN_UNUSED_VARIABLE(straits);
1033 EIGEN_UNUSED_VARIABLE(blA);
1034 EIGEN_UNUSED_VARIABLE(blB);
1035 EIGEN_UNUSED_VARIABLE(depth);
1036 EIGEN_UNUSED_VARIABLE(endk);
1037 EIGEN_UNUSED_VARIABLE(i);
1038 EIGEN_UNUSED_VARIABLE(j2);
1039 EIGEN_UNUSED_VARIABLE(alpha);
1040 EIGEN_UNUSED_VARIABLE(C0);
1041 }
1042};
1043
1044template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
1045 bool ConjugateLhs, bool ConjugateRhs>
1046struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
1047 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
1048 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
1049
1050 typedef typename Traits::ResScalar ResScalar;
1051 typedef typename SwappedTraits::LhsPacket SLhsPacket;
1052 typedef typename SwappedTraits::RhsPacket SRhsPacket;
1053 typedef typename SwappedTraits::ResPacket SResPacket;
1054 typedef typename SwappedTraits::AccPacket SAccPacket;
1055
1056 EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits& straits, const LhsScalar* blA,
1057 const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
1058 ResScalar alpha, SAccPacket& C0) {
1059 typedef typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half SResPacketQuarter;
1060 typedef typename unpacket_traits<typename unpacket_traits<SLhsPacket>::half>::half SLhsPacketQuarter;
1061 typedef typename unpacket_traits<typename unpacket_traits<SRhsPacket>::half>::half SRhsPacketQuarter;
1062 typedef typename unpacket_traits<typename unpacket_traits<SAccPacket>::half>::half SAccPacketQuarter;
1063
1064 SResPacketQuarter R = res.template gatherPacket<SResPacketQuarter>(i, j2);
1065 SResPacketQuarter alphav = pset1<SResPacketQuarter>(alpha);
1066
1067 if (depth - endk > 0) {
1068 // We have to handle the last row(s) of the rhs, which
1069 // correspond to a half-packet
1070 SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0));
1071
1072 for (Index kk = endk; kk < depth; kk++) {
1073 SLhsPacketQuarter a0;
1074 SRhsPacketQuarter b0;
1075 straits.loadLhsUnaligned(blB, a0);
1076 straits.loadRhs(blA, b0);
1077 straits.madd(a0, b0, c0, b0, fix<0>);
1078 blB += SwappedTraits::LhsProgress / 4;
1079 blA += 1;
1080 }
1081 straits.acc(c0, alphav, R);
1082 } else {
1083 straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R);
1084 }
1085 res.scatterPacket(i, j2, R);
1086 }
1087};
1088
1089template <int nr, Index LhsProgress, Index RhsProgress, typename LhsScalar, typename RhsScalar, typename ResScalar,
1090 typename AccPacket, typename LhsPacket, typename RhsPacket, typename ResPacket, typename GEBPTraits,
1091 typename LinearMapper, typename DataMapper>
1092struct lhs_process_one_packet {
1093 typedef typename GEBPTraits::RhsPacketx4 RhsPacketx4;
1094
1095 EIGEN_STRONG_INLINE void peeled_kc_onestep(Index K, const LhsScalar* blA, const RhsScalar* blB, GEBPTraits traits,
1096 LhsPacket* A0, RhsPacketx4* rhs_panel, RhsPacket* T0, AccPacket* C0,
1097 AccPacket* C1, AccPacket* C2, AccPacket* C3) {
1098 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1X4");
1099 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!");
1100 traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], *A0);
1101 traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], *rhs_panel);
1102 traits.madd(*A0, *rhs_panel, *C0, *T0, fix<0>);
1103 traits.madd(*A0, *rhs_panel, *C1, *T0, fix<1>);
1104 traits.madd(*A0, *rhs_panel, *C2, *T0, fix<2>);
1105 traits.madd(*A0, *rhs_panel, *C3, *T0, fix<3>);
1106#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE) && !(EIGEN_COMP_LCC)
1107 __asm__("" : "+x,m"(*A0));
1108#endif
1109 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1X4");
1110 }
1111
1112 EIGEN_STRONG_INLINE void operator()(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
1113 ResScalar alpha, Index peelStart, Index peelEnd, Index strideA, Index strideB,
1114 Index offsetA, Index offsetB, int prefetch_res_offset, Index peeled_kc, Index pk,
1115 Index cols, Index depth, Index packet_cols4) {
1116 GEBPTraits traits;
1117 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1118 // loops on each largest micro horizontal panel of lhs
1119 // (LhsProgress x depth)
1120 for (Index i = peelStart; i < peelEnd; i += LhsProgress) {
1121#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1122 EIGEN_IF_CONSTEXPR(nr >= 8) {
1123 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1124 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1125 prefetch(&blA[0]);
1126
1127 // gets res block as register
1128 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
1129 traits.initAcc(C0);
1130 traits.initAcc(C1);
1131 traits.initAcc(C2);
1132 traits.initAcc(C3);
1133 traits.initAcc(C4);
1134 traits.initAcc(C5);
1135 traits.initAcc(C6);
1136 traits.initAcc(C7);
1137
1138 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1139 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1140 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1141 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1142 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1143 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1144 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1145 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1146 r0.prefetch(prefetch_res_offset);
1147 r1.prefetch(prefetch_res_offset);
1148 r2.prefetch(prefetch_res_offset);
1149 r3.prefetch(prefetch_res_offset);
1150 r4.prefetch(prefetch_res_offset);
1151 r5.prefetch(prefetch_res_offset);
1152 r6.prefetch(prefetch_res_offset);
1153 r7.prefetch(prefetch_res_offset);
1154 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1155 prefetch(&blB[0]);
1156
1157 LhsPacket A0;
1158 for (Index k = 0; k < peeled_kc; k += pk) {
1159 RhsPacketx4 rhs_panel;
1160 RhsPacket T0;
1161#define EIGEN_GEBGP_ONESTEP(K) \
1162 do { \
1163 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX8"); \
1164 traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], A0); \
1165 traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
1166 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1167 traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
1168 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1169 traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
1170 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1171 traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
1172 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1173 traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
1174 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
1175 traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
1176 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
1177 traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
1178 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
1179 traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
1180 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
1181 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \
1182 } while (false)
1183
1184 EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX8");
1185
1186 EIGEN_GEBGP_ONESTEP(0);
1187 EIGEN_GEBGP_ONESTEP(1);
1188 EIGEN_GEBGP_ONESTEP(2);
1189 EIGEN_GEBGP_ONESTEP(3);
1190 EIGEN_GEBGP_ONESTEP(4);
1191 EIGEN_GEBGP_ONESTEP(5);
1192 EIGEN_GEBGP_ONESTEP(6);
1193 EIGEN_GEBGP_ONESTEP(7);
1194
1195 blB += pk * 8 * RhsProgress;
1196 blA += pk * (1 * LhsProgress);
1197
1198 EIGEN_ASM_COMMENT("end gebp micro kernel 1pX8");
1199 }
1200 // process remaining peeled loop
1201 for (Index k = peeled_kc; k < depth; k++) {
1202 RhsPacketx4 rhs_panel;
1203 RhsPacket T0;
1204 EIGEN_GEBGP_ONESTEP(0);
1205 blB += 8 * RhsProgress;
1206 blA += 1 * LhsProgress;
1207 }
1208
1209#undef EIGEN_GEBGP_ONESTEP
1210
1211 ResPacket R0, R1;
1212 ResPacket alphav = pset1<ResPacket>(alpha);
1213
1214 R0 = r0.template loadPacket<ResPacket>(0);
1215 R1 = r1.template loadPacket<ResPacket>(0);
1216 traits.acc(C0, alphav, R0);
1217 traits.acc(C1, alphav, R1);
1218 r0.storePacket(0, R0);
1219 r1.storePacket(0, R1);
1220
1221 R0 = r2.template loadPacket<ResPacket>(0);
1222 R1 = r3.template loadPacket<ResPacket>(0);
1223 traits.acc(C2, alphav, R0);
1224 traits.acc(C3, alphav, R1);
1225 r2.storePacket(0, R0);
1226 r3.storePacket(0, R1);
1227
1228 R0 = r4.template loadPacket<ResPacket>(0);
1229 R1 = r5.template loadPacket<ResPacket>(0);
1230 traits.acc(C4, alphav, R0);
1231 traits.acc(C5, alphav, R1);
1232 r4.storePacket(0, R0);
1233 r5.storePacket(0, R1);
1234
1235 R0 = r6.template loadPacket<ResPacket>(0);
1236 R1 = r7.template loadPacket<ResPacket>(0);
1237 traits.acc(C6, alphav, R0);
1238 traits.acc(C7, alphav, R1);
1239 r6.storePacket(0, R0);
1240 r7.storePacket(0, R1);
1241 }
1242 }
1243#endif
1244
1245 // loops on each largest micro vertical panel of rhs (depth * nr)
1246 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1247 // We select a LhsProgress x nr micro block of res
1248 // which is entirely stored into 1 x nr registers.
1249
1250 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1251 prefetch(&blA[0]);
1252
1253 // gets res block as register
1254 AccPacket C0, C1, C2, C3;
1255 traits.initAcc(C0);
1256 traits.initAcc(C1);
1257 traits.initAcc(C2);
1258 traits.initAcc(C3);
1259 // To improve instruction pipelining, let's double the accumulation registers:
1260 // even k will accumulate in C*, while odd k will accumulate in D*.
1261 // This trick is crucial to get good performance with FMA, otherwise it is
1262 // actually faster to perform separated MUL+ADD because of a naturally
1263 // better instruction-level parallelism.
1264 AccPacket D0, D1, D2, D3;
1265 traits.initAcc(D0);
1266 traits.initAcc(D1);
1267 traits.initAcc(D2);
1268 traits.initAcc(D3);
1269
1270 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1271 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1272 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1273 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1274
1275 r0.prefetch(prefetch_res_offset);
1276 r1.prefetch(prefetch_res_offset);
1277 r2.prefetch(prefetch_res_offset);
1278 r3.prefetch(prefetch_res_offset);
1279
1280 // performs "inner" products
1281 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
1282 prefetch(&blB[0]);
1283 LhsPacket A0, A1;
1284
1285 for (Index k = 0; k < peeled_kc; k += pk) {
1286 EIGEN_ASM_COMMENT("begin gebp micro kernel 1/half/quarterX4");
1287 RhsPacketx4 rhs_panel;
1288 RhsPacket T0;
1289
1290 internal::prefetch(blB + (48 + 0));
1291 peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1292 peeled_kc_onestep(1, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1293 peeled_kc_onestep(2, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1294 peeled_kc_onestep(3, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1295 internal::prefetch(blB + (48 + 16));
1296 peeled_kc_onestep(4, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1297 peeled_kc_onestep(5, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1298 peeled_kc_onestep(6, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1299 peeled_kc_onestep(7, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1300
1301 blB += pk * 4 * RhsProgress;
1302 blA += pk * LhsProgress;
1303
1304 EIGEN_ASM_COMMENT("end gebp micro kernel 1/half/quarterX4");
1305 }
1306 C0 = padd(C0, D0);
1307 C1 = padd(C1, D1);
1308 C2 = padd(C2, D2);
1309 C3 = padd(C3, D3);
1310
1311 // process remaining peeled loop
1312 for (Index k = peeled_kc; k < depth; k++) {
1313 RhsPacketx4 rhs_panel;
1314 RhsPacket T0;
1315 peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1316 blB += 4 * RhsProgress;
1317 blA += LhsProgress;
1318 }
1319
1320 ResPacket R0, R1;
1321 ResPacket alphav = pset1<ResPacket>(alpha);
1322
1323 R0 = r0.template loadPacket<ResPacket>(0);
1324 R1 = r1.template loadPacket<ResPacket>(0);
1325 traits.acc(C0, alphav, R0);
1326 traits.acc(C1, alphav, R1);
1327 r0.storePacket(0, R0);
1328 r1.storePacket(0, R1);
1329
1330 R0 = r2.template loadPacket<ResPacket>(0);
1331 R1 = r3.template loadPacket<ResPacket>(0);
1332 traits.acc(C2, alphav, R0);
1333 traits.acc(C3, alphav, R1);
1334 r2.storePacket(0, R0);
1335 r3.storePacket(0, R1);
1336 }
1337
1338 // Deal with remaining columns of the rhs
1339 for (Index j2 = packet_cols4; j2 < cols; j2++) {
1340 // One column at a time
1341 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1342 prefetch(&blA[0]);
1343
1344 // gets res block as register
1345 AccPacket C0;
1346 traits.initAcc(C0);
1347
1348 LinearMapper r0 = res.getLinearMapper(i, j2);
1349
1350 // performs "inner" products
1351 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
1352 LhsPacket A0;
1353
1354 for (Index k = 0; k < peeled_kc; k += pk) {
1355 EIGEN_ASM_COMMENT("begin gebp micro kernel 1/half/quarterX1");
1356 RhsPacket B_0;
1357
1358#define EIGEN_GEBGP_ONESTEP(K) \
1359 do { \
1360 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1/half/quarterX1"); \
1361 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1362 /* FIXME: why unaligned???? */ \
1363 traits.loadLhsUnaligned(&blA[(0 + 1 * K) * LhsProgress], A0); \
1364 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
1365 traits.madd(A0, B_0, C0, B_0, fix<0>); \
1366 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1/half/quarterX1"); \
1367 } while (false);
1368
1369 EIGEN_GEBGP_ONESTEP(0);
1370 EIGEN_GEBGP_ONESTEP(1);
1371 EIGEN_GEBGP_ONESTEP(2);
1372 EIGEN_GEBGP_ONESTEP(3);
1373 EIGEN_GEBGP_ONESTEP(4);
1374 EIGEN_GEBGP_ONESTEP(5);
1375 EIGEN_GEBGP_ONESTEP(6);
1376 EIGEN_GEBGP_ONESTEP(7);
1377
1378 blB += pk * RhsProgress;
1379 blA += pk * LhsProgress;
1380
1381 EIGEN_ASM_COMMENT("end gebp micro kernel 1/half/quarterX1");
1382 }
1383
1384 // process remaining peeled loop
1385 for (Index k = peeled_kc; k < depth; k++) {
1386 RhsPacket B_0;
1387 EIGEN_GEBGP_ONESTEP(0);
1388 blB += RhsProgress;
1389 blA += LhsProgress;
1390 }
1391#undef EIGEN_GEBGP_ONESTEP
1392 ResPacket R0;
1393 ResPacket alphav = pset1<ResPacket>(alpha);
1394 R0 = r0.template loadPacket<ResPacket>(0);
1395 traits.acc(C0, alphav, R0);
1396 r0.storePacket(0, R0);
1397 }
1398 }
1399 }
1400};
1401
1402template <int nr, Index LhsProgress, Index RhsProgress, typename LhsScalar, typename RhsScalar, typename ResScalar,
1403 typename AccPacket, typename LhsPacket, typename RhsPacket, typename ResPacket, typename GEBPTraits,
1404 typename LinearMapper, typename DataMapper>
1405struct lhs_process_fraction_of_packet
1406 : lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket,
1407 RhsPacket, ResPacket, GEBPTraits, LinearMapper, DataMapper> {
1408 EIGEN_STRONG_INLINE void peeled_kc_onestep(Index K, const LhsScalar* blA, const RhsScalar* blB, GEBPTraits traits,
1409 LhsPacket* A0, RhsPacket* B_0, RhsPacket* B1, RhsPacket* B2, RhsPacket* B3,
1410 AccPacket* C0, AccPacket* C1, AccPacket* C2, AccPacket* C3) {
1411 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1X4");
1412 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!");
1413 traits.loadLhsUnaligned(&blA[(0 + 1 * K) * (LhsProgress)], *A0);
1414 traits.broadcastRhs(&blB[(0 + 4 * K) * RhsProgress], *B_0, *B1, *B2, *B3);
1415 traits.madd(*A0, *B_0, *C0, *B_0);
1416 traits.madd(*A0, *B1, *C1, *B1);
1417 traits.madd(*A0, *B2, *C2, *B2);
1418 traits.madd(*A0, *B3, *C3, *B3);
1419 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1X4");
1420 }
1421};
1422
1423template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
1424 bool ConjugateLhs, bool ConjugateRhs>
1425EIGEN_DONT_INLINE void gebp_kernel<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs,
1426 ConjugateRhs>::operator()(const DataMapper& res, const LhsScalar* blockA,
1427 const RhsScalar* blockB, Index rows, Index depth,
1428 Index cols, ResScalar alpha, Index strideA, Index strideB,
1429 Index offsetA, Index offsetB) {
1430 Traits traits;
1431 SwappedTraits straits;
1432
1433 if (strideA == -1) strideA = depth;
1434 if (strideB == -1) strideB = depth;
1435 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
1436 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1437 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1438 const Index peeled_mc3 = mr >= 3 * Traits::LhsProgress ? (rows / (3 * LhsProgress)) * (3 * LhsProgress) : 0;
1439 const Index peeled_mc2 =
1440 mr >= 2 * Traits::LhsProgress ? peeled_mc3 + ((rows - peeled_mc3) / (2 * LhsProgress)) * (2 * LhsProgress) : 0;
1441 const Index peeled_mc1 =
1442 mr >= 1 * Traits::LhsProgress ? peeled_mc2 + ((rows - peeled_mc2) / (1 * LhsProgress)) * (1 * LhsProgress) : 0;
1443 const Index peeled_mc_half =
1444 mr >= LhsProgressHalf ? peeled_mc1 + ((rows - peeled_mc1) / (LhsProgressHalf)) * (LhsProgressHalf) : 0;
1445 const Index peeled_mc_quarter =
1446 mr >= LhsProgressQuarter
1447 ? peeled_mc_half + ((rows - peeled_mc_half) / (LhsProgressQuarter)) * (LhsProgressQuarter)
1448 : 0;
1449 enum { pk = 8 }; // NOTE Such a large peeling factor is important for large matrices (~ +5% when >1000 on Haswell)
1450 const Index peeled_kc = depth & ~(pk - 1);
1451 const int prefetch_res_offset = 32 / sizeof(ResScalar);
1452 // const Index depth2 = depth & ~1;
1453
1454 //---------- Process 3 * LhsProgress rows at once ----------
1455 // This corresponds to 3*LhsProgress x nr register blocks.
1456 // Usually, make sense only with FMA
1457 if (mr >= 3 * Traits::LhsProgress) {
1458 // Here, the general idea is to loop on each largest micro horizontal panel of the lhs (3*Traits::LhsProgress x
1459 // depth) and on each largest micro vertical panel of the rhs (depth * nr). Blocking sizes, i.e., 'depth' has been
1460 // computed so that the micro horizontal panel of the lhs fit in L1. However, if depth is too small, we can extend
1461 // the number of rows of these horizontal panels. This actual number of rows is computed as follow:
1462 const Index l1 = defaultL1CacheSize; // in Bytes, TODO, l1 should be passed to this function.
1463 // The max(1, ...) here is needed because we may be using blocking params larger than what our known l1 cache size
1464 // suggests we should be using: either because our known l1 cache size is inaccurate (e.g. on Android, we can only
1465 // guess), or because we are testing specific blocking sizes.
1466 const Index actual_panel_rows =
1467 (3 * LhsProgress) * std::max<Index>(1, ((l1 - sizeof(ResScalar) * mr * nr - depth * nr * sizeof(RhsScalar)) /
1468 (depth * sizeof(LhsScalar) * 3 * LhsProgress)));
1469 for (Index i1 = 0; i1 < peeled_mc3; i1 += actual_panel_rows) {
1470 const Index actual_panel_end = (std::min)(i1 + actual_panel_rows, peeled_mc3);
1471#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1472 EIGEN_IF_CONSTEXPR(nr >= 8) {
1473 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1474 for (Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1475 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * LhsProgress)];
1476 prefetch(&blA[0]);
1477 // gets res block as register
1478 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11, C12, C13, C14, C15, C16, C17, C18, C19, C20,
1479 C21, C22, C23;
1480 traits.initAcc(C0);
1481 traits.initAcc(C1);
1482 traits.initAcc(C2);
1483 traits.initAcc(C3);
1484 traits.initAcc(C4);
1485 traits.initAcc(C5);
1486 traits.initAcc(C6);
1487 traits.initAcc(C7);
1488 traits.initAcc(C8);
1489 traits.initAcc(C9);
1490 traits.initAcc(C10);
1491 traits.initAcc(C11);
1492 traits.initAcc(C12);
1493 traits.initAcc(C13);
1494 traits.initAcc(C14);
1495 traits.initAcc(C15);
1496 traits.initAcc(C16);
1497 traits.initAcc(C17);
1498 traits.initAcc(C18);
1499 traits.initAcc(C19);
1500 traits.initAcc(C20);
1501 traits.initAcc(C21);
1502 traits.initAcc(C22);
1503 traits.initAcc(C23);
1504
1505 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1506 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1507 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1508 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1509 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1510 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1511 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1512 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1513
1514 r0.prefetch(0);
1515 r1.prefetch(0);
1516 r2.prefetch(0);
1517 r3.prefetch(0);
1518 r4.prefetch(0);
1519 r5.prefetch(0);
1520 r6.prefetch(0);
1521 r7.prefetch(0);
1522
1523 // performs "inner" products
1524 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1525 prefetch(&blB[0]);
1526 LhsPacket A0, A1;
1527 for (Index k = 0; k < peeled_kc; k += pk) {
1528 EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX8");
1529 // 27 registers are taken (24 for acc, 3 for lhs).
1530 RhsPanel27 rhs_panel;
1531 RhsPacket T0;
1532 LhsPacket A2;
1533#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
1534// see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
1535// without this workaround A0, A1, and A2 are loaded in the same register,
1536// which is not good for pipelining
1537#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND __asm__("" : "+w,m"(A0), "+w,m"(A1), "+w,m"(A2));
1538#else
1539#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND
1540#endif
1541
1542#define EIGEN_GEBP_ONESTEP(K) \
1543 do { \
1544 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX8"); \
1545 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1546 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1547 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1548 EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND traits.loadRhs(blB + (0 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1549 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1550 traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
1551 traits.madd(A2, rhs_panel, C16, T0, fix<0>); \
1552 traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1553 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1554 traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
1555 traits.madd(A2, rhs_panel, C17, T0, fix<1>); \
1556 traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1557 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1558 traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
1559 traits.madd(A2, rhs_panel, C18, T0, fix<2>); \
1560 traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1561 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1562 traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
1563 traits.madd(A2, rhs_panel, C19, T0, fix<3>); \
1564 traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1565 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
1566 traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
1567 traits.madd(A2, rhs_panel, C20, T0, fix<0>); \
1568 traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1569 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
1570 traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
1571 traits.madd(A2, rhs_panel, C21, T0, fix<1>); \
1572 traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1573 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
1574 traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
1575 traits.madd(A2, rhs_panel, C22, T0, fix<2>); \
1576 traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1577 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
1578 traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
1579 traits.madd(A2, rhs_panel, C23, T0, fix<3>); \
1580 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \
1581 } while (false)
1582
1583 EIGEN_GEBP_ONESTEP(0);
1584 EIGEN_GEBP_ONESTEP(1);
1585 EIGEN_GEBP_ONESTEP(2);
1586 EIGEN_GEBP_ONESTEP(3);
1587 EIGEN_GEBP_ONESTEP(4);
1588 EIGEN_GEBP_ONESTEP(5);
1589 EIGEN_GEBP_ONESTEP(6);
1590 EIGEN_GEBP_ONESTEP(7);
1591
1592 blB += pk * 8 * RhsProgress;
1593 blA += pk * 3 * Traits::LhsProgress;
1594 EIGEN_ASM_COMMENT("end gebp micro kernel 3pX8");
1595 }
1596
1597 // process remaining peeled loop
1598 for (Index k = peeled_kc; k < depth; k++) {
1599 RhsPanel27 rhs_panel;
1600 RhsPacket T0;
1601 LhsPacket A2;
1602 EIGEN_GEBP_ONESTEP(0);
1603 blB += 8 * RhsProgress;
1604 blA += 3 * Traits::LhsProgress;
1605 }
1606
1607#undef EIGEN_GEBP_ONESTEP
1608
1609 ResPacket R0, R1, R2;
1610 ResPacket alphav = pset1<ResPacket>(alpha);
1611
1612 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1613 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1614 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1615 traits.acc(C0, alphav, R0);
1616 traits.acc(C8, alphav, R1);
1617 traits.acc(C16, alphav, R2);
1618 r0.storePacket(0 * Traits::ResPacketSize, R0);
1619 r0.storePacket(1 * Traits::ResPacketSize, R1);
1620 r0.storePacket(2 * Traits::ResPacketSize, R2);
1621
1622 R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1623 R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1624 R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1625 traits.acc(C1, alphav, R0);
1626 traits.acc(C9, alphav, R1);
1627 traits.acc(C17, alphav, R2);
1628 r1.storePacket(0 * Traits::ResPacketSize, R0);
1629 r1.storePacket(1 * Traits::ResPacketSize, R1);
1630 r1.storePacket(2 * Traits::ResPacketSize, R2);
1631
1632 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1633 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1634 R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1635 traits.acc(C2, alphav, R0);
1636 traits.acc(C10, alphav, R1);
1637 traits.acc(C18, alphav, R2);
1638 r2.storePacket(0 * Traits::ResPacketSize, R0);
1639 r2.storePacket(1 * Traits::ResPacketSize, R1);
1640 r2.storePacket(2 * Traits::ResPacketSize, R2);
1641
1642 R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1643 R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1644 R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1645 traits.acc(C3, alphav, R0);
1646 traits.acc(C11, alphav, R1);
1647 traits.acc(C19, alphav, R2);
1648 r3.storePacket(0 * Traits::ResPacketSize, R0);
1649 r3.storePacket(1 * Traits::ResPacketSize, R1);
1650 r3.storePacket(2 * Traits::ResPacketSize, R2);
1651
1652 R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1653 R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1654 R2 = r4.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1655 traits.acc(C4, alphav, R0);
1656 traits.acc(C12, alphav, R1);
1657 traits.acc(C20, alphav, R2);
1658 r4.storePacket(0 * Traits::ResPacketSize, R0);
1659 r4.storePacket(1 * Traits::ResPacketSize, R1);
1660 r4.storePacket(2 * Traits::ResPacketSize, R2);
1661
1662 R0 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1663 R1 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1664 R2 = r5.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1665 traits.acc(C5, alphav, R0);
1666 traits.acc(C13, alphav, R1);
1667 traits.acc(C21, alphav, R2);
1668 r5.storePacket(0 * Traits::ResPacketSize, R0);
1669 r5.storePacket(1 * Traits::ResPacketSize, R1);
1670 r5.storePacket(2 * Traits::ResPacketSize, R2);
1671
1672 R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1673 R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1674 R2 = r6.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1675 traits.acc(C6, alphav, R0);
1676 traits.acc(C14, alphav, R1);
1677 traits.acc(C22, alphav, R2);
1678 r6.storePacket(0 * Traits::ResPacketSize, R0);
1679 r6.storePacket(1 * Traits::ResPacketSize, R1);
1680 r6.storePacket(2 * Traits::ResPacketSize, R2);
1681
1682 R0 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1683 R1 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1684 R2 = r7.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1685 traits.acc(C7, alphav, R0);
1686 traits.acc(C15, alphav, R1);
1687 traits.acc(C23, alphav, R2);
1688 r7.storePacket(0 * Traits::ResPacketSize, R0);
1689 r7.storePacket(1 * Traits::ResPacketSize, R1);
1690 r7.storePacket(2 * Traits::ResPacketSize, R2);
1691 }
1692 }
1693 }
1694#endif
1695 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1696 for (Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1697 // We selected a 3*Traits::LhsProgress x nr micro block of res which is entirely
1698 // stored into 3 x nr registers.
1699
1700 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * LhsProgress)];
1701 prefetch(&blA[0]);
1702
1703 // gets res block as register
1704 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11;
1705 traits.initAcc(C0);
1706 traits.initAcc(C1);
1707 traits.initAcc(C2);
1708 traits.initAcc(C3);
1709 traits.initAcc(C4);
1710 traits.initAcc(C5);
1711 traits.initAcc(C6);
1712 traits.initAcc(C7);
1713 traits.initAcc(C8);
1714 traits.initAcc(C9);
1715 traits.initAcc(C10);
1716 traits.initAcc(C11);
1717
1718 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1719 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1720 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1721 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1722
1723 r0.prefetch(0);
1724 r1.prefetch(0);
1725 r2.prefetch(0);
1726 r3.prefetch(0);
1727
1728 // performs "inner" products
1729 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
1730 prefetch(&blB[0]);
1731 LhsPacket A0, A1;
1732
1733 for (Index k = 0; k < peeled_kc; k += pk) {
1734 EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4");
1735 // 15 registers are taken (12 for acc, 3 for lhs).
1736 RhsPanel15 rhs_panel;
1737 RhsPacket T0;
1738 LhsPacket A2;
1739#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
1740// see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
1741// without this workaround A0, A1, and A2 are loaded in the same register,
1742// which is not good for pipelining
1743#define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND __asm__("" : "+w,m"(A0), "+w,m"(A1), "+w,m"(A2));
1744#else
1745#define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND
1746#endif
1747#define EIGEN_GEBP_ONESTEP(K) \
1748 do { \
1749 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \
1750 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1751 internal::prefetch(blA + (3 * K + 16) * LhsProgress); \
1752 if (EIGEN_ARCH_ARM || EIGEN_ARCH_MIPS) { \
1753 internal::prefetch(blB + (4 * K + 16) * RhsProgress); \
1754 } /* Bug 953 */ \
1755 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1756 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1757 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1758 EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND \
1759 traits.loadRhs(blB + (0 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1760 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1761 traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
1762 traits.madd(A2, rhs_panel, C8, T0, fix<0>); \
1763 traits.updateRhs(blB + (1 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1764 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1765 traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
1766 traits.madd(A2, rhs_panel, C9, T0, fix<1>); \
1767 traits.updateRhs(blB + (2 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1768 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1769 traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
1770 traits.madd(A2, rhs_panel, C10, T0, fix<2>); \
1771 traits.updateRhs(blB + (3 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1772 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1773 traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
1774 traits.madd(A2, rhs_panel, C11, T0, fix<3>); \
1775 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \
1776 } while (false)
1777
1778 internal::prefetch(blB);
1779 EIGEN_GEBP_ONESTEP(0);
1780 EIGEN_GEBP_ONESTEP(1);
1781 EIGEN_GEBP_ONESTEP(2);
1782 EIGEN_GEBP_ONESTEP(3);
1783 EIGEN_GEBP_ONESTEP(4);
1784 EIGEN_GEBP_ONESTEP(5);
1785 EIGEN_GEBP_ONESTEP(6);
1786 EIGEN_GEBP_ONESTEP(7);
1787
1788 blB += pk * 4 * RhsProgress;
1789 blA += pk * 3 * Traits::LhsProgress;
1790
1791 EIGEN_ASM_COMMENT("end gebp micro kernel 3pX4");
1792 }
1793 // process remaining peeled loop
1794 for (Index k = peeled_kc; k < depth; k++) {
1795 RhsPanel15 rhs_panel;
1796 RhsPacket T0;
1797 LhsPacket A2;
1798 EIGEN_GEBP_ONESTEP(0);
1799 blB += 4 * RhsProgress;
1800 blA += 3 * Traits::LhsProgress;
1801 }
1802
1803#undef EIGEN_GEBP_ONESTEP
1804
1805 ResPacket R0, R1, R2;
1806 ResPacket alphav = pset1<ResPacket>(alpha);
1807
1808 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1809 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1810 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1811 traits.acc(C0, alphav, R0);
1812 traits.acc(C4, alphav, R1);
1813 traits.acc(C8, alphav, R2);
1814 r0.storePacket(0 * Traits::ResPacketSize, R0);
1815 r0.storePacket(1 * Traits::ResPacketSize, R1);
1816 r0.storePacket(2 * Traits::ResPacketSize, R2);
1817
1818 R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1819 R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1820 R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1821 traits.acc(C1, alphav, R0);
1822 traits.acc(C5, alphav, R1);
1823 traits.acc(C9, alphav, R2);
1824 r1.storePacket(0 * Traits::ResPacketSize, R0);
1825 r1.storePacket(1 * Traits::ResPacketSize, R1);
1826 r1.storePacket(2 * Traits::ResPacketSize, R2);
1827
1828 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1829 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1830 R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1831 traits.acc(C2, alphav, R0);
1832 traits.acc(C6, alphav, R1);
1833 traits.acc(C10, alphav, R2);
1834 r2.storePacket(0 * Traits::ResPacketSize, R0);
1835 r2.storePacket(1 * Traits::ResPacketSize, R1);
1836 r2.storePacket(2 * Traits::ResPacketSize, R2);
1837
1838 R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1839 R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1840 R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1841 traits.acc(C3, alphav, R0);
1842 traits.acc(C7, alphav, R1);
1843 traits.acc(C11, alphav, R2);
1844 r3.storePacket(0 * Traits::ResPacketSize, R0);
1845 r3.storePacket(1 * Traits::ResPacketSize, R1);
1846 r3.storePacket(2 * Traits::ResPacketSize, R2);
1847 }
1848 }
1849
1850 // Deal with remaining columns of the rhs
1851 for (Index j2 = packet_cols4; j2 < cols; j2++) {
1852 for (Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1853 // One column at a time
1854 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * Traits::LhsProgress)];
1855 prefetch(&blA[0]);
1856
1857 // gets res block as register
1858 AccPacket C0, C4, C8;
1859 traits.initAcc(C0);
1860 traits.initAcc(C4);
1861 traits.initAcc(C8);
1862
1863 LinearMapper r0 = res.getLinearMapper(i, j2);
1864 r0.prefetch(0);
1865
1866 // performs "inner" products
1867 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
1868 LhsPacket A0, A1, A2;
1869
1870 for (Index k = 0; k < peeled_kc; k += pk) {
1871 EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX1");
1872 RhsPacket B_0;
1873#define EIGEN_GEBGP_ONESTEP(K) \
1874 do { \
1875 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \
1876 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1877 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1878 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1879 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1880 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
1881 traits.madd(A0, B_0, C0, B_0, fix<0>); \
1882 traits.madd(A1, B_0, C4, B_0, fix<0>); \
1883 traits.madd(A2, B_0, C8, B_0, fix<0>); \
1884 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \
1885 } while (false)
1886
1887 EIGEN_GEBGP_ONESTEP(0);
1888 EIGEN_GEBGP_ONESTEP(1);
1889 EIGEN_GEBGP_ONESTEP(2);
1890 EIGEN_GEBGP_ONESTEP(3);
1891 EIGEN_GEBGP_ONESTEP(4);
1892 EIGEN_GEBGP_ONESTEP(5);
1893 EIGEN_GEBGP_ONESTEP(6);
1894 EIGEN_GEBGP_ONESTEP(7);
1895
1896 blB += int(pk) * int(RhsProgress);
1897 blA += int(pk) * 3 * int(Traits::LhsProgress);
1898
1899 EIGEN_ASM_COMMENT("end gebp micro kernel 3pX1");
1900 }
1901
1902 // process remaining peeled loop
1903 for (Index k = peeled_kc; k < depth; k++) {
1904 RhsPacket B_0;
1905 EIGEN_GEBGP_ONESTEP(0);
1906 blB += RhsProgress;
1907 blA += 3 * Traits::LhsProgress;
1908 }
1909#undef EIGEN_GEBGP_ONESTEP
1910 ResPacket R0, R1, R2;
1911 ResPacket alphav = pset1<ResPacket>(alpha);
1912
1913 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1914 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1915 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1916 traits.acc(C0, alphav, R0);
1917 traits.acc(C4, alphav, R1);
1918 traits.acc(C8, alphav, R2);
1919 r0.storePacket(0 * Traits::ResPacketSize, R0);
1920 r0.storePacket(1 * Traits::ResPacketSize, R1);
1921 r0.storePacket(2 * Traits::ResPacketSize, R2);
1922 }
1923 }
1924 }
1925 }
1926
1927 //---------- Process 2 * LhsProgress rows at once ----------
1928 if (mr >= 2 * Traits::LhsProgress) {
1929 const Index l1 = defaultL1CacheSize; // in Bytes, TODO, l1 should be passed to this function.
1930 // The max(1, ...) here is needed because we may be using blocking params larger than what our known l1 cache size
1931 // suggests we should be using: either because our known l1 cache size is inaccurate (e.g. on Android, we can only
1932 // guess), or because we are testing specific blocking sizes.
1933 Index actual_panel_rows =
1934 (2 * LhsProgress) * std::max<Index>(1, ((l1 - sizeof(ResScalar) * mr * nr - depth * nr * sizeof(RhsScalar)) /
1935 (depth * sizeof(LhsScalar) * 2 * LhsProgress)));
1936
1937 for (Index i1 = peeled_mc3; i1 < peeled_mc2; i1 += actual_panel_rows) {
1938 Index actual_panel_end = (std::min)(i1 + actual_panel_rows, peeled_mc2);
1939#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1940 EIGEN_IF_CONSTEXPR(nr >= 8) {
1941 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1942 for (Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
1943 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
1944 prefetch(&blA[0]);
1945
1946 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11, C12, C13, C14, C15;
1947 traits.initAcc(C0);
1948 traits.initAcc(C1);
1949 traits.initAcc(C2);
1950 traits.initAcc(C3);
1951 traits.initAcc(C4);
1952 traits.initAcc(C5);
1953 traits.initAcc(C6);
1954 traits.initAcc(C7);
1955 traits.initAcc(C8);
1956 traits.initAcc(C9);
1957 traits.initAcc(C10);
1958 traits.initAcc(C11);
1959 traits.initAcc(C12);
1960 traits.initAcc(C13);
1961 traits.initAcc(C14);
1962 traits.initAcc(C15);
1963
1964 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1965 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1966 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1967 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1968 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1969 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1970 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1971 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1972 r0.prefetch(prefetch_res_offset);
1973 r1.prefetch(prefetch_res_offset);
1974 r2.prefetch(prefetch_res_offset);
1975 r3.prefetch(prefetch_res_offset);
1976 r4.prefetch(prefetch_res_offset);
1977 r5.prefetch(prefetch_res_offset);
1978 r6.prefetch(prefetch_res_offset);
1979 r7.prefetch(prefetch_res_offset);
1980
1981 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1982 prefetch(&blB[0]);
1983 LhsPacket A0, A1;
1984 for (Index k = 0; k < peeled_kc; k += pk) {
1985 RhsPacketx4 rhs_panel;
1986 RhsPacket T0;
1987// NOTE: the begin/end asm comments below work around bug 935!
1988// but they are not enough for gcc>=6 without FMA (bug 1637)
1989#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE)
1990#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__("" : [a0] "+x,m"(A0), [a1] "+x,m"(A1));
1991#else
1992#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND
1993#endif
1994#define EIGEN_GEBGP_ONESTEP(K) \
1995 do { \
1996 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \
1997 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
1998 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
1999 traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
2000 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
2001 traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
2002 traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
2003 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
2004 traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
2005 traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
2006 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
2007 traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
2008 traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
2009 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
2010 traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
2011 traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
2012 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
2013 traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
2014 traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
2015 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
2016 traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
2017 traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
2018 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
2019 traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
2020 traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
2021 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
2022 traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
2023 EIGEN_GEBP_2Px8_SPILLING_WORKAROUND EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \
2024 } while (false)
2025
2026 EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX8");
2027
2028 EIGEN_GEBGP_ONESTEP(0);
2029 EIGEN_GEBGP_ONESTEP(1);
2030 EIGEN_GEBGP_ONESTEP(2);
2031 EIGEN_GEBGP_ONESTEP(3);
2032 EIGEN_GEBGP_ONESTEP(4);
2033 EIGEN_GEBGP_ONESTEP(5);
2034 EIGEN_GEBGP_ONESTEP(6);
2035 EIGEN_GEBGP_ONESTEP(7);
2036
2037 blB += pk * 8 * RhsProgress;
2038 blA += pk * (2 * Traits::LhsProgress);
2039
2040 EIGEN_ASM_COMMENT("end gebp micro kernel 2pX8");
2041 }
2042 // process remaining peeled loop
2043 for (Index k = peeled_kc; k < depth; k++) {
2044 RhsPacketx4 rhs_panel;
2045 RhsPacket T0;
2046 EIGEN_GEBGP_ONESTEP(0);
2047 blB += 8 * RhsProgress;
2048 blA += 2 * Traits::LhsProgress;
2049 }
2050
2051#undef EIGEN_GEBGP_ONESTEP
2052
2053 ResPacket R0, R1, R2, R3;
2054 ResPacket alphav = pset1<ResPacket>(alpha);
2055
2056 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2057 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2058 R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2059 R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2060 traits.acc(C0, alphav, R0);
2061 traits.acc(C8, alphav, R1);
2062 traits.acc(C1, alphav, R2);
2063 traits.acc(C9, alphav, R3);
2064 r0.storePacket(0 * Traits::ResPacketSize, R0);
2065 r0.storePacket(1 * Traits::ResPacketSize, R1);
2066 r1.storePacket(0 * Traits::ResPacketSize, R2);
2067 r1.storePacket(1 * Traits::ResPacketSize, R3);
2068
2069 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2070 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2071 R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2072 R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2073 traits.acc(C2, alphav, R0);
2074 traits.acc(C10, alphav, R1);
2075 traits.acc(C3, alphav, R2);
2076 traits.acc(C11, alphav, R3);
2077 r2.storePacket(0 * Traits::ResPacketSize, R0);
2078 r2.storePacket(1 * Traits::ResPacketSize, R1);
2079 r3.storePacket(0 * Traits::ResPacketSize, R2);
2080 r3.storePacket(1 * Traits::ResPacketSize, R3);
2081
2082 R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2083 R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2084 R2 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2085 R3 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2086 traits.acc(C4, alphav, R0);
2087 traits.acc(C12, alphav, R1);
2088 traits.acc(C5, alphav, R2);
2089 traits.acc(C13, alphav, R3);
2090 r4.storePacket(0 * Traits::ResPacketSize, R0);
2091 r4.storePacket(1 * Traits::ResPacketSize, R1);
2092 r5.storePacket(0 * Traits::ResPacketSize, R2);
2093 r5.storePacket(1 * Traits::ResPacketSize, R3);
2094
2095 R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2096 R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2097 R2 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2098 R3 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2099 traits.acc(C6, alphav, R0);
2100 traits.acc(C14, alphav, R1);
2101 traits.acc(C7, alphav, R2);
2102 traits.acc(C15, alphav, R3);
2103 r6.storePacket(0 * Traits::ResPacketSize, R0);
2104 r6.storePacket(1 * Traits::ResPacketSize, R1);
2105 r7.storePacket(0 * Traits::ResPacketSize, R2);
2106 r7.storePacket(1 * Traits::ResPacketSize, R3);
2107 }
2108 }
2109 }
2110#endif
2111 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2112 for (Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
2113 // We selected a 2*Traits::LhsProgress x nr micro block of res which is entirely
2114 // stored into 2 x nr registers.
2115
2116 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
2117 prefetch(&blA[0]);
2118
2119 // gets res block as register
2120 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
2121 traits.initAcc(C0);
2122 traits.initAcc(C1);
2123 traits.initAcc(C2);
2124 traits.initAcc(C3);
2125 traits.initAcc(C4);
2126 traits.initAcc(C5);
2127 traits.initAcc(C6);
2128 traits.initAcc(C7);
2129
2130 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
2131 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
2132 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
2133 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
2134
2135 r0.prefetch(prefetch_res_offset);
2136 r1.prefetch(prefetch_res_offset);
2137 r2.prefetch(prefetch_res_offset);
2138 r3.prefetch(prefetch_res_offset);
2139
2140 // performs "inner" products
2141 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
2142 prefetch(&blB[0]);
2143 LhsPacket A0, A1;
2144
2145 for (Index k = 0; k < peeled_kc; k += pk) {
2146 EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX4");
2147 RhsPacketx4 rhs_panel;
2148 RhsPacket T0;
2149
2150// NOTE: the begin/end asm comments below work around bug 935!
2151// but they are not enough for gcc>=6 without FMA (bug 1637)
2152#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE) && !(EIGEN_COMP_LCC)
2153#define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND __asm__("" : [a0] "+x,m"(A0), [a1] "+x,m"(A1));
2154#else
2155#define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND
2156#endif
2157#define EIGEN_GEBGP_ONESTEP(K) \
2158 do { \
2159 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \
2160 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
2161 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
2162 traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], rhs_panel); \
2163 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
2164 traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
2165 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
2166 traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
2167 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
2168 traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
2169 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
2170 traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
2171 EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \
2172 EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \
2173 } while (false)
2174
2175 internal::prefetch(blB + (48 + 0));
2176 EIGEN_GEBGP_ONESTEP(0);
2177 EIGEN_GEBGP_ONESTEP(1);
2178 EIGEN_GEBGP_ONESTEP(2);
2179 EIGEN_GEBGP_ONESTEP(3);
2180 internal::prefetch(blB + (48 + 16));
2181 EIGEN_GEBGP_ONESTEP(4);
2182 EIGEN_GEBGP_ONESTEP(5);
2183 EIGEN_GEBGP_ONESTEP(6);
2184 EIGEN_GEBGP_ONESTEP(7);
2185
2186 blB += pk * 4 * RhsProgress;
2187 blA += pk * (2 * Traits::LhsProgress);
2188
2189 EIGEN_ASM_COMMENT("end gebp micro kernel 2pX4");
2190 }
2191 // process remaining peeled loop
2192 for (Index k = peeled_kc; k < depth; k++) {
2193 RhsPacketx4 rhs_panel;
2194 RhsPacket T0;
2195 EIGEN_GEBGP_ONESTEP(0);
2196 blB += 4 * RhsProgress;
2197 blA += 2 * Traits::LhsProgress;
2198 }
2199#undef EIGEN_GEBGP_ONESTEP
2200
2201 ResPacket R0, R1, R2, R3;
2202 ResPacket alphav = pset1<ResPacket>(alpha);
2203
2204 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2205 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2206 R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2207 R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2208 traits.acc(C0, alphav, R0);
2209 traits.acc(C4, alphav, R1);
2210 traits.acc(C1, alphav, R2);
2211 traits.acc(C5, alphav, R3);
2212 r0.storePacket(0 * Traits::ResPacketSize, R0);
2213 r0.storePacket(1 * Traits::ResPacketSize, R1);
2214 r1.storePacket(0 * Traits::ResPacketSize, R2);
2215 r1.storePacket(1 * Traits::ResPacketSize, R3);
2216
2217 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2218 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2219 R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2220 R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2221 traits.acc(C2, alphav, R0);
2222 traits.acc(C6, alphav, R1);
2223 traits.acc(C3, alphav, R2);
2224 traits.acc(C7, alphav, R3);
2225 r2.storePacket(0 * Traits::ResPacketSize, R0);
2226 r2.storePacket(1 * Traits::ResPacketSize, R1);
2227 r3.storePacket(0 * Traits::ResPacketSize, R2);
2228 r3.storePacket(1 * Traits::ResPacketSize, R3);
2229 }
2230 }
2231
2232 // Deal with remaining columns of the rhs
2233 for (Index j2 = packet_cols4; j2 < cols; j2++) {
2234 for (Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
2235 // One column at a time
2236 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
2237 prefetch(&blA[0]);
2238
2239 // gets res block as register
2240 AccPacket C0, C4;
2241 traits.initAcc(C0);
2242 traits.initAcc(C4);
2243
2244 LinearMapper r0 = res.getLinearMapper(i, j2);
2245 r0.prefetch(prefetch_res_offset);
2246
2247 // performs "inner" products
2248 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
2249 LhsPacket A0, A1;
2250
2251 for (Index k = 0; k < peeled_kc; k += pk) {
2252 EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX1");
2253 RhsPacket B_0, B1;
2254
2255#define EIGEN_GEBGP_ONESTEP(K) \
2256 do { \
2257 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX1"); \
2258 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
2259 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
2260 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
2261 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
2262 traits.madd(A0, B_0, C0, B1, fix<0>); \
2263 traits.madd(A1, B_0, C4, B_0, fix<0>); \
2264 EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX1"); \
2265 } while (false)
2266
2267 EIGEN_GEBGP_ONESTEP(0);
2268 EIGEN_GEBGP_ONESTEP(1);
2269 EIGEN_GEBGP_ONESTEP(2);
2270 EIGEN_GEBGP_ONESTEP(3);
2271 EIGEN_GEBGP_ONESTEP(4);
2272 EIGEN_GEBGP_ONESTEP(5);
2273 EIGEN_GEBGP_ONESTEP(6);
2274 EIGEN_GEBGP_ONESTEP(7);
2275
2276 blB += int(pk) * int(RhsProgress);
2277 blA += int(pk) * 2 * int(Traits::LhsProgress);
2278
2279 EIGEN_ASM_COMMENT("end gebp micro kernel 2pX1");
2280 }
2281
2282 // process remaining peeled loop
2283 for (Index k = peeled_kc; k < depth; k++) {
2284 RhsPacket B_0, B1;
2285 EIGEN_GEBGP_ONESTEP(0);
2286 blB += RhsProgress;
2287 blA += 2 * Traits::LhsProgress;
2288 }
2289#undef EIGEN_GEBGP_ONESTEP
2290 ResPacket R0, R1;
2291 ResPacket alphav = pset1<ResPacket>(alpha);
2292
2293 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2294 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2295 traits.acc(C0, alphav, R0);
2296 traits.acc(C4, alphav, R1);
2297 r0.storePacket(0 * Traits::ResPacketSize, R0);
2298 r0.storePacket(1 * Traits::ResPacketSize, R1);
2299 }
2300 }
2301 }
2302 }
2303 //---------- Process 1 * LhsProgress rows at once ----------
2304 if (mr >= 1 * Traits::LhsProgress) {
2305 lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket,
2306 RhsPacket, ResPacket, Traits, LinearMapper, DataMapper>
2307 p;
2308 p(res, blockA, blockB, alpha, peeled_mc2, peeled_mc1, strideA, strideB, offsetA, offsetB, prefetch_res_offset,
2309 peeled_kc, pk, cols, depth, packet_cols4);
2310 }
2311 //---------- Process LhsProgressHalf rows at once ----------
2312 if ((LhsProgressHalf < LhsProgress) && mr >= LhsProgressHalf) {
2313 lhs_process_fraction_of_packet<nr, LhsProgressHalf, RhsProgressHalf, LhsScalar, RhsScalar, ResScalar, AccPacketHalf,
2314 LhsPacketHalf, RhsPacketHalf, ResPacketHalf, HalfTraits, LinearMapper, DataMapper>
2315 p;
2316 p(res, blockA, blockB, alpha, peeled_mc1, peeled_mc_half, strideA, strideB, offsetA, offsetB, prefetch_res_offset,
2317 peeled_kc, pk, cols, depth, packet_cols4);
2318 }
2319 //---------- Process LhsProgressQuarter rows at once ----------
2320 if ((LhsProgressQuarter < LhsProgressHalf) && mr >= LhsProgressQuarter) {
2321 lhs_process_fraction_of_packet<nr, LhsProgressQuarter, RhsProgressQuarter, LhsScalar, RhsScalar, ResScalar,
2322 AccPacketQuarter, LhsPacketQuarter, RhsPacketQuarter, ResPacketQuarter,
2323 QuarterTraits, LinearMapper, DataMapper>
2324 p;
2325 p(res, blockA, blockB, alpha, peeled_mc_half, peeled_mc_quarter, strideA, strideB, offsetA, offsetB,
2326 prefetch_res_offset, peeled_kc, pk, cols, depth, packet_cols4);
2327 }
2328 //---------- Process remaining rows, 1 at once ----------
2329 if (peeled_mc_quarter < rows) {
2330#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
2331 EIGEN_IF_CONSTEXPR(nr >= 8) {
2332 // loop on each panel of the rhs
2333 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
2334 // loop on each row of the lhs (1*LhsProgress x depth)
2335 for (Index i = peeled_mc_quarter; i < rows; i += 1) {
2336 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2337 prefetch(&blA[0]);
2338 // gets a 1 x 1 res block as registers
2339 ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
2340 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
2341 for (Index k = 0; k < depth; k++) {
2342 LhsScalar A0 = blA[k];
2343 RhsScalar B_0;
2344
2345 B_0 = blB[0];
2346 C0 = cj.pmadd(A0, B_0, C0);
2347
2348 B_0 = blB[1];
2349 C1 = cj.pmadd(A0, B_0, C1);
2350
2351 B_0 = blB[2];
2352 C2 = cj.pmadd(A0, B_0, C2);
2353
2354 B_0 = blB[3];
2355 C3 = cj.pmadd(A0, B_0, C3);
2356
2357 B_0 = blB[4];
2358 C4 = cj.pmadd(A0, B_0, C4);
2359
2360 B_0 = blB[5];
2361 C5 = cj.pmadd(A0, B_0, C5);
2362
2363 B_0 = blB[6];
2364 C6 = cj.pmadd(A0, B_0, C6);
2365
2366 B_0 = blB[7];
2367 C7 = cj.pmadd(A0, B_0, C7);
2368
2369 blB += 8;
2370 }
2371 res(i, j2 + 0) += alpha * C0;
2372 res(i, j2 + 1) += alpha * C1;
2373 res(i, j2 + 2) += alpha * C2;
2374 res(i, j2 + 3) += alpha * C3;
2375 res(i, j2 + 4) += alpha * C4;
2376 res(i, j2 + 5) += alpha * C5;
2377 res(i, j2 + 6) += alpha * C6;
2378 res(i, j2 + 7) += alpha * C7;
2379 }
2380 }
2381 }
2382#endif
2383
2384 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2385 // loop on each row of the lhs (1*LhsProgress x depth)
2386 for (Index i = peeled_mc_quarter; i < rows; i += 1) {
2387 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2388 prefetch(&blA[0]);
2389 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
2390
2391 // If LhsProgress is 8 or 16, it assumes that there is a
2392 // half or quarter packet, respectively, of the same size as
2393 // nr (which is currently 4) for the return type.
2394 const int SResPacketHalfSize = unpacket_traits<typename unpacket_traits<SResPacket>::half>::size;
2395 const int SResPacketQuarterSize =
2396 unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
2397 // The following code assumes we can load SRhsPacket in such a way that
2398 // it multiplies blocks of 4 elements in SLhsPacket. This is not the
2399 // case for some customized kernels (i.e. NEON fp16). If the assumption
2400 // fails, drop down to the scalar path.
2401 constexpr bool kCanLoadSRhsQuad =
2402 (unpacket_traits<SLhsPacket>::size < 4) ||
2403 (unpacket_traits<SRhsPacket>::size % ((std::max<int>)(unpacket_traits<SLhsPacket>::size, 4) / 4)) == 0;
2404 if (kCanLoadSRhsQuad && (SwappedTraits::LhsProgress % 4) == 0 && (SwappedTraits::LhsProgress <= 16) &&
2405 (SwappedTraits::LhsProgress != 8 || SResPacketHalfSize == nr) &&
2406 (SwappedTraits::LhsProgress != 16 || SResPacketQuarterSize == nr)) {
2407 SAccPacket C0, C1, C2, C3;
2408 straits.initAcc(C0);
2409 straits.initAcc(C1);
2410 straits.initAcc(C2);
2411 straits.initAcc(C3);
2412
2413 const Index spk = (std::max)(1, SwappedTraits::LhsProgress / 4);
2414 const Index endk = (depth / spk) * spk;
2415 const Index endk4 = (depth / (spk * 4)) * (spk * 4);
2416
2417 Index k = 0;
2418 for (; k < endk4; k += 4 * spk) {
2419 SLhsPacket A0, A1;
2420 SRhsPacket B_0, B_1;
2421
2422 straits.loadLhsUnaligned(blB + 0 * SwappedTraits::LhsProgress, A0);
2423 straits.loadLhsUnaligned(blB + 1 * SwappedTraits::LhsProgress, A1);
2424
2425 straits.loadRhsQuad(blA + 0 * spk, B_0);
2426 straits.loadRhsQuad(blA + 1 * spk, B_1);
2427 straits.madd(A0, B_0, C0, B_0, fix<0>);
2428 straits.madd(A1, B_1, C1, B_1, fix<0>);
2429
2430 straits.loadLhsUnaligned(blB + 2 * SwappedTraits::LhsProgress, A0);
2431 straits.loadLhsUnaligned(blB + 3 * SwappedTraits::LhsProgress, A1);
2432 straits.loadRhsQuad(blA + 2 * spk, B_0);
2433 straits.loadRhsQuad(blA + 3 * spk, B_1);
2434 straits.madd(A0, B_0, C2, B_0, fix<0>);
2435 straits.madd(A1, B_1, C3, B_1, fix<0>);
2436
2437 blB += 4 * SwappedTraits::LhsProgress;
2438 blA += 4 * spk;
2439 }
2440 C0 = padd(padd(C0, C1), padd(C2, C3));
2441 for (; k < endk; k += spk) {
2442 SLhsPacket A0;
2443 SRhsPacket B_0;
2444
2445 straits.loadLhsUnaligned(blB, A0);
2446 straits.loadRhsQuad(blA, B_0);
2447 straits.madd(A0, B_0, C0, B_0, fix<0>);
2448
2449 blB += SwappedTraits::LhsProgress;
2450 blA += spk;
2451 }
2452 if (SwappedTraits::LhsProgress == 8) {
2453 // Special case where we have to first reduce the accumulation register C0
2454 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8, typename unpacket_traits<SResPacket>::half,
2455 SResPacket>
2456 SResPacketHalf;
2457 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8, typename unpacket_traits<SLhsPacket>::half,
2458 SLhsPacket>
2459 SLhsPacketHalf;
2460 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8, typename unpacket_traits<SRhsPacket>::half,
2461 SRhsPacket>
2462 SRhsPacketHalf;
2463 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8, typename unpacket_traits<SAccPacket>::half,
2464 SAccPacket>
2465 SAccPacketHalf;
2466
2467 SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
2468 SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
2469
2470 if (depth - endk > 0) {
2471 // We have to handle the last row of the rhs which corresponds to a half-packet
2472 SLhsPacketHalf a0;
2473 SRhsPacketHalf b0;
2474 straits.loadLhsUnaligned(blB, a0);
2475 straits.loadRhs(blA, b0);
2476 SAccPacketHalf c0 = predux_half_dowto4(C0);
2477 straits.madd(a0, b0, c0, b0, fix<0>);
2478 straits.acc(c0, alphav, R);
2479 } else {
2480 straits.acc(predux_half_dowto4(C0), alphav, R);
2481 }
2482 res.scatterPacket(i, j2, R);
2483 } else if (SwappedTraits::LhsProgress == 16) {
2484 // Special case where we have to first reduce the
2485 // accumulation register C0. We specialize the block in
2486 // template form, so that LhsProgress < 16 paths don't
2487 // fail to compile
2488 last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
2489 p(res, straits, blA, blB, depth, endk, i, j2, alpha, C0);
2490 } else {
2491 SResPacket R = res.template gatherPacket<SResPacket>(i, j2);
2492 SResPacket alphav = pset1<SResPacket>(alpha);
2493 straits.acc(C0, alphav, R);
2494 res.scatterPacket(i, j2, R);
2495 }
2496 } else // scalar path
2497 {
2498 // get a 1 x 4 res block as registers
2499 ResScalar C0(0), C1(0), C2(0), C3(0);
2500
2501 for (Index k = 0; k < depth; k++) {
2502 LhsScalar A0;
2503 RhsScalar B_0, B_1;
2504
2505 A0 = blA[k];
2506
2507 B_0 = blB[0];
2508 B_1 = blB[1];
2509 C0 = cj.pmadd(A0, B_0, C0);
2510 C1 = cj.pmadd(A0, B_1, C1);
2511
2512 B_0 = blB[2];
2513 B_1 = blB[3];
2514 C2 = cj.pmadd(A0, B_0, C2);
2515 C3 = cj.pmadd(A0, B_1, C3);
2516
2517 blB += 4;
2518 }
2519 res(i, j2 + 0) += alpha * C0;
2520 res(i, j2 + 1) += alpha * C1;
2521 res(i, j2 + 2) += alpha * C2;
2522 res(i, j2 + 3) += alpha * C3;
2523 }
2524 }
2525 }
2526 // remaining columns
2527 for (Index j2 = packet_cols4; j2 < cols; j2++) {
2528 // loop on each row of the lhs (1*LhsProgress x depth)
2529 for (Index i = peeled_mc_quarter; i < rows; i += 1) {
2530 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2531 prefetch(&blA[0]);
2532 // gets a 1 x 1 res block as registers
2533 ResScalar C0(0);
2534 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
2535 for (Index k = 0; k < depth; k++) {
2536 LhsScalar A0 = blA[k];
2537 RhsScalar B_0 = blB[k];
2538 C0 = cj.pmadd(A0, B_0, C0);
2539 }
2540 res(i, j2) += alpha * C0;
2541 }
2542 }
2543 }
2544}
2545
2546// pack a block of the lhs
2547// The traversal is as follow (mr==4):
2548// 0 4 8 12 ...
2549// 1 5 9 13 ...
2550// 2 6 10 14 ...
2551// 3 7 11 15 ...
2552//
2553// 16 20 24 28 ...
2554// 17 21 25 29 ...
2555// 18 22 26 30 ...
2556// 19 23 27 31 ...
2557//
2558// 32 33 34 35 ...
2559// 36 36 38 39 ...
2560template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate,
2561 bool PanelMode>
2562struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
2563 typedef typename DataMapper::LinearMapper LinearMapper;
2564 EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
2565 Index offset = 0);
2566};
2567
2568template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate,
2569 bool PanelMode>
2570EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
2571 PanelMode>::operator()(Scalar* blockA, const DataMapper& lhs, Index depth,
2572 Index rows, Index stride, Index offset) {
2573 typedef typename unpacket_traits<Packet>::half HalfPacket;
2574 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
2575 enum {
2576 PacketSize = unpacket_traits<Packet>::size,
2577 HalfPacketSize = unpacket_traits<HalfPacket>::size,
2578 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
2579 HasHalf = (int)HalfPacketSize < (int)PacketSize,
2580 HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize
2581 };
2582
2583 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS");
2584 EIGEN_UNUSED_VARIABLE(stride);
2585 EIGEN_UNUSED_VARIABLE(offset);
2586 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2587 eigen_assert(((Pack1 % PacketSize) == 0 && Pack1 <= 4 * PacketSize) || (Pack1 <= 4));
2588 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2589 Index count = 0;
2590
2591 const Index peeled_mc3 = Pack1 >= 3 * PacketSize ? (rows / (3 * PacketSize)) * (3 * PacketSize) : 0;
2592 const Index peeled_mc2 =
2593 Pack1 >= 2 * PacketSize ? peeled_mc3 + ((rows - peeled_mc3) / (2 * PacketSize)) * (2 * PacketSize) : 0;
2594 const Index peeled_mc1 =
2595 Pack1 >= 1 * PacketSize ? peeled_mc2 + ((rows - peeled_mc2) / (1 * PacketSize)) * (1 * PacketSize) : 0;
2596 const Index peeled_mc_half =
2597 Pack1 >= HalfPacketSize ? peeled_mc1 + ((rows - peeled_mc1) / (HalfPacketSize)) * (HalfPacketSize) : 0;
2598 const Index peeled_mc_quarter = Pack1 >= QuarterPacketSize ? (rows / (QuarterPacketSize)) * (QuarterPacketSize) : 0;
2599 const Index last_lhs_progress = rows > peeled_mc_quarter ? (rows - peeled_mc_quarter) & ~1 : 0;
2600 const Index peeled_mc0 = Pack2 >= PacketSize ? peeled_mc_quarter
2601 : Pack2 > 1 && last_lhs_progress ? (rows / last_lhs_progress) * last_lhs_progress
2602 : 0;
2603
2604 Index i = 0;
2605
2606 // Pack 3 packets
2607 if (Pack1 >= 3 * PacketSize) {
2608 for (; i < peeled_mc3; i += 3 * PacketSize) {
2609 if (PanelMode) count += (3 * PacketSize) * offset;
2610
2611 for (Index k = 0; k < depth; k++) {
2612 Packet A, B, C;
2613 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2614 B = lhs.template loadPacket<Packet>(i + 1 * PacketSize, k);
2615 C = lhs.template loadPacket<Packet>(i + 2 * PacketSize, k);
2616 pstore(blockA + count, cj.pconj(A));
2617 count += PacketSize;
2618 pstore(blockA + count, cj.pconj(B));
2619 count += PacketSize;
2620 pstore(blockA + count, cj.pconj(C));
2621 count += PacketSize;
2622 }
2623 if (PanelMode) count += (3 * PacketSize) * (stride - offset - depth);
2624 }
2625 }
2626 // Pack 2 packets
2627 if (Pack1 >= 2 * PacketSize) {
2628 for (; i < peeled_mc2; i += 2 * PacketSize) {
2629 if (PanelMode) count += (2 * PacketSize) * offset;
2630
2631 for (Index k = 0; k < depth; k++) {
2632 Packet A, B;
2633 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2634 B = lhs.template loadPacket<Packet>(i + 1 * PacketSize, k);
2635 pstore(blockA + count, cj.pconj(A));
2636 count += PacketSize;
2637 pstore(blockA + count, cj.pconj(B));
2638 count += PacketSize;
2639 }
2640 if (PanelMode) count += (2 * PacketSize) * (stride - offset - depth);
2641 }
2642 }
2643 // Pack 1 packets
2644 if (Pack1 >= 1 * PacketSize) {
2645 for (; i < peeled_mc1; i += 1 * PacketSize) {
2646 if (PanelMode) count += (1 * PacketSize) * offset;
2647
2648 for (Index k = 0; k < depth; k++) {
2649 Packet A;
2650 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2651 pstore(blockA + count, cj.pconj(A));
2652 count += PacketSize;
2653 }
2654 if (PanelMode) count += (1 * PacketSize) * (stride - offset - depth);
2655 }
2656 }
2657 // Pack half packets
2658 if (HasHalf && Pack1 >= HalfPacketSize) {
2659 for (; i < peeled_mc_half; i += HalfPacketSize) {
2660 if (PanelMode) count += (HalfPacketSize)*offset;
2661
2662 for (Index k = 0; k < depth; k++) {
2663 HalfPacket A;
2664 A = lhs.template loadPacket<HalfPacket>(i + 0 * (HalfPacketSize), k);
2665 pstoreu(blockA + count, cj.pconj(A));
2666 count += HalfPacketSize;
2667 }
2668 if (PanelMode) count += (HalfPacketSize) * (stride - offset - depth);
2669 }
2670 }
2671 // Pack quarter packets
2672 if (HasQuarter && Pack1 >= QuarterPacketSize) {
2673 for (; i < peeled_mc_quarter; i += QuarterPacketSize) {
2674 if (PanelMode) count += (QuarterPacketSize)*offset;
2675
2676 for (Index k = 0; k < depth; k++) {
2677 QuarterPacket A;
2678 A = lhs.template loadPacket<QuarterPacket>(i + 0 * (QuarterPacketSize), k);
2679 pstoreu(blockA + count, cj.pconj(A));
2680 count += QuarterPacketSize;
2681 }
2682 if (PanelMode) count += (QuarterPacketSize) * (stride - offset - depth);
2683 }
2684 }
2685 // Pack2 may be *smaller* than PacketSize—that happens for
2686 // products like real * complex, where we have to go half the
2687 // progress on the lhs in order to duplicate those operands to
2688 // address both real & imaginary parts on the rhs. This portion will
2689 // pack those half ones until they match the number expected on the
2690 // last peeling loop at this point (for the rhs).
2691 if (Pack2 < PacketSize && Pack2 > 1) {
2692 for (; i < peeled_mc0; i += last_lhs_progress) {
2693 if (PanelMode) count += last_lhs_progress * offset;
2694
2695 for (Index k = 0; k < depth; k++)
2696 for (Index w = 0; w < last_lhs_progress; w++) blockA[count++] = cj(lhs(i + w, k));
2697
2698 if (PanelMode) count += last_lhs_progress * (stride - offset - depth);
2699 }
2700 }
2701 // Pack scalars
2702 for (; i < rows; i++) {
2703 if (PanelMode) count += offset;
2704 for (Index k = 0; k < depth; k++) blockA[count++] = cj(lhs(i, k));
2705 if (PanelMode) count += (stride - offset - depth);
2706 }
2707}
2708
2709template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate,
2710 bool PanelMode>
2711struct gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
2712 typedef typename DataMapper::LinearMapper LinearMapper;
2713 EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
2714 Index offset = 0);
2715};
2716
2717template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate,
2718 bool PanelMode>
2719EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
2720 PanelMode>::operator()(Scalar* blockA, const DataMapper& lhs, Index depth,
2721 Index rows, Index stride, Index offset) {
2722 typedef typename unpacket_traits<Packet>::half HalfPacket;
2723 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
2724 enum {
2725 PacketSize = unpacket_traits<Packet>::size,
2726 HalfPacketSize = unpacket_traits<HalfPacket>::size,
2727 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
2728 HasHalf = (int)HalfPacketSize < (int)PacketSize,
2729 HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize
2730 };
2731
2732 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS");
2733 EIGEN_UNUSED_VARIABLE(stride);
2734 EIGEN_UNUSED_VARIABLE(offset);
2735 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2736 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2737 Index count = 0;
2738 bool gone_half = false, gone_quarter = false, gone_last = false;
2739
2740 Index i = 0;
2741 Index pack = Pack1;
2742 Index psize = PacketSize;
2743 while (pack > 0) {
2744 Index remaining_rows = rows - i;
2745 Index peeled_mc = gone_last ? Pack2 > 1 ? (rows / pack) * pack : 0 : i + (remaining_rows / pack) * pack;
2746 Index starting_pos = i;
2747 for (; i < peeled_mc; i += pack) {
2748 if (PanelMode) count += pack * offset;
2749
2750 Index k = 0;
2751 if (pack >= psize && psize >= QuarterPacketSize) {
2752 const Index peeled_k = (depth / psize) * psize;
2753 for (; k < peeled_k; k += psize) {
2754 for (Index m = 0; m < pack; m += psize) {
2755 if (psize == PacketSize) {
2756 PacketBlock<Packet> kernel;
2757 for (Index p = 0; p < psize; ++p) kernel.packet[p] = lhs.template loadPacket<Packet>(i + p + m, k);
2758 ptranspose(kernel);
2759 for (Index p = 0; p < psize; ++p) pstore(blockA + count + m + (pack)*p, cj.pconj(kernel.packet[p]));
2760 } else if (HasHalf && psize == HalfPacketSize) {
2761 gone_half = true;
2762 PacketBlock<HalfPacket> kernel_half;
2763 for (Index p = 0; p < psize; ++p)
2764 kernel_half.packet[p] = lhs.template loadPacket<HalfPacket>(i + p + m, k);
2765 ptranspose(kernel_half);
2766 for (Index p = 0; p < psize; ++p) pstore(blockA + count + m + (pack)*p, cj.pconj(kernel_half.packet[p]));
2767 } else if (HasQuarter && psize == QuarterPacketSize) {
2768 gone_quarter = true;
2769 PacketBlock<QuarterPacket> kernel_quarter;
2770 for (Index p = 0; p < psize; ++p)
2771 kernel_quarter.packet[p] = lhs.template loadPacket<QuarterPacket>(i + p + m, k);
2772 ptranspose(kernel_quarter);
2773 for (Index p = 0; p < psize; ++p)
2774 pstore(blockA + count + m + (pack)*p, cj.pconj(kernel_quarter.packet[p]));
2775 }
2776 }
2777 count += psize * pack;
2778 }
2779 }
2780
2781 for (; k < depth; k++) {
2782 Index w = 0;
2783 for (; w < pack - 3; w += 4) {
2784 Scalar a(cj(lhs(i + w + 0, k))), b(cj(lhs(i + w + 1, k))), c(cj(lhs(i + w + 2, k))), d(cj(lhs(i + w + 3, k)));
2785 blockA[count++] = a;
2786 blockA[count++] = b;
2787 blockA[count++] = c;
2788 blockA[count++] = d;
2789 }
2790 if (pack % 4)
2791 for (; w < pack; ++w) blockA[count++] = cj(lhs(i + w, k));
2792 }
2793
2794 if (PanelMode) count += pack * (stride - offset - depth);
2795 }
2796
2797 pack -= psize;
2798 Index left = rows - i;
2799 if (pack <= 0) {
2800 if (!gone_last && (starting_pos == i || left >= psize / 2 || left >= psize / 4) &&
2801 ((psize / 2 == HalfPacketSize && HasHalf && !gone_half) ||
2802 (psize / 2 == QuarterPacketSize && HasQuarter && !gone_quarter))) {
2803 psize /= 2;
2804 pack = psize;
2805 continue;
2806 }
2807 // Pack2 may be *smaller* than PacketSize—that happens for
2808 // products like real * complex, where we have to go half the
2809 // progress on the lhs in order to duplicate those operands to
2810 // address both real & imaginary parts on the rhs. This portion will
2811 // pack those half ones until they match the number expected on the
2812 // last peeling loop at this point (for the rhs).
2813 if (Pack2 < PacketSize && !gone_last) {
2814 gone_last = true;
2815 psize = pack = left & ~1;
2816 }
2817 }
2818 }
2819
2820 for (; i < rows; i++) {
2821 if (PanelMode) count += offset;
2822 for (Index k = 0; k < depth; k++) blockA[count++] = cj(lhs(i, k));
2823 if (PanelMode) count += (stride - offset - depth);
2824 }
2825}
2826
2827// copy a complete panel of the rhs
2828// this version is optimized for column major matrices
2829// The traversal order is as follow: (nr==4):
2830// 0 1 2 3 12 13 14 15 24 27
2831// 4 5 6 7 16 17 18 19 25 28
2832// 8 9 10 11 20 21 22 23 26 29
2833// . . . . . . . . . .
2834template <typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2835struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
2836 typedef typename packet_traits<Scalar>::type Packet;
2837 typedef typename DataMapper::LinearMapper LinearMapper;
2838 enum { PacketSize = packet_traits<Scalar>::size };
2839 EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
2840 Index offset = 0);
2841};
2842
2843template <typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2844EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
2845 Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
2846 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
2847 EIGEN_UNUSED_VARIABLE(stride);
2848 EIGEN_UNUSED_VARIABLE(offset);
2849 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2850 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2851 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
2852 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
2853 Index count = 0;
2854 const Index peeled_k = (depth / PacketSize) * PacketSize;
2855
2856#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
2857 EIGEN_IF_CONSTEXPR(nr >= 8) {
2858 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
2859 // skip what we have before
2860 if (PanelMode) count += 8 * offset;
2861 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
2862 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
2863 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
2864 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
2865 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
2866 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
2867 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
2868 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
2869 Index k = 0;
2870 if (PacketSize % 2 == 0 && PacketSize <= 8) // 2 4 8
2871 {
2872 for (; k < peeled_k; k += PacketSize) {
2873 if (PacketSize == 2) {
2874 PacketBlock<Packet, PacketSize == 2 ? 2 : PacketSize> kernel0, kernel1, kernel2, kernel3;
2875 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2876 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2877 kernel1.packet[0 % PacketSize] = dm2.template loadPacket<Packet>(k);
2878 kernel1.packet[1 % PacketSize] = dm3.template loadPacket<Packet>(k);
2879 kernel2.packet[0 % PacketSize] = dm4.template loadPacket<Packet>(k);
2880 kernel2.packet[1 % PacketSize] = dm5.template loadPacket<Packet>(k);
2881 kernel3.packet[0 % PacketSize] = dm6.template loadPacket<Packet>(k);
2882 kernel3.packet[1 % PacketSize] = dm7.template loadPacket<Packet>(k);
2883 ptranspose(kernel0);
2884 ptranspose(kernel1);
2885 ptranspose(kernel2);
2886 ptranspose(kernel3);
2887
2888 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2889 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
2890 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize]));
2891 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize]));
2892
2893 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2894 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
2895 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize]));
2896 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize]));
2897 count += 8 * PacketSize;
2898 } else if (PacketSize == 4) {
2899 PacketBlock<Packet, PacketSize == 4 ? 4 : PacketSize> kernel0, kernel1;
2900
2901 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2902 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2903 kernel0.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2904 kernel0.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2905 kernel1.packet[0 % PacketSize] = dm4.template loadPacket<Packet>(k);
2906 kernel1.packet[1 % PacketSize] = dm5.template loadPacket<Packet>(k);
2907 kernel1.packet[2 % PacketSize] = dm6.template loadPacket<Packet>(k);
2908 kernel1.packet[3 % PacketSize] = dm7.template loadPacket<Packet>(k);
2909 ptranspose(kernel0);
2910 ptranspose(kernel1);
2911
2912 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2913 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
2914 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2915 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
2916 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[2 % PacketSize]));
2917 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[2 % PacketSize]));
2918 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel0.packet[3 % PacketSize]));
2919 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel1.packet[3 % PacketSize]));
2920 count += 8 * PacketSize;
2921 } else if (PacketSize == 8) {
2922 PacketBlock<Packet, PacketSize == 8 ? 8 : PacketSize> kernel0;
2923
2924 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2925 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2926 kernel0.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2927 kernel0.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2928 kernel0.packet[4 % PacketSize] = dm4.template loadPacket<Packet>(k);
2929 kernel0.packet[5 % PacketSize] = dm5.template loadPacket<Packet>(k);
2930 kernel0.packet[6 % PacketSize] = dm6.template loadPacket<Packet>(k);
2931 kernel0.packet[7 % PacketSize] = dm7.template loadPacket<Packet>(k);
2932 ptranspose(kernel0);
2933
2934 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2935 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2936 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel0.packet[2 % PacketSize]));
2937 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel0.packet[3 % PacketSize]));
2938 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[4 % PacketSize]));
2939 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel0.packet[5 % PacketSize]));
2940 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel0.packet[6 % PacketSize]));
2941 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel0.packet[7 % PacketSize]));
2942 count += 8 * PacketSize;
2943 }
2944 }
2945 }
2946
2947 for (; k < depth; k++) {
2948 blockB[count + 0] = cj(dm0(k));
2949 blockB[count + 1] = cj(dm1(k));
2950 blockB[count + 2] = cj(dm2(k));
2951 blockB[count + 3] = cj(dm3(k));
2952 blockB[count + 4] = cj(dm4(k));
2953 blockB[count + 5] = cj(dm5(k));
2954 blockB[count + 6] = cj(dm6(k));
2955 blockB[count + 7] = cj(dm7(k));
2956 count += 8;
2957 }
2958 // skip what we have after
2959 if (PanelMode) count += 8 * (stride - offset - depth);
2960 }
2961 }
2962#endif
2963
2964 EIGEN_IF_CONSTEXPR(nr >= 4) {
2965 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2966 // skip what we have before
2967 if (PanelMode) count += 4 * offset;
2968 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
2969 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
2970 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
2971 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
2972
2973 Index k = 0;
2974 if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
2975 {
2976 for (; k < peeled_k; k += PacketSize) {
2977 PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
2978 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
2979 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2980 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2981 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2982 ptranspose(kernel);
2983 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
2984 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
2985 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
2986 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
2987 count += 4 * PacketSize;
2988 }
2989 }
2990 for (; k < depth; k++) {
2991 blockB[count + 0] = cj(dm0(k));
2992 blockB[count + 1] = cj(dm1(k));
2993 blockB[count + 2] = cj(dm2(k));
2994 blockB[count + 3] = cj(dm3(k));
2995 count += 4;
2996 }
2997 // skip what we have after
2998 if (PanelMode) count += 4 * (stride - offset - depth);
2999 }
3000 }
3001
3002 // copy the remaining columns one at a time (nr==1)
3003 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
3004 if (PanelMode) count += offset;
3005 const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
3006 for (Index k = 0; k < depth; k++) {
3007 blockB[count] = cj(dm0(k));
3008 count += 1;
3009 }
3010 if (PanelMode) count += (stride - offset - depth);
3011 }
3012}
3013
3014// this version is optimized for row major matrices
3015template <typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3016struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3017 typedef typename packet_traits<Scalar>::type Packet;
3018 typedef typename unpacket_traits<Packet>::half HalfPacket;
3019 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
3020 typedef typename DataMapper::LinearMapper LinearMapper;
3021 enum {
3022 PacketSize = packet_traits<Scalar>::size,
3023 HalfPacketSize = unpacket_traits<HalfPacket>::size,
3024 QuarterPacketSize = unpacket_traits<QuarterPacket>::size
3025 };
3026 EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3027 Index offset = 0) {
3028 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
3029 EIGEN_UNUSED_VARIABLE(stride);
3030 EIGEN_UNUSED_VARIABLE(offset);
3031 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
3032 const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
3033 const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
3034 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
3035 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
3036 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
3037 Index count = 0;
3038
3039#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
3040 EIGEN_IF_CONSTEXPR(nr >= 8) {
3041 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
3042 // skip what we have before
3043 if (PanelMode) count += 8 * offset;
3044 for (Index k = 0; k < depth; k++) {
3045 if (PacketSize == 8) {
3046 Packet A = rhs.template loadPacket<Packet>(k, j2);
3047 pstoreu(blockB + count, cj.pconj(A));
3048 count += PacketSize;
3049 } else if (PacketSize == 4) {
3050 Packet A = rhs.template loadPacket<Packet>(k, j2);
3051 Packet B = rhs.template loadPacket<Packet>(k, j2 + 4);
3052 pstoreu(blockB + count, cj.pconj(A));
3053 pstoreu(blockB + count + PacketSize, cj.pconj(B));
3054 count += 2 * PacketSize;
3055 } else {
3056 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
3057 blockB[count + 0] = cj(dm0(0));
3058 blockB[count + 1] = cj(dm0(1));
3059 blockB[count + 2] = cj(dm0(2));
3060 blockB[count + 3] = cj(dm0(3));
3061 blockB[count + 4] = cj(dm0(4));
3062 blockB[count + 5] = cj(dm0(5));
3063 blockB[count + 6] = cj(dm0(6));
3064 blockB[count + 7] = cj(dm0(7));
3065 count += 8;
3066 }
3067 }
3068 // skip what we have after
3069 if (PanelMode) count += 8 * (stride - offset - depth);
3070 }
3071 }
3072#endif
3073
3074 if (nr >= 4) {
3075 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
3076 // skip what we have before
3077 if (PanelMode) count += 4 * offset;
3078 for (Index k = 0; k < depth; k++) {
3079 if (PacketSize == 4) {
3080 Packet A = rhs.template loadPacket<Packet>(k, j2);
3081 pstoreu(blockB + count, cj.pconj(A));
3082 count += PacketSize;
3083 } else if (HasHalf && HalfPacketSize == 4) {
3084 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
3085 pstoreu(blockB + count, cj.pconj(A));
3086 count += HalfPacketSize;
3087 } else if (HasQuarter && QuarterPacketSize == 4) {
3088 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
3089 pstoreu(blockB + count, cj.pconj(A));
3090 count += QuarterPacketSize;
3091 } else {
3092 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
3093 blockB[count + 0] = cj(dm0(0));
3094 blockB[count + 1] = cj(dm0(1));
3095 blockB[count + 2] = cj(dm0(2));
3096 blockB[count + 3] = cj(dm0(3));
3097 count += 4;
3098 }
3099 }
3100 // skip what we have after
3101 if (PanelMode) count += 4 * (stride - offset - depth);
3102 }
3103 }
3104 // copy the remaining columns one at a time (nr==1)
3105 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
3106 if (PanelMode) count += offset;
3107 for (Index k = 0; k < depth; k++) {
3108 blockB[count] = cj(rhs(k, j2));
3109 count += 1;
3110 }
3111 if (PanelMode) count += stride - offset - depth;
3112 }
3113 }
3114};
3115
3116} // end namespace internal
3117
3120inline std::ptrdiff_t l1CacheSize() {
3121 std::ptrdiff_t l1, l2, l3;
3122 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3123 return l1;
3124}
3125
3128inline std::ptrdiff_t l2CacheSize() {
3129 std::ptrdiff_t l1, l2, l3;
3130 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3131 return l2;
3132}
3133
3136inline std::ptrdiff_t l3CacheSize() {
3137 std::ptrdiff_t l1, l2, l3;
3138 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3139 return l3;
3140}
3141
3147inline void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2, std::ptrdiff_t l3) {
3148 internal::manage_caching_sizes(SetAction, &l1, &l2, &l3);
3149}
3150
3151} // end namespace Eigen
3152
3153#endif // EIGEN_GENERAL_BLOCK_PANEL_H
static const auto fix()
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
std::ptrdiff_t l1CacheSize()
Definition GeneralBlockPanelKernel.h:3120
std::ptrdiff_t l2CacheSize()
Definition GeneralBlockPanelKernel.h:3128
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
std::ptrdiff_t l3CacheSize()
Definition GeneralBlockPanelKernel.h:3136
void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2, std::ptrdiff_t l3)
Definition GeneralBlockPanelKernel.h:3147