10#ifndef EIGEN_GENERAL_BLOCK_PANEL_H
11#define EIGEN_GENERAL_BLOCK_PANEL_H
14#include "../InternalHeaderCheck.h"
20enum GEBPPacketSizeType { GEBPPacketFull = 0, GEBPPacketHalf, GEBPPacketQuarter };
22template <
typename LhsScalar_,
typename RhsScalar_,
bool ConjLhs_ =
false,
bool ConjRhs_ =
false,
23 int Arch = Architecture::Target,
int PacketSize_ = GEBPPacketFull>
27inline std::ptrdiff_t manage_caching_sizes_helper(std::ptrdiff_t a, std::ptrdiff_t b) {
return a <= 0 ? b : a; }
29#if defined(EIGEN_DEFAULT_L1_CACHE_SIZE)
30#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) EIGEN_DEFAULT_L1_CACHE_SIZE
32#define EIGEN_SET_DEFAULT_L1_CACHE_SIZE(val) val
35#if defined(EIGEN_DEFAULT_L2_CACHE_SIZE)
36#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) EIGEN_DEFAULT_L2_CACHE_SIZE
38#define EIGEN_SET_DEFAULT_L2_CACHE_SIZE(val) val
41#if defined(EIGEN_DEFAULT_L3_CACHE_SIZE)
42#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) EIGEN_DEFAULT_L3_CACHE_SIZE
44#define EIGEN_SET_DEFAULT_L3_CACHE_SIZE(val) val
47#if EIGEN_ARCH_i386_OR_x86_64
48const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(32 * 1024);
49const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(256 * 1024);
50const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(2 * 1024 * 1024);
52const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(64 * 1024);
54const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(2 * 1024 * 1024);
55const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(8 * 1024 * 1024);
57const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512 * 1024);
58const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(4 * 1024 * 1024);
61const std::ptrdiff_t defaultL1CacheSize = EIGEN_SET_DEFAULT_L1_CACHE_SIZE(16 * 1024);
62const std::ptrdiff_t defaultL2CacheSize = EIGEN_SET_DEFAULT_L2_CACHE_SIZE(512 * 1024);
63const std::ptrdiff_t defaultL3CacheSize = EIGEN_SET_DEFAULT_L3_CACHE_SIZE(512 * 1024);
66#undef EIGEN_SET_DEFAULT_L1_CACHE_SIZE
67#undef EIGEN_SET_DEFAULT_L2_CACHE_SIZE
68#undef EIGEN_SET_DEFAULT_L3_CACHE_SIZE
72 CacheSizes() : m_l1(-1), m_l2(-1), m_l3(-1) {
75 m_l1 = manage_caching_sizes_helper(
l1CacheSize, defaultL1CacheSize);
76 m_l2 = manage_caching_sizes_helper(
l2CacheSize, defaultL2CacheSize);
77 m_l3 = manage_caching_sizes_helper(
l3CacheSize, defaultL3CacheSize);
86inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff_t* l2, std::ptrdiff_t* l3) {
87 static CacheSizes m_cacheSizes;
89 if (action == SetAction) {
91 eigen_internal_assert(l1 != 0 && l2 != 0);
92 m_cacheSizes.m_l1 = *l1;
93 m_cacheSizes.m_l2 = *l2;
94 m_cacheSizes.m_l3 = *l3;
95 }
else if (action == GetAction) {
96 eigen_internal_assert(l1 != 0 && l2 != 0);
97 *l1 = m_cacheSizes.m_l1;
98 *l2 = m_cacheSizes.m_l2;
99 *l3 = m_cacheSizes.m_l3;
101 eigen_internal_assert(
false);
117template <
typename LhsScalar,
typename RhsScalar,
int KcFactor,
typename Index>
118void evaluateProductBlockingSizesHeuristic(
Index& k,
Index& m,
Index& n,
Index num_threads = 1) {
119 typedef gebp_traits<LhsScalar, RhsScalar> Traits;
126 std::ptrdiff_t l1, l2, l3;
127 manage_caching_sizes(GetAction, &l1, &l2, &l3);
128#ifdef EIGEN_VECTORIZE_AVX512
139 if (num_threads > 1) {
140 typedef typename Traits::ResScalar ResScalar;
142 kdiv = KcFactor * (Traits::mr *
sizeof(LhsScalar) + Traits::nr *
sizeof(RhsScalar)),
143 ksub = Traits::mr * (Traits::nr *
sizeof(ResScalar)),
153 const Index k_cache = numext::maxi<Index>(kr, (numext::mini<Index>)((l1 - ksub) / kdiv, 320));
155 k = k_cache - (k_cache % kr);
156 eigen_internal_assert(k > 0);
159 const Index n_cache = (l2 - l1) / (nr *
sizeof(RhsScalar) * k);
160 const Index n_per_thread = numext::div_ceil(n, num_threads);
161 if (n_cache <= n_per_thread) {
163 eigen_internal_assert(n_cache >=
static_cast<Index>(nr));
164 n = n_cache - (n_cache % nr);
165 eigen_internal_assert(n > 0);
167 n = (numext::mini<Index>)(n, (n_per_thread + nr - 1) - ((n_per_thread + nr - 1) % nr));
172 const Index m_cache = (l3 - l2) / (
sizeof(LhsScalar) * k * num_threads);
173 const Index m_per_thread = numext::div_ceil(m, num_threads);
174 if (m_cache < m_per_thread && m_cache >=
static_cast<Index>(mr)) {
175 m = m_cache - (m_cache % mr);
176 eigen_internal_assert(m > 0);
178 m = (numext::mini<Index>)(m, (m_per_thread + mr - 1) - ((m_per_thread + mr - 1) % mr));
184#ifdef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS
194 if ((numext::maxi)(k, (numext::maxi)(m, n)) < 48)
return;
196 typedef typename Traits::ResScalar ResScalar;
199 k_div = KcFactor * (Traits::mr *
sizeof(LhsScalar) + Traits::nr *
sizeof(RhsScalar)),
200 k_sub = Traits::mr * (Traits::nr *
sizeof(ResScalar))
210 const Index max_kc = numext::maxi<Index>(((l1 - k_sub) / k_div) & (~(k_peeling - 1)), 1);
211 const Index old_k = k;
216 k = (k % max_kc) == 0 ? max_kc
217 : max_kc - k_peeling * ((max_kc - 1 - (k % max_kc)) / (k_peeling * (k / max_kc + 1)));
219 eigen_internal_assert(((old_k / k) == (old_k / max_kc)) &&
"the number of sweeps has to remain the same");
228#ifdef EIGEN_DEBUG_SMALL_PRODUCT_BLOCKS
229 const Index actual_l2 = l3;
231 const Index actual_l2 = 1572864;
241 const Index lhs_bytes = m * k *
sizeof(LhsScalar);
242 const Index remaining_l1 = l1 - k_sub - lhs_bytes;
243 if (remaining_l1 >=
Index(Traits::nr *
sizeof(RhsScalar)) * k) {
245 max_nc = remaining_l1 / (k *
sizeof(RhsScalar));
248 max_nc = (3 * actual_l2) / (2 * 2 * max_kc *
sizeof(RhsScalar));
251 Index nc = numext::mini<Index>(actual_l2 / (2 * k *
sizeof(RhsScalar)), max_nc) & (~(Traits::nr - 1));
257 n = (n % nc) == 0 ? nc : (nc - Traits::nr * ((nc - (n % nc)) / (Traits::nr * (n / nc + 1))));
258 }
else if (old_k == k) {
263 Index problem_size = k * n *
sizeof(LhsScalar);
264 Index actual_lm = actual_l2;
266 if (problem_size <= 1024) {
270 }
else if (l3 != 0 && problem_size <= 32768) {
274 max_mc = (numext::mini<Index>)(576, max_mc);
276 Index mc = (numext::mini<Index>)(actual_lm / (3 * k *
sizeof(LhsScalar)), max_mc);
278 mc -= mc % Traits::mr;
281 m = (m % mc) == 0 ? mc : (mc - Traits::mr * ((mc - (m % mc)) / (Traits::mr * (m / mc + 1))));
286template <
typename Index>
288#ifdef EIGEN_TEST_SPECIFIC_BLOCKING_SIZES
289 if (EIGEN_TEST_SPECIFIC_BLOCKING_SIZES) {
290 k = numext::mini<Index>(k, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_K);
291 m = numext::mini<Index>(m, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_M);
292 n = numext::mini<Index>(n, EIGEN_TEST_SPECIFIC_BLOCKING_SIZE_N);
296 EIGEN_UNUSED_VARIABLE(k)
297 EIGEN_UNUSED_VARIABLE(m)
298 EIGEN_UNUSED_VARIABLE(n)
321template <
typename LhsScalar,
typename RhsScalar,
int KcFactor,
typename Index>
323 if (!useSpecificBlockingSizes(k, m, n)) {
324 evaluateProductBlockingSizesHeuristic<LhsScalar, RhsScalar, KcFactor, Index>(k, m, n, num_threads);
328template <
typename LhsScalar,
typename RhsScalar,
typename Index>
330 computeProductBlockingSizes<LhsScalar, RhsScalar, 1, Index>(k, m, n, num_threads);
333template <
typename RhsPacket,
typename RhsPacketx4,
int registers_taken>
334struct RhsPanelHelper {
336 static constexpr int remaining_registers =
337 (std::max)(
int(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS) - registers_taken, 0);
340 typedef std::conditional_t<remaining_registers >= 4, RhsPacketx4, RhsPacket> type;
343template <
typename Packet>
345 Packet B_0, B1, B2, B3;
346 const Packet& get(
const FixedInt<0>&)
const {
return B_0; }
347 const Packet& get(
const FixedInt<1>&)
const {
return B1; }
348 const Packet& get(
const FixedInt<2>&)
const {
return B2; }
349 const Packet& get(
const FixedInt<3>&)
const {
return B3; }
352template <
int N,
typename T1,
typename T2,
typename T3>
353struct packet_conditional {
357template <
typename T1,
typename T2,
typename T3>
358struct packet_conditional<GEBPPacketFull, T1, T2, T3> {
362template <
typename T1,
typename T2,
typename T3>
363struct packet_conditional<GEBPPacketHalf, T1, T2, T3> {
367#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
368 typedef typename packet_conditional< \
369 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
370 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
372#define PACKET_DECL_COND(name, packet_size) \
373 typedef typename packet_conditional< \
374 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
375 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet
377#define PACKET_DECL_COND_SCALAR_POSTFIX(postfix, packet_size) \
378 typedef typename packet_conditional< \
379 packet_size, typename packet_traits<Scalar>::type, typename packet_traits<Scalar>::half, \
380 typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type ScalarPacket##postfix
382#define PACKET_DECL_COND_SCALAR(packet_size) \
383 typedef typename packet_conditional< \
384 packet_size, typename packet_traits<Scalar>::type, typename packet_traits<Scalar>::half, \
385 typename unpacket_traits<typename packet_traits<Scalar>::half>::half>::type ScalarPacket
397template <
typename LhsScalar_,
typename RhsScalar_,
bool ConjLhs_,
bool ConjRhs_,
int Arch,
int PacketSize_>
400 typedef LhsScalar_ LhsScalar;
401 typedef RhsScalar_ RhsScalar;
402 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
404 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
405 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
406 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
411 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
412 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
413 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
414 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
416 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
422 default_mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * LhsPacketSize,
423#
if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && \
424 !defined(EIGEN_VECTORIZE_VSX) && ((!EIGEN_COMP_MSVC) || (EIGEN_COMP_MSVC >= 1914))
429 mr = Vectorizable ? 3 * LhsPacketSize : default_mr,
434 LhsProgress = LhsPacketSize,
438 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
439 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
440 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
441 typedef LhsPacket LhsPacket4Packing;
443 typedef QuadPacket<RhsPacket> RhsPacketx4;
444 typedef ResPacket AccPacket;
446 EIGEN_STRONG_INLINE
void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
448 template <
typename RhsPacketType>
449 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
450 dest = pset1<RhsPacketType>(*b);
453 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketx4& dest)
const {
454 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
457 template <
typename RhsPacketType>
458 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
462 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar*, RhsPacketx4&)
const {}
464 EIGEN_STRONG_INLINE
void loadRhsQuad(
const RhsScalar* b, RhsPacket& dest)
const { dest = ploadquad<RhsPacket>(b); }
466 template <
typename LhsPacketType>
467 EIGEN_STRONG_INLINE
void loadLhs(
const LhsScalar* a, LhsPacketType& dest)
const {
468 dest = pload<LhsPacketType>(a);
471 template <
typename LhsPacketType>
472 EIGEN_STRONG_INLINE
void loadLhsUnaligned(
const LhsScalar* a, LhsPacketType& dest)
const {
473 dest = ploadu<LhsPacketType>(a);
476 template <
typename LhsPacketType,
typename RhsPacketType,
typename AccPacketType,
typename LaneIdType>
477 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
478 const LaneIdType&)
const {
479 conj_helper<LhsPacketType, RhsPacketType, ConjLhs, ConjRhs> cj;
484#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
485 EIGEN_UNUSED_VARIABLE(tmp);
486 c = cj.pmadd(a, b, c);
489 tmp = cj.pmul(a, tmp);
494 template <
typename LhsPacketType,
typename AccPacketType,
typename LaneIdType>
495 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
496 const LaneIdType& lane)
const {
497 madd(a, b.get(lane), c, tmp, lane);
500 EIGEN_STRONG_INLINE
void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const {
501 r = pmadd(c, alpha, r);
504 template <
typename ResPacketHalf>
505 EIGEN_STRONG_INLINE
void acc(
const ResPacketHalf& c,
const ResPacketHalf& alpha, ResPacketHalf& r)
const {
506 r = pmadd(c, alpha, r);
510template <
typename RealScalar,
bool ConjLhs_,
int Arch,
int PacketSize_>
511class gebp_traits<std::complex<RealScalar>, RealScalar, ConjLhs_, false, Arch, PacketSize_> {
513 typedef std::complex<RealScalar> LhsScalar;
514 typedef RealScalar RhsScalar;
515 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
517 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
518 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
519 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
524 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
525 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
526 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
527 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
529 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
531#if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) && !defined(EIGEN_VECTORIZE_ALTIVEC) && !defined(EIGEN_VECTORIZE_VSX)
533 mr = 3 * LhsPacketSize,
535 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * LhsPacketSize,
538 LhsProgress = LhsPacketSize,
542 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
543 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
544 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
545 typedef LhsPacket LhsPacket4Packing;
547 typedef QuadPacket<RhsPacket> RhsPacketx4;
549 typedef ResPacket AccPacket;
551 EIGEN_STRONG_INLINE
void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
553 template <
typename RhsPacketType>
554 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
555 dest = pset1<RhsPacketType>(*b);
558 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketx4& dest)
const {
559 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
562 template <
typename RhsPacketType>
563 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
567 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar*, RhsPacketx4&)
const {}
569 EIGEN_STRONG_INLINE
void loadRhsQuad(
const RhsScalar* b, RhsPacket& dest)
const {
570 loadRhsQuad_impl(b, dest, std::conditional_t<RhsPacketSize == 16, true_type, false_type>());
573 EIGEN_STRONG_INLINE
void loadRhsQuad_impl(
const RhsScalar* b, RhsPacket& dest,
const true_type&)
const {
576 RhsScalar tmp[4] = {b[0], b[0], b[1], b[1]};
577 dest = ploadquad<RhsPacket>(tmp);
580 EIGEN_STRONG_INLINE
void loadRhsQuad_impl(
const RhsScalar* b, RhsPacket& dest,
const false_type&)
const {
581 eigen_internal_assert(RhsPacketSize <= 8);
582 dest = pset1<RhsPacket>(*b);
585 EIGEN_STRONG_INLINE
void loadLhs(
const LhsScalar* a, LhsPacket& dest)
const { dest = pload<LhsPacket>(a); }
587 template <
typename LhsPacketType>
588 EIGEN_STRONG_INLINE
void loadLhsUnaligned(
const LhsScalar* a, LhsPacketType& dest)
const {
589 dest = ploadu<LhsPacketType>(a);
592 template <
typename LhsPacketType,
typename RhsPacketType,
typename AccPacketType,
typename LaneIdType>
593 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
594 const LaneIdType&)
const {
595 madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable, true_type, false_type>());
598 template <
typename LhsPacketType,
typename RhsPacketType,
typename AccPacketType>
599 EIGEN_STRONG_INLINE
void madd_impl(
const LhsPacketType& a,
const RhsPacketType& b, AccPacketType& c,
600 RhsPacketType& tmp,
const true_type&)
const {
601#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
602 EIGEN_UNUSED_VARIABLE(tmp);
603 c.v = pmadd(a.v, b, c.v);
606 tmp = pmul(a.v, tmp);
607 c.v = padd(c.v, tmp);
611 EIGEN_STRONG_INLINE
void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
612 const false_type&)
const {
616 template <
typename LhsPacketType,
typename AccPacketType,
typename LaneIdType>
617 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
618 const LaneIdType& lane)
const {
619 madd(a, b.get(lane), c, tmp, lane);
622 template <
typename ResPacketType,
typename AccPacketType>
623 EIGEN_STRONG_INLINE
void acc(
const AccPacketType& c,
const ResPacketType& alpha, ResPacketType& r)
const {
624 conj_helper<ResPacketType, ResPacketType, ConjLhs, false> cj;
625 r = cj.pmadd(c, alpha, r);
631template <
typename Packet>
637template <
typename Packet>
638DoublePacket<Packet> padd(
const DoublePacket<Packet>& a,
const DoublePacket<Packet>& b) {
639 DoublePacket<Packet> res;
640 res.first = padd(a.first, b.first);
641 res.second = padd(a.second, b.second);
649template <
typename Packet>
650const DoublePacket<Packet>& predux_half_dowto4(
const DoublePacket<Packet>& a,
651 std::enable_if_t<unpacket_traits<Packet>::size <= 8>* = 0) {
655template <
typename Packet>
656DoublePacket<typename unpacket_traits<Packet>::half> predux_half_dowto4(
657 const DoublePacket<Packet>& a, std::enable_if_t<unpacket_traits<Packet>::size == 16>* = 0) {
659 DoublePacket<typename unpacket_traits<Packet>::half> res;
660 typedef std::complex<typename unpacket_traits<Packet>::type> Cplx;
661 typedef typename packet_traits<Cplx>::type CplxPacket;
662 res.first = predux_half_dowto4(CplxPacket(a.first)).v;
663 res.second = predux_half_dowto4(CplxPacket(a.second)).v;
668template <
typename Scalar,
typename RealPacket>
669void loadQuadToDoublePacket(
const Scalar* b, DoublePacket<RealPacket>& dest,
670 std::enable_if_t<unpacket_traits<RealPacket>::size <= 8>* = 0) {
671 dest.first = pset1<RealPacket>(numext::real(*b));
672 dest.second = pset1<RealPacket>(numext::imag(*b));
675template <
typename Scalar,
typename RealPacket>
676void loadQuadToDoublePacket(
const Scalar* b, DoublePacket<RealPacket>& dest,
677 std::enable_if_t<unpacket_traits<RealPacket>::size == 16>* = 0) {
679 typedef typename NumTraits<Scalar>::Real RealScalar;
680 RealScalar r[4] = {numext::real(b[0]), numext::real(b[0]), numext::real(b[1]), numext::real(b[1])};
681 RealScalar i[4] = {numext::imag(b[0]), numext::imag(b[0]), numext::imag(b[1]), numext::imag(b[1])};
682 dest.first = ploadquad<RealPacket>(r);
683 dest.second = ploadquad<RealPacket>(i);
686template <
typename Packet>
687struct unpacket_traits<DoublePacket<Packet> > {
688 typedef DoublePacket<typename unpacket_traits<Packet>::half> half;
689 enum { size = 2 * unpacket_traits<Packet>::size };
700template <
typename RealScalar,
bool ConjLhs_,
bool ConjRhs_,
int Arch,
int PacketSize_>
701class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, ConjLhs_, ConjRhs_, Arch, PacketSize_> {
703 typedef std::complex<RealScalar> Scalar;
704 typedef std::complex<RealScalar> LhsScalar;
705 typedef std::complex<RealScalar> RhsScalar;
706 typedef std::complex<RealScalar> ResScalar;
708 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
709 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
710 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
711 PACKET_DECL_COND(Real, PacketSize_);
712 PACKET_DECL_COND_SCALAR(PacketSize_);
717 Vectorizable = unpacket_traits<RealPacket>::vectorizable && unpacket_traits<ScalarPacket>::vectorizable,
718 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
719 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
720 RhsPacketSize = Vectorizable ? unpacket_traits<RhsScalar>::size : 1,
721 RealPacketSize = Vectorizable ? unpacket_traits<RealPacket>::size : 1,
722 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
725 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * ResPacketSize,
727 LhsProgress = ResPacketSize,
731 typedef DoublePacket<RealPacket> DoublePacketType;
733 typedef std::conditional_t<Vectorizable, ScalarPacket, Scalar> LhsPacket4Packing;
734 typedef std::conditional_t<Vectorizable, RealPacket, Scalar> LhsPacket;
735 typedef std::conditional_t<Vectorizable, DoublePacketType, Scalar> RhsPacket;
736 typedef std::conditional_t<Vectorizable, ScalarPacket, Scalar> ResPacket;
737 typedef std::conditional_t<Vectorizable, DoublePacketType, Scalar> AccPacket;
740 typedef QuadPacket<RhsPacket> RhsPacketx4;
742 EIGEN_STRONG_INLINE
void initAcc(Scalar& p) { p = Scalar(0); }
744 EIGEN_STRONG_INLINE
void initAcc(DoublePacketType& p) {
745 p.first = pset1<RealPacket>(RealScalar(0));
746 p.second = pset1<RealPacket>(RealScalar(0));
750 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, ScalarPacket& dest)
const { dest = pset1<ScalarPacket>(*b); }
753 template <
typename RealPacketType>
754 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, DoublePacket<RealPacketType>& dest)
const {
755 dest.first = pset1<RealPacketType>(numext::real(*b));
756 dest.second = pset1<RealPacketType>(numext::imag(*b));
759 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketx4& dest)
const {
760 loadRhs(b, dest.B_0);
761 loadRhs(b + 1, dest.B1);
762 loadRhs(b + 2, dest.B2);
763 loadRhs(b + 3, dest.B3);
767 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar* b, ScalarPacket& dest)
const { loadRhs(b, dest); }
770 template <
typename RealPacketType>
771 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar* b, DoublePacket<RealPacketType>& dest)
const {
775 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar*, RhsPacketx4&)
const {}
777 EIGEN_STRONG_INLINE
void loadRhsQuad(
const RhsScalar* b, ResPacket& dest)
const { loadRhs(b, dest); }
778 EIGEN_STRONG_INLINE
void loadRhsQuad(
const RhsScalar* b, DoublePacketType& dest)
const {
779 loadQuadToDoublePacket(b, dest);
783 EIGEN_STRONG_INLINE
void loadLhs(
const LhsScalar* a, LhsPacket& dest)
const {
784 dest = pload<LhsPacket>((
const typename unpacket_traits<LhsPacket>::type*)(a));
787 template <
typename LhsPacketType>
788 EIGEN_STRONG_INLINE
void loadLhsUnaligned(
const LhsScalar* a, LhsPacketType& dest)
const {
789 dest = ploadu<LhsPacketType>((
const typename unpacket_traits<LhsPacketType>::type*)(a));
792 template <
typename LhsPacketType,
typename RhsPacketType,
typename ResPacketType,
typename TmpType,
794 EIGEN_STRONG_INLINE std::enable_if_t<!is_same<RhsPacketType, RhsPacketx4>::value> madd(
const LhsPacketType& a,
795 const RhsPacketType& b,
796 DoublePacket<ResPacketType>& c,
798 const LaneIdType&)
const {
799 c.first = pmadd(a, b.first, c.first);
800 c.second = pmadd(a, b.second, c.second);
803 template <
typename LaneIdType>
804 EIGEN_STRONG_INLINE
void madd(
const LhsPacket& a,
const RhsPacket& b, ResPacket& c, RhsPacket& ,
805 const LaneIdType&)
const {
806 c = cj.pmadd(a, b, c);
809 template <
typename LhsPacketType,
typename AccPacketType,
typename LaneIdType>
810 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
811 const LaneIdType& lane)
const {
812 madd(a, b.get(lane), c, tmp, lane);
815 EIGEN_STRONG_INLINE
void acc(
const Scalar& c,
const Scalar& alpha, Scalar& r)
const { r += alpha * c; }
817 template <
typename RealPacketType,
typename ResPacketType>
818 EIGEN_STRONG_INLINE
void acc(
const DoublePacket<RealPacketType>& c,
const ResPacketType& alpha,
819 ResPacketType& r)
const {
822 if ((!ConjLhs) && (!ConjRhs)) {
823 tmp = pcplxflip(pconj(ResPacketType(c.second)));
824 tmp = padd(ResPacketType(c.first), tmp);
825 }
else if ((!ConjLhs) && (ConjRhs)) {
826 tmp = pconj(pcplxflip(ResPacketType(c.second)));
827 tmp = padd(ResPacketType(c.first), tmp);
828 }
else if ((ConjLhs) && (!ConjRhs)) {
829 tmp = pcplxflip(ResPacketType(c.second));
830 tmp = padd(pconj(ResPacketType(c.first)), tmp);
831 }
else if ((ConjLhs) && (ConjRhs)) {
832 tmp = pcplxflip(ResPacketType(c.second));
833 tmp = psub(pconj(ResPacketType(c.first)), tmp);
836 r = pmadd(tmp, alpha, r);
840 conj_helper<LhsScalar, RhsScalar, ConjLhs, ConjRhs> cj;
843template <
typename RealScalar,
bool ConjRhs_,
int Arch,
int PacketSize_>
844class gebp_traits<RealScalar, std::complex<RealScalar>, false, ConjRhs_, Arch, PacketSize_> {
846 typedef std::complex<RealScalar> Scalar;
847 typedef RealScalar LhsScalar;
848 typedef Scalar RhsScalar;
849 typedef Scalar ResScalar;
851 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
852 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
853 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
854 PACKET_DECL_COND_POSTFIX(_, Real, PacketSize_);
855 PACKET_DECL_COND_SCALAR_POSTFIX(_, PacketSize_);
857#undef PACKET_DECL_COND_SCALAR_POSTFIX
858#undef PACKET_DECL_COND_POSTFIX
859#undef PACKET_DECL_COND_SCALAR
860#undef PACKET_DECL_COND
865 Vectorizable = unpacket_traits<RealPacket_>::vectorizable && unpacket_traits<ScalarPacket_>::vectorizable,
866 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
867 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
868 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
870 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
873 mr = (plain_enum_min(16, NumberOfRegisters) / 2 / nr) * ResPacketSize,
875 LhsProgress = ResPacketSize,
879 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
880 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
881 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
882 typedef LhsPacket LhsPacket4Packing;
883 typedef QuadPacket<RhsPacket> RhsPacketx4;
884 typedef ResPacket AccPacket;
886 EIGEN_STRONG_INLINE
void initAcc(AccPacket& p) { p = pset1<ResPacket>(ResScalar(0)); }
888 template <
typename RhsPacketType>
889 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
890 dest = pset1<RhsPacketType>(*b);
893 EIGEN_STRONG_INLINE
void loadRhs(
const RhsScalar* b, RhsPacketx4& dest)
const {
894 pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
897 template <
typename RhsPacketType>
898 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar* b, RhsPacketType& dest)
const {
902 EIGEN_STRONG_INLINE
void updateRhs(
const RhsScalar*, RhsPacketx4&)
const {}
904 EIGEN_STRONG_INLINE
void loadLhs(
const LhsScalar* a, LhsPacket& dest)
const { dest = ploaddup<LhsPacket>(a); }
906 EIGEN_STRONG_INLINE
void loadRhsQuad(
const RhsScalar* b, RhsPacket& dest)
const { dest = ploadquad<RhsPacket>(b); }
908 template <
typename LhsPacketType>
909 EIGEN_STRONG_INLINE
void loadLhsUnaligned(
const LhsScalar* a, LhsPacketType& dest)
const {
910 dest = ploaddup<LhsPacketType>(a);
913 template <
typename LhsPacketType,
typename RhsPacketType,
typename AccPacketType,
typename LaneIdType>
914 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp,
915 const LaneIdType&)
const {
916 madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable, true_type, false_type>());
919 template <
typename LhsPacketType,
typename RhsPacketType,
typename AccPacketType>
920 EIGEN_STRONG_INLINE
void madd_impl(
const LhsPacketType& a,
const RhsPacketType& b, AccPacketType& c,
921 RhsPacketType& tmp,
const true_type&)
const {
922#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
923 EIGEN_UNUSED_VARIABLE(tmp);
924 c.v = pmadd(a, b.v, c.v);
927 tmp.v = pmul(a, tmp.v);
932 EIGEN_STRONG_INLINE
void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
933 const false_type&)
const {
937 template <
typename LhsPacketType,
typename AccPacketType,
typename LaneIdType>
938 EIGEN_STRONG_INLINE
void madd(
const LhsPacketType& a,
const RhsPacketx4& b, AccPacketType& c, RhsPacket& tmp,
939 const LaneIdType& lane)
const {
940 madd(a, b.get(lane), c, tmp, lane);
943 template <
typename ResPacketType,
typename AccPacketType>
944 EIGEN_STRONG_INLINE
void acc(
const AccPacketType& c,
const ResPacketType& alpha, ResPacketType& r)
const {
945 conj_helper<ResPacketType, ResPacketType, false, ConjRhs> cj;
946 r = cj.pmadd(alpha, c, r);
959template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
960 bool ConjugateLhs,
bool ConjugateRhs>
962 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
963 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target, GEBPPacketHalf>
965 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target, GEBPPacketQuarter>
968 typedef typename Traits::ResScalar ResScalar;
969 typedef typename Traits::LhsPacket LhsPacket;
970 typedef typename Traits::RhsPacket RhsPacket;
971 typedef typename Traits::ResPacket ResPacket;
972 typedef typename Traits::AccPacket AccPacket;
973 typedef typename Traits::RhsPacketx4 RhsPacketx4;
975 typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
976 typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 27>::type RhsPanel27;
978 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
980 typedef typename SwappedTraits::ResScalar SResScalar;
981 typedef typename SwappedTraits::LhsPacket SLhsPacket;
982 typedef typename SwappedTraits::RhsPacket SRhsPacket;
983 typedef typename SwappedTraits::ResPacket SResPacket;
984 typedef typename SwappedTraits::AccPacket SAccPacket;
986 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
987 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
988 typedef typename HalfTraits::ResPacket ResPacketHalf;
989 typedef typename HalfTraits::AccPacket AccPacketHalf;
991 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
992 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
993 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
994 typedef typename QuarterTraits::AccPacket AccPacketQuarter;
996 typedef typename DataMapper::LinearMapper LinearMapper;
999 Vectorizable = Traits::Vectorizable,
1000 LhsProgress = Traits::LhsProgress,
1001 LhsProgressHalf = HalfTraits::LhsProgress,
1002 LhsProgressQuarter = QuarterTraits::LhsProgress,
1003 RhsProgress = Traits::RhsProgress,
1004 RhsProgressHalf = HalfTraits::RhsProgress,
1005 RhsProgressQuarter = QuarterTraits::RhsProgress,
1006 ResPacketSize = Traits::ResPacketSize
1009 EIGEN_DONT_INLINE
void operator()(
const DataMapper& res,
const LhsScalar* blockA,
const RhsScalar* blockB, Index rows,
1010 Index depth, Index cols, ResScalar alpha, Index strideA = -1, Index strideB = -1,
1011 Index offsetA = 0, Index offsetB = 0);
1014template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
1015 bool ConjugateLhs,
bool ConjugateRhs,
1016 int SwappedLhsProgress =
1017 gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target>::LhsProgress>
1018struct last_row_process_16_packets {
1019 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
1020 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
1022 typedef typename Traits::ResScalar ResScalar;
1023 typedef typename SwappedTraits::LhsPacket SLhsPacket;
1024 typedef typename SwappedTraits::RhsPacket SRhsPacket;
1025 typedef typename SwappedTraits::ResPacket SResPacket;
1026 typedef typename SwappedTraits::AccPacket SAccPacket;
1028 EIGEN_STRONG_INLINE
void operator()(
const DataMapper& res, SwappedTraits& straits,
const LhsScalar* blA,
1029 const RhsScalar* blB, Index depth,
const Index endk, Index i, Index j2,
1030 ResScalar alpha, SAccPacket& C0) {
1031 EIGEN_UNUSED_VARIABLE(res);
1032 EIGEN_UNUSED_VARIABLE(straits);
1033 EIGEN_UNUSED_VARIABLE(blA);
1034 EIGEN_UNUSED_VARIABLE(blB);
1035 EIGEN_UNUSED_VARIABLE(depth);
1036 EIGEN_UNUSED_VARIABLE(endk);
1037 EIGEN_UNUSED_VARIABLE(i);
1038 EIGEN_UNUSED_VARIABLE(j2);
1039 EIGEN_UNUSED_VARIABLE(alpha);
1040 EIGEN_UNUSED_VARIABLE(C0);
1044template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
1045 bool ConjugateLhs,
bool ConjugateRhs>
1046struct last_row_process_16_packets<LhsScalar, RhsScalar,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
1047 typedef gebp_traits<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs, Architecture::Target> Traits;
1048 typedef gebp_traits<RhsScalar, LhsScalar, ConjugateRhs, ConjugateLhs, Architecture::Target> SwappedTraits;
1050 typedef typename Traits::ResScalar ResScalar;
1051 typedef typename SwappedTraits::LhsPacket SLhsPacket;
1052 typedef typename SwappedTraits::RhsPacket SRhsPacket;
1053 typedef typename SwappedTraits::ResPacket SResPacket;
1054 typedef typename SwappedTraits::AccPacket SAccPacket;
1056 EIGEN_STRONG_INLINE
void operator()(
const DataMapper& res, SwappedTraits& straits,
const LhsScalar* blA,
1057 const RhsScalar* blB, Index depth,
const Index endk, Index i, Index j2,
1058 ResScalar alpha, SAccPacket& C0) {
1059 typedef typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half SResPacketQuarter;
1060 typedef typename unpacket_traits<typename unpacket_traits<SLhsPacket>::half>::half SLhsPacketQuarter;
1061 typedef typename unpacket_traits<typename unpacket_traits<SRhsPacket>::half>::half SRhsPacketQuarter;
1062 typedef typename unpacket_traits<typename unpacket_traits<SAccPacket>::half>::half SAccPacketQuarter;
1064 SResPacketQuarter R = res.template gatherPacket<SResPacketQuarter>(i, j2);
1065 SResPacketQuarter alphav = pset1<SResPacketQuarter>(alpha);
1067 if (depth - endk > 0) {
1070 SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0));
1072 for (Index kk = endk; kk < depth; kk++) {
1073 SLhsPacketQuarter a0;
1074 SRhsPacketQuarter b0;
1075 straits.loadLhsUnaligned(blB, a0);
1076 straits.loadRhs(blA, b0);
1077 straits.madd(a0, b0, c0, b0, fix<0>);
1078 blB += SwappedTraits::LhsProgress / 4;
1081 straits.acc(c0, alphav, R);
1083 straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R);
1085 res.scatterPacket(i, j2, R);
1089template <
int nr,
Index LhsProgress,
Index RhsProgress,
typename LhsScalar,
typename RhsScalar,
typename ResScalar,
1090 typename AccPacket,
typename LhsPacket,
typename RhsPacket,
typename ResPacket,
typename GEBPTraits,
1091 typename LinearMapper,
typename DataMapper>
1092struct lhs_process_one_packet {
1093 typedef typename GEBPTraits::RhsPacketx4 RhsPacketx4;
1095 EIGEN_STRONG_INLINE
void peeled_kc_onestep(Index K,
const LhsScalar* blA,
const RhsScalar* blB, GEBPTraits traits,
1096 LhsPacket* A0, RhsPacketx4* rhs_panel, RhsPacket* T0, AccPacket* C0,
1097 AccPacket* C1, AccPacket* C2, AccPacket* C3) {
1098 EIGEN_ASM_COMMENT(
"begin step of gebp micro kernel 1X4");
1099 EIGEN_ASM_COMMENT(
"Note: these asm comments work around bug 935!");
1100 traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], *A0);
1101 traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], *rhs_panel);
1102 traits.madd(*A0, *rhs_panel, *C0, *T0, fix<0>);
1103 traits.madd(*A0, *rhs_panel, *C1, *T0, fix<1>);
1104 traits.madd(*A0, *rhs_panel, *C2, *T0, fix<2>);
1105 traits.madd(*A0, *rhs_panel, *C3, *T0, fix<3>);
1106#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE) && !(EIGEN_COMP_LCC)
1107 __asm__(
"" :
"+x,m"(*A0));
1109 EIGEN_ASM_COMMENT(
"end step of gebp micro kernel 1X4");
1112 EIGEN_STRONG_INLINE
void operator()(
const DataMapper& res,
const LhsScalar* blockA,
const RhsScalar* blockB,
1113 ResScalar alpha, Index peelStart, Index peelEnd, Index strideA, Index strideB,
1114 Index offsetA, Index offsetB,
int prefetch_res_offset, Index peeled_kc, Index pk,
1115 Index cols, Index depth, Index packet_cols4) {
1117 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1120 for (Index i = peelStart; i < peelEnd; i += LhsProgress) {
1121#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1122 EIGEN_IF_CONSTEXPR(nr >= 8) {
1123 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1124 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1128 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
1138 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1139 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1140 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1141 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1142 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1143 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1144 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1145 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1146 r0.prefetch(prefetch_res_offset);
1147 r1.prefetch(prefetch_res_offset);
1148 r2.prefetch(prefetch_res_offset);
1149 r3.prefetch(prefetch_res_offset);
1150 r4.prefetch(prefetch_res_offset);
1151 r5.prefetch(prefetch_res_offset);
1152 r6.prefetch(prefetch_res_offset);
1153 r7.prefetch(prefetch_res_offset);
1154 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1158 for (Index k = 0; k < peeled_kc; k += pk) {
1159 RhsPacketx4 rhs_panel;
1161#define EIGEN_GEBGP_ONESTEP(K) \
1163 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX8"); \
1164 traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], A0); \
1165 traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
1166 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1167 traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
1168 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1169 traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
1170 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1171 traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
1172 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1173 traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
1174 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
1175 traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
1176 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
1177 traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
1178 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
1179 traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
1180 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
1181 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \
1184 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 1pX8");
1186 EIGEN_GEBGP_ONESTEP(0);
1187 EIGEN_GEBGP_ONESTEP(1);
1188 EIGEN_GEBGP_ONESTEP(2);
1189 EIGEN_GEBGP_ONESTEP(3);
1190 EIGEN_GEBGP_ONESTEP(4);
1191 EIGEN_GEBGP_ONESTEP(5);
1192 EIGEN_GEBGP_ONESTEP(6);
1193 EIGEN_GEBGP_ONESTEP(7);
1195 blB += pk * 8 * RhsProgress;
1196 blA += pk * (1 * LhsProgress);
1198 EIGEN_ASM_COMMENT(
"end gebp micro kernel 1pX8");
1201 for (Index k = peeled_kc; k < depth; k++) {
1202 RhsPacketx4 rhs_panel;
1204 EIGEN_GEBGP_ONESTEP(0);
1205 blB += 8 * RhsProgress;
1206 blA += 1 * LhsProgress;
1209#undef EIGEN_GEBGP_ONESTEP
1212 ResPacket alphav = pset1<ResPacket>(alpha);
1214 R0 = r0.template loadPacket<ResPacket>(0);
1215 R1 = r1.template loadPacket<ResPacket>(0);
1216 traits.acc(C0, alphav, R0);
1217 traits.acc(C1, alphav, R1);
1218 r0.storePacket(0, R0);
1219 r1.storePacket(0, R1);
1221 R0 = r2.template loadPacket<ResPacket>(0);
1222 R1 = r3.template loadPacket<ResPacket>(0);
1223 traits.acc(C2, alphav, R0);
1224 traits.acc(C3, alphav, R1);
1225 r2.storePacket(0, R0);
1226 r3.storePacket(0, R1);
1228 R0 = r4.template loadPacket<ResPacket>(0);
1229 R1 = r5.template loadPacket<ResPacket>(0);
1230 traits.acc(C4, alphav, R0);
1231 traits.acc(C5, alphav, R1);
1232 r4.storePacket(0, R0);
1233 r5.storePacket(0, R1);
1235 R0 = r6.template loadPacket<ResPacket>(0);
1236 R1 = r7.template loadPacket<ResPacket>(0);
1237 traits.acc(C6, alphav, R0);
1238 traits.acc(C7, alphav, R1);
1239 r6.storePacket(0, R0);
1240 r7.storePacket(0, R1);
1246 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1250 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1254 AccPacket C0, C1, C2, C3;
1264 AccPacket D0, D1, D2, D3;
1270 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1271 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1272 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1273 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1275 r0.prefetch(prefetch_res_offset);
1276 r1.prefetch(prefetch_res_offset);
1277 r2.prefetch(prefetch_res_offset);
1278 r3.prefetch(prefetch_res_offset);
1281 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
1285 for (Index k = 0; k < peeled_kc; k += pk) {
1286 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 1/half/quarterX4");
1287 RhsPacketx4 rhs_panel;
1290 internal::prefetch(blB + (48 + 0));
1291 peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1292 peeled_kc_onestep(1, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1293 peeled_kc_onestep(2, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1294 peeled_kc_onestep(3, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1295 internal::prefetch(blB + (48 + 16));
1296 peeled_kc_onestep(4, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1297 peeled_kc_onestep(5, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1298 peeled_kc_onestep(6, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1299 peeled_kc_onestep(7, blA, blB, traits, &A1, &rhs_panel, &T0, &D0, &D1, &D2, &D3);
1301 blB += pk * 4 * RhsProgress;
1302 blA += pk * LhsProgress;
1304 EIGEN_ASM_COMMENT(
"end gebp micro kernel 1/half/quarterX4");
1312 for (Index k = peeled_kc; k < depth; k++) {
1313 RhsPacketx4 rhs_panel;
1315 peeled_kc_onestep(0, blA, blB, traits, &A0, &rhs_panel, &T0, &C0, &C1, &C2, &C3);
1316 blB += 4 * RhsProgress;
1321 ResPacket alphav = pset1<ResPacket>(alpha);
1323 R0 = r0.template loadPacket<ResPacket>(0);
1324 R1 = r1.template loadPacket<ResPacket>(0);
1325 traits.acc(C0, alphav, R0);
1326 traits.acc(C1, alphav, R1);
1327 r0.storePacket(0, R0);
1328 r1.storePacket(0, R1);
1330 R0 = r2.template loadPacket<ResPacket>(0);
1331 R1 = r3.template loadPacket<ResPacket>(0);
1332 traits.acc(C2, alphav, R0);
1333 traits.acc(C3, alphav, R1);
1334 r2.storePacket(0, R0);
1335 r3.storePacket(0, R1);
1339 for (Index j2 = packet_cols4; j2 < cols; j2++) {
1341 const LhsScalar* blA = &blockA[i * strideA + offsetA * (LhsProgress)];
1348 LinearMapper r0 = res.getLinearMapper(i, j2);
1351 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
1354 for (Index k = 0; k < peeled_kc; k += pk) {
1355 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 1/half/quarterX1");
1358#define EIGEN_GEBGP_ONESTEP(K) \
1360 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1/half/quarterX1"); \
1361 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1363 traits.loadLhsUnaligned(&blA[(0 + 1 * K) * LhsProgress], A0); \
1364 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
1365 traits.madd(A0, B_0, C0, B_0, fix<0>); \
1366 EIGEN_ASM_COMMENT("end step of gebp micro kernel 1/half/quarterX1"); \
1369 EIGEN_GEBGP_ONESTEP(0);
1370 EIGEN_GEBGP_ONESTEP(1);
1371 EIGEN_GEBGP_ONESTEP(2);
1372 EIGEN_GEBGP_ONESTEP(3);
1373 EIGEN_GEBGP_ONESTEP(4);
1374 EIGEN_GEBGP_ONESTEP(5);
1375 EIGEN_GEBGP_ONESTEP(6);
1376 EIGEN_GEBGP_ONESTEP(7);
1378 blB += pk * RhsProgress;
1379 blA += pk * LhsProgress;
1381 EIGEN_ASM_COMMENT(
"end gebp micro kernel 1/half/quarterX1");
1385 for (Index k = peeled_kc; k < depth; k++) {
1387 EIGEN_GEBGP_ONESTEP(0);
1391#undef EIGEN_GEBGP_ONESTEP
1393 ResPacket alphav = pset1<ResPacket>(alpha);
1394 R0 = r0.template loadPacket<ResPacket>(0);
1395 traits.acc(C0, alphav, R0);
1396 r0.storePacket(0, R0);
1402template <
int nr,
Index LhsProgress,
Index RhsProgress,
typename LhsScalar,
typename RhsScalar,
typename ResScalar,
1403 typename AccPacket,
typename LhsPacket,
typename RhsPacket,
typename ResPacket,
typename GEBPTraits,
1404 typename LinearMapper,
typename DataMapper>
1405struct lhs_process_fraction_of_packet
1406 : lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket,
1407 RhsPacket, ResPacket, GEBPTraits, LinearMapper, DataMapper> {
1408 EIGEN_STRONG_INLINE
void peeled_kc_onestep(Index K,
const LhsScalar* blA,
const RhsScalar* blB, GEBPTraits traits,
1409 LhsPacket* A0, RhsPacket* B_0, RhsPacket* B1, RhsPacket* B2, RhsPacket* B3,
1410 AccPacket* C0, AccPacket* C1, AccPacket* C2, AccPacket* C3) {
1411 EIGEN_ASM_COMMENT(
"begin step of gebp micro kernel 1X4");
1412 EIGEN_ASM_COMMENT(
"Note: these asm comments work around bug 935!");
1413 traits.loadLhsUnaligned(&blA[(0 + 1 * K) * (LhsProgress)], *A0);
1414 traits.broadcastRhs(&blB[(0 + 4 * K) * RhsProgress], *B_0, *B1, *B2, *B3);
1415 traits.madd(*A0, *B_0, *C0, *B_0);
1416 traits.madd(*A0, *B1, *C1, *B1);
1417 traits.madd(*A0, *B2, *C2, *B2);
1418 traits.madd(*A0, *B3, *C3, *B3);
1419 EIGEN_ASM_COMMENT(
"end step of gebp micro kernel 1X4");
1423template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
1424 bool ConjugateLhs,
bool ConjugateRhs>
1425EIGEN_DONT_INLINE
void gebp_kernel<LhsScalar, RhsScalar,
Index, DataMapper, mr, nr, ConjugateLhs,
1426 ConjugateRhs>::operator()(
const DataMapper& res,
const LhsScalar* blockA,
1427 const RhsScalar* blockB,
Index rows,
Index depth,
1431 SwappedTraits straits;
1433 if (strideA == -1) strideA = depth;
1434 if (strideB == -1) strideB = depth;
1435 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
1436 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1437 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1438 const Index peeled_mc3 = mr >= 3 * Traits::LhsProgress ? (rows / (3 * LhsProgress)) * (3 * LhsProgress) : 0;
1439 const Index peeled_mc2 =
1440 mr >= 2 * Traits::LhsProgress ? peeled_mc3 + ((rows - peeled_mc3) / (2 * LhsProgress)) * (2 * LhsProgress) : 0;
1441 const Index peeled_mc1 =
1442 mr >= 1 * Traits::LhsProgress ? peeled_mc2 + ((rows - peeled_mc2) / (1 * LhsProgress)) * (1 * LhsProgress) : 0;
1443 const Index peeled_mc_half =
1444 mr >= LhsProgressHalf ? peeled_mc1 + ((rows - peeled_mc1) / (LhsProgressHalf)) * (LhsProgressHalf) : 0;
1445 const Index peeled_mc_quarter =
1446 mr >= LhsProgressQuarter
1447 ? peeled_mc_half + ((rows - peeled_mc_half) / (LhsProgressQuarter)) * (LhsProgressQuarter)
1450 const Index peeled_kc = depth & ~(pk - 1);
1451 const int prefetch_res_offset = 32 /
sizeof(ResScalar);
1457 if (mr >= 3 * Traits::LhsProgress) {
1462 const Index l1 = defaultL1CacheSize;
1466 const Index actual_panel_rows =
1467 (3 * LhsProgress) * std::max<Index>(1, ((l1 -
sizeof(ResScalar) * mr * nr - depth * nr *
sizeof(RhsScalar)) /
1468 (depth *
sizeof(LhsScalar) * 3 * LhsProgress)));
1469 for (
Index i1 = 0; i1 < peeled_mc3; i1 += actual_panel_rows) {
1470 const Index actual_panel_end = (std::min)(i1 + actual_panel_rows, peeled_mc3);
1471#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1472 EIGEN_IF_CONSTEXPR(nr >= 8) {
1473 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1474 for (
Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1475 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * LhsProgress)];
1478 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11, C12, C13, C14, C15, C16, C17, C18, C19, C20,
1490 traits.initAcc(C10);
1491 traits.initAcc(C11);
1492 traits.initAcc(C12);
1493 traits.initAcc(C13);
1494 traits.initAcc(C14);
1495 traits.initAcc(C15);
1496 traits.initAcc(C16);
1497 traits.initAcc(C17);
1498 traits.initAcc(C18);
1499 traits.initAcc(C19);
1500 traits.initAcc(C20);
1501 traits.initAcc(C21);
1502 traits.initAcc(C22);
1503 traits.initAcc(C23);
1505 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1506 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1507 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1508 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1509 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1510 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1511 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1512 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1524 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1527 for (
Index k = 0; k < peeled_kc; k += pk) {
1528 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 3pX8");
1530 RhsPanel27 rhs_panel;
1533#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
1537#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND __asm__("" : "+w,m"(A0), "+w,m"(A1), "+w,m"(A2));
1539#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND
1542#define EIGEN_GEBP_ONESTEP(K) \
1544 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX8"); \
1545 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1546 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1547 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1548 EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND traits.loadRhs(blB + (0 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1549 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1550 traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
1551 traits.madd(A2, rhs_panel, C16, T0, fix<0>); \
1552 traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1553 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1554 traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
1555 traits.madd(A2, rhs_panel, C17, T0, fix<1>); \
1556 traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1557 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1558 traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
1559 traits.madd(A2, rhs_panel, C18, T0, fix<2>); \
1560 traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1561 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1562 traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
1563 traits.madd(A2, rhs_panel, C19, T0, fix<3>); \
1564 traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1565 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
1566 traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
1567 traits.madd(A2, rhs_panel, C20, T0, fix<0>); \
1568 traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1569 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
1570 traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
1571 traits.madd(A2, rhs_panel, C21, T0, fix<1>); \
1572 traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1573 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
1574 traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
1575 traits.madd(A2, rhs_panel, C22, T0, fix<2>); \
1576 traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \
1577 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
1578 traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
1579 traits.madd(A2, rhs_panel, C23, T0, fix<3>); \
1580 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \
1583 EIGEN_GEBP_ONESTEP(0);
1584 EIGEN_GEBP_ONESTEP(1);
1585 EIGEN_GEBP_ONESTEP(2);
1586 EIGEN_GEBP_ONESTEP(3);
1587 EIGEN_GEBP_ONESTEP(4);
1588 EIGEN_GEBP_ONESTEP(5);
1589 EIGEN_GEBP_ONESTEP(6);
1590 EIGEN_GEBP_ONESTEP(7);
1592 blB += pk * 8 * RhsProgress;
1593 blA += pk * 3 * Traits::LhsProgress;
1594 EIGEN_ASM_COMMENT(
"end gebp micro kernel 3pX8");
1598 for (
Index k = peeled_kc; k < depth; k++) {
1599 RhsPanel27 rhs_panel;
1602 EIGEN_GEBP_ONESTEP(0);
1603 blB += 8 * RhsProgress;
1604 blA += 3 * Traits::LhsProgress;
1607#undef EIGEN_GEBP_ONESTEP
1609 ResPacket R0, R1, R2;
1610 ResPacket alphav = pset1<ResPacket>(alpha);
1612 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1613 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1614 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1615 traits.acc(C0, alphav, R0);
1616 traits.acc(C8, alphav, R1);
1617 traits.acc(C16, alphav, R2);
1618 r0.storePacket(0 * Traits::ResPacketSize, R0);
1619 r0.storePacket(1 * Traits::ResPacketSize, R1);
1620 r0.storePacket(2 * Traits::ResPacketSize, R2);
1622 R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1623 R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1624 R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1625 traits.acc(C1, alphav, R0);
1626 traits.acc(C9, alphav, R1);
1627 traits.acc(C17, alphav, R2);
1628 r1.storePacket(0 * Traits::ResPacketSize, R0);
1629 r1.storePacket(1 * Traits::ResPacketSize, R1);
1630 r1.storePacket(2 * Traits::ResPacketSize, R2);
1632 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1633 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1634 R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1635 traits.acc(C2, alphav, R0);
1636 traits.acc(C10, alphav, R1);
1637 traits.acc(C18, alphav, R2);
1638 r2.storePacket(0 * Traits::ResPacketSize, R0);
1639 r2.storePacket(1 * Traits::ResPacketSize, R1);
1640 r2.storePacket(2 * Traits::ResPacketSize, R2);
1642 R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1643 R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1644 R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1645 traits.acc(C3, alphav, R0);
1646 traits.acc(C11, alphav, R1);
1647 traits.acc(C19, alphav, R2);
1648 r3.storePacket(0 * Traits::ResPacketSize, R0);
1649 r3.storePacket(1 * Traits::ResPacketSize, R1);
1650 r3.storePacket(2 * Traits::ResPacketSize, R2);
1652 R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1653 R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1654 R2 = r4.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1655 traits.acc(C4, alphav, R0);
1656 traits.acc(C12, alphav, R1);
1657 traits.acc(C20, alphav, R2);
1658 r4.storePacket(0 * Traits::ResPacketSize, R0);
1659 r4.storePacket(1 * Traits::ResPacketSize, R1);
1660 r4.storePacket(2 * Traits::ResPacketSize, R2);
1662 R0 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1663 R1 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1664 R2 = r5.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1665 traits.acc(C5, alphav, R0);
1666 traits.acc(C13, alphav, R1);
1667 traits.acc(C21, alphav, R2);
1668 r5.storePacket(0 * Traits::ResPacketSize, R0);
1669 r5.storePacket(1 * Traits::ResPacketSize, R1);
1670 r5.storePacket(2 * Traits::ResPacketSize, R2);
1672 R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1673 R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1674 R2 = r6.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1675 traits.acc(C6, alphav, R0);
1676 traits.acc(C14, alphav, R1);
1677 traits.acc(C22, alphav, R2);
1678 r6.storePacket(0 * Traits::ResPacketSize, R0);
1679 r6.storePacket(1 * Traits::ResPacketSize, R1);
1680 r6.storePacket(2 * Traits::ResPacketSize, R2);
1682 R0 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1683 R1 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1684 R2 = r7.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1685 traits.acc(C7, alphav, R0);
1686 traits.acc(C15, alphav, R1);
1687 traits.acc(C23, alphav, R2);
1688 r7.storePacket(0 * Traits::ResPacketSize, R0);
1689 r7.storePacket(1 * Traits::ResPacketSize, R1);
1690 r7.storePacket(2 * Traits::ResPacketSize, R2);
1695 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1696 for (
Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1700 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * LhsProgress)];
1704 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11;
1715 traits.initAcc(C10);
1716 traits.initAcc(C11);
1718 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1719 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1720 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1721 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1729 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
1733 for (
Index k = 0; k < peeled_kc; k += pk) {
1734 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 3pX4");
1736 RhsPanel15 rhs_panel;
1739#if EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && EIGEN_GNUC_STRICT_LESS_THAN(9, 0, 0)
1743#define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND __asm__("" : "+w,m"(A0), "+w,m"(A1), "+w,m"(A2));
1745#define EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND
1747#define EIGEN_GEBP_ONESTEP(K) \
1749 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX4"); \
1750 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1751 internal::prefetch(blA + (3 * K + 16) * LhsProgress); \
1752 if (EIGEN_ARCH_ARM || EIGEN_ARCH_MIPS) { \
1753 internal::prefetch(blB + (4 * K + 16) * RhsProgress); \
1755 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1756 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1757 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1758 EIGEN_GEBP_3PX4_REGISTER_ALLOC_WORKAROUND \
1759 traits.loadRhs(blB + (0 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1760 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
1761 traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
1762 traits.madd(A2, rhs_panel, C8, T0, fix<0>); \
1763 traits.updateRhs(blB + (1 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1764 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
1765 traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
1766 traits.madd(A2, rhs_panel, C9, T0, fix<1>); \
1767 traits.updateRhs(blB + (2 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1768 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
1769 traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
1770 traits.madd(A2, rhs_panel, C10, T0, fix<2>); \
1771 traits.updateRhs(blB + (3 + 4 * K) * Traits::RhsProgress, rhs_panel); \
1772 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
1773 traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
1774 traits.madd(A2, rhs_panel, C11, T0, fix<3>); \
1775 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX4"); \
1778 internal::prefetch(blB);
1779 EIGEN_GEBP_ONESTEP(0);
1780 EIGEN_GEBP_ONESTEP(1);
1781 EIGEN_GEBP_ONESTEP(2);
1782 EIGEN_GEBP_ONESTEP(3);
1783 EIGEN_GEBP_ONESTEP(4);
1784 EIGEN_GEBP_ONESTEP(5);
1785 EIGEN_GEBP_ONESTEP(6);
1786 EIGEN_GEBP_ONESTEP(7);
1788 blB += pk * 4 * RhsProgress;
1789 blA += pk * 3 * Traits::LhsProgress;
1791 EIGEN_ASM_COMMENT(
"end gebp micro kernel 3pX4");
1794 for (
Index k = peeled_kc; k < depth; k++) {
1795 RhsPanel15 rhs_panel;
1798 EIGEN_GEBP_ONESTEP(0);
1799 blB += 4 * RhsProgress;
1800 blA += 3 * Traits::LhsProgress;
1803#undef EIGEN_GEBP_ONESTEP
1805 ResPacket R0, R1, R2;
1806 ResPacket alphav = pset1<ResPacket>(alpha);
1808 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1809 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1810 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1811 traits.acc(C0, alphav, R0);
1812 traits.acc(C4, alphav, R1);
1813 traits.acc(C8, alphav, R2);
1814 r0.storePacket(0 * Traits::ResPacketSize, R0);
1815 r0.storePacket(1 * Traits::ResPacketSize, R1);
1816 r0.storePacket(2 * Traits::ResPacketSize, R2);
1818 R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1819 R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1820 R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1821 traits.acc(C1, alphav, R0);
1822 traits.acc(C5, alphav, R1);
1823 traits.acc(C9, alphav, R2);
1824 r1.storePacket(0 * Traits::ResPacketSize, R0);
1825 r1.storePacket(1 * Traits::ResPacketSize, R1);
1826 r1.storePacket(2 * Traits::ResPacketSize, R2);
1828 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1829 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1830 R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1831 traits.acc(C2, alphav, R0);
1832 traits.acc(C6, alphav, R1);
1833 traits.acc(C10, alphav, R2);
1834 r2.storePacket(0 * Traits::ResPacketSize, R0);
1835 r2.storePacket(1 * Traits::ResPacketSize, R1);
1836 r2.storePacket(2 * Traits::ResPacketSize, R2);
1838 R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1839 R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1840 R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1841 traits.acc(C3, alphav, R0);
1842 traits.acc(C7, alphav, R1);
1843 traits.acc(C11, alphav, R2);
1844 r3.storePacket(0 * Traits::ResPacketSize, R0);
1845 r3.storePacket(1 * Traits::ResPacketSize, R1);
1846 r3.storePacket(2 * Traits::ResPacketSize, R2);
1851 for (
Index j2 = packet_cols4; j2 < cols; j2++) {
1852 for (
Index i = i1; i < actual_panel_end; i += 3 * LhsProgress) {
1854 const LhsScalar* blA = &blockA[i * strideA + offsetA * (3 * Traits::LhsProgress)];
1858 AccPacket C0, C4, C8;
1863 LinearMapper r0 = res.getLinearMapper(i, j2);
1867 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
1868 LhsPacket A0, A1, A2;
1870 for (
Index k = 0; k < peeled_kc; k += pk) {
1871 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 3pX1");
1873#define EIGEN_GEBGP_ONESTEP(K) \
1875 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX1"); \
1876 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
1877 traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
1878 traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
1879 traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
1880 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
1881 traits.madd(A0, B_0, C0, B_0, fix<0>); \
1882 traits.madd(A1, B_0, C4, B_0, fix<0>); \
1883 traits.madd(A2, B_0, C8, B_0, fix<0>); \
1884 EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX1"); \
1887 EIGEN_GEBGP_ONESTEP(0);
1888 EIGEN_GEBGP_ONESTEP(1);
1889 EIGEN_GEBGP_ONESTEP(2);
1890 EIGEN_GEBGP_ONESTEP(3);
1891 EIGEN_GEBGP_ONESTEP(4);
1892 EIGEN_GEBGP_ONESTEP(5);
1893 EIGEN_GEBGP_ONESTEP(6);
1894 EIGEN_GEBGP_ONESTEP(7);
1896 blB += int(pk) * int(RhsProgress);
1897 blA += int(pk) * 3 * int(Traits::LhsProgress);
1899 EIGEN_ASM_COMMENT(
"end gebp micro kernel 3pX1");
1903 for (
Index k = peeled_kc; k < depth; k++) {
1905 EIGEN_GEBGP_ONESTEP(0);
1907 blA += 3 * Traits::LhsProgress;
1909#undef EIGEN_GEBGP_ONESTEP
1910 ResPacket R0, R1, R2;
1911 ResPacket alphav = pset1<ResPacket>(alpha);
1913 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
1914 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
1915 R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
1916 traits.acc(C0, alphav, R0);
1917 traits.acc(C4, alphav, R1);
1918 traits.acc(C8, alphav, R2);
1919 r0.storePacket(0 * Traits::ResPacketSize, R0);
1920 r0.storePacket(1 * Traits::ResPacketSize, R1);
1921 r0.storePacket(2 * Traits::ResPacketSize, R2);
1928 if (mr >= 2 * Traits::LhsProgress) {
1929 const Index l1 = defaultL1CacheSize;
1933 Index actual_panel_rows =
1934 (2 * LhsProgress) * std::max<Index>(1, ((l1 -
sizeof(ResScalar) * mr * nr - depth * nr *
sizeof(RhsScalar)) /
1935 (depth *
sizeof(LhsScalar) * 2 * LhsProgress)));
1937 for (
Index i1 = peeled_mc3; i1 < peeled_mc2; i1 += actual_panel_rows) {
1938 Index actual_panel_end = (std::min)(i1 + actual_panel_rows, peeled_mc2);
1939#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
1940 EIGEN_IF_CONSTEXPR(nr >= 8) {
1941 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1942 for (
Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
1943 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
1946 AccPacket C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11, C12, C13, C14, C15;
1957 traits.initAcc(C10);
1958 traits.initAcc(C11);
1959 traits.initAcc(C12);
1960 traits.initAcc(C13);
1961 traits.initAcc(C14);
1962 traits.initAcc(C15);
1964 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
1965 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
1966 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
1967 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
1968 LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
1969 LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
1970 LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
1971 LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
1972 r0.prefetch(prefetch_res_offset);
1973 r1.prefetch(prefetch_res_offset);
1974 r2.prefetch(prefetch_res_offset);
1975 r3.prefetch(prefetch_res_offset);
1976 r4.prefetch(prefetch_res_offset);
1977 r5.prefetch(prefetch_res_offset);
1978 r6.prefetch(prefetch_res_offset);
1979 r7.prefetch(prefetch_res_offset);
1981 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
1984 for (
Index k = 0; k < peeled_kc; k += pk) {
1985 RhsPacketx4 rhs_panel;
1989#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE)
1990#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__("" : [a0] "+x,m"(A0), [a1] "+x,m"(A1));
1992#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND
1994#define EIGEN_GEBGP_ONESTEP(K) \
1996 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \
1997 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
1998 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
1999 traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
2000 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
2001 traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
2002 traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
2003 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
2004 traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
2005 traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
2006 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
2007 traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
2008 traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
2009 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
2010 traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
2011 traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
2012 traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
2013 traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
2014 traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
2015 traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
2016 traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
2017 traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
2018 traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
2019 traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
2020 traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
2021 traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
2022 traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
2023 EIGEN_GEBP_2Px8_SPILLING_WORKAROUND EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \
2026 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 2pX8");
2028 EIGEN_GEBGP_ONESTEP(0);
2029 EIGEN_GEBGP_ONESTEP(1);
2030 EIGEN_GEBGP_ONESTEP(2);
2031 EIGEN_GEBGP_ONESTEP(3);
2032 EIGEN_GEBGP_ONESTEP(4);
2033 EIGEN_GEBGP_ONESTEP(5);
2034 EIGEN_GEBGP_ONESTEP(6);
2035 EIGEN_GEBGP_ONESTEP(7);
2037 blB += pk * 8 * RhsProgress;
2038 blA += pk * (2 * Traits::LhsProgress);
2040 EIGEN_ASM_COMMENT(
"end gebp micro kernel 2pX8");
2043 for (
Index k = peeled_kc; k < depth; k++) {
2044 RhsPacketx4 rhs_panel;
2046 EIGEN_GEBGP_ONESTEP(0);
2047 blB += 8 * RhsProgress;
2048 blA += 2 * Traits::LhsProgress;
2051#undef EIGEN_GEBGP_ONESTEP
2053 ResPacket R0, R1, R2, R3;
2054 ResPacket alphav = pset1<ResPacket>(alpha);
2056 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2057 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2058 R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2059 R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2060 traits.acc(C0, alphav, R0);
2061 traits.acc(C8, alphav, R1);
2062 traits.acc(C1, alphav, R2);
2063 traits.acc(C9, alphav, R3);
2064 r0.storePacket(0 * Traits::ResPacketSize, R0);
2065 r0.storePacket(1 * Traits::ResPacketSize, R1);
2066 r1.storePacket(0 * Traits::ResPacketSize, R2);
2067 r1.storePacket(1 * Traits::ResPacketSize, R3);
2069 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2070 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2071 R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2072 R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2073 traits.acc(C2, alphav, R0);
2074 traits.acc(C10, alphav, R1);
2075 traits.acc(C3, alphav, R2);
2076 traits.acc(C11, alphav, R3);
2077 r2.storePacket(0 * Traits::ResPacketSize, R0);
2078 r2.storePacket(1 * Traits::ResPacketSize, R1);
2079 r3.storePacket(0 * Traits::ResPacketSize, R2);
2080 r3.storePacket(1 * Traits::ResPacketSize, R3);
2082 R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2083 R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2084 R2 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2085 R3 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2086 traits.acc(C4, alphav, R0);
2087 traits.acc(C12, alphav, R1);
2088 traits.acc(C5, alphav, R2);
2089 traits.acc(C13, alphav, R3);
2090 r4.storePacket(0 * Traits::ResPacketSize, R0);
2091 r4.storePacket(1 * Traits::ResPacketSize, R1);
2092 r5.storePacket(0 * Traits::ResPacketSize, R2);
2093 r5.storePacket(1 * Traits::ResPacketSize, R3);
2095 R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2096 R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2097 R2 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2098 R3 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2099 traits.acc(C6, alphav, R0);
2100 traits.acc(C14, alphav, R1);
2101 traits.acc(C7, alphav, R2);
2102 traits.acc(C15, alphav, R3);
2103 r6.storePacket(0 * Traits::ResPacketSize, R0);
2104 r6.storePacket(1 * Traits::ResPacketSize, R1);
2105 r7.storePacket(0 * Traits::ResPacketSize, R2);
2106 r7.storePacket(1 * Traits::ResPacketSize, R3);
2111 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2112 for (
Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
2116 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
2120 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
2130 LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
2131 LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
2132 LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
2133 LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
2135 r0.prefetch(prefetch_res_offset);
2136 r1.prefetch(prefetch_res_offset);
2137 r2.prefetch(prefetch_res_offset);
2138 r3.prefetch(prefetch_res_offset);
2141 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
2145 for (
Index k = 0; k < peeled_kc; k += pk) {
2146 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 2pX4");
2147 RhsPacketx4 rhs_panel;
2152#if EIGEN_GNUC_STRICT_AT_LEAST(6, 0, 0) && defined(EIGEN_VECTORIZE_SSE) && !(EIGEN_COMP_LCC)
2153#define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND __asm__("" : [a0] "+x,m"(A0), [a1] "+x,m"(A1));
2155#define EIGEN_GEBP_2PX4_SPILLING_WORKAROUND
2157#define EIGEN_GEBGP_ONESTEP(K) \
2159 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX4"); \
2160 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
2161 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
2162 traits.loadRhs(&blB[(0 + 4 * K) * RhsProgress], rhs_panel); \
2163 traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
2164 traits.madd(A1, rhs_panel, C4, T0, fix<0>); \
2165 traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
2166 traits.madd(A1, rhs_panel, C5, T0, fix<1>); \
2167 traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
2168 traits.madd(A1, rhs_panel, C6, T0, fix<2>); \
2169 traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
2170 traits.madd(A1, rhs_panel, C7, T0, fix<3>); \
2171 EIGEN_GEBP_2PX4_SPILLING_WORKAROUND \
2172 EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX4"); \
2175 internal::prefetch(blB + (48 + 0));
2176 EIGEN_GEBGP_ONESTEP(0);
2177 EIGEN_GEBGP_ONESTEP(1);
2178 EIGEN_GEBGP_ONESTEP(2);
2179 EIGEN_GEBGP_ONESTEP(3);
2180 internal::prefetch(blB + (48 + 16));
2181 EIGEN_GEBGP_ONESTEP(4);
2182 EIGEN_GEBGP_ONESTEP(5);
2183 EIGEN_GEBGP_ONESTEP(6);
2184 EIGEN_GEBGP_ONESTEP(7);
2186 blB += pk * 4 * RhsProgress;
2187 blA += pk * (2 * Traits::LhsProgress);
2189 EIGEN_ASM_COMMENT(
"end gebp micro kernel 2pX4");
2192 for (
Index k = peeled_kc; k < depth; k++) {
2193 RhsPacketx4 rhs_panel;
2195 EIGEN_GEBGP_ONESTEP(0);
2196 blB += 4 * RhsProgress;
2197 blA += 2 * Traits::LhsProgress;
2199#undef EIGEN_GEBGP_ONESTEP
2201 ResPacket R0, R1, R2, R3;
2202 ResPacket alphav = pset1<ResPacket>(alpha);
2204 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2205 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2206 R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2207 R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2208 traits.acc(C0, alphav, R0);
2209 traits.acc(C4, alphav, R1);
2210 traits.acc(C1, alphav, R2);
2211 traits.acc(C5, alphav, R3);
2212 r0.storePacket(0 * Traits::ResPacketSize, R0);
2213 r0.storePacket(1 * Traits::ResPacketSize, R1);
2214 r1.storePacket(0 * Traits::ResPacketSize, R2);
2215 r1.storePacket(1 * Traits::ResPacketSize, R3);
2217 R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2218 R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2219 R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2220 R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2221 traits.acc(C2, alphav, R0);
2222 traits.acc(C6, alphav, R1);
2223 traits.acc(C3, alphav, R2);
2224 traits.acc(C7, alphav, R3);
2225 r2.storePacket(0 * Traits::ResPacketSize, R0);
2226 r2.storePacket(1 * Traits::ResPacketSize, R1);
2227 r3.storePacket(0 * Traits::ResPacketSize, R2);
2228 r3.storePacket(1 * Traits::ResPacketSize, R3);
2233 for (
Index j2 = packet_cols4; j2 < cols; j2++) {
2234 for (
Index i = i1; i < actual_panel_end; i += 2 * LhsProgress) {
2236 const LhsScalar* blA = &blockA[i * strideA + offsetA * (2 * Traits::LhsProgress)];
2244 LinearMapper r0 = res.getLinearMapper(i, j2);
2245 r0.prefetch(prefetch_res_offset);
2248 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
2251 for (
Index k = 0; k < peeled_kc; k += pk) {
2252 EIGEN_ASM_COMMENT(
"begin gebp micro kernel 2pX1");
2255#define EIGEN_GEBGP_ONESTEP(K) \
2257 EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX1"); \
2258 EIGEN_ASM_COMMENT("Note: these asm comments work around bug 935!"); \
2259 traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
2260 traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
2261 traits.loadRhs(&blB[(0 + K) * RhsProgress], B_0); \
2262 traits.madd(A0, B_0, C0, B1, fix<0>); \
2263 traits.madd(A1, B_0, C4, B_0, fix<0>); \
2264 EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX1"); \
2267 EIGEN_GEBGP_ONESTEP(0);
2268 EIGEN_GEBGP_ONESTEP(1);
2269 EIGEN_GEBGP_ONESTEP(2);
2270 EIGEN_GEBGP_ONESTEP(3);
2271 EIGEN_GEBGP_ONESTEP(4);
2272 EIGEN_GEBGP_ONESTEP(5);
2273 EIGEN_GEBGP_ONESTEP(6);
2274 EIGEN_GEBGP_ONESTEP(7);
2276 blB += int(pk) * int(RhsProgress);
2277 blA += int(pk) * 2 * int(Traits::LhsProgress);
2279 EIGEN_ASM_COMMENT(
"end gebp micro kernel 2pX1");
2283 for (
Index k = peeled_kc; k < depth; k++) {
2285 EIGEN_GEBGP_ONESTEP(0);
2287 blA += 2 * Traits::LhsProgress;
2289#undef EIGEN_GEBGP_ONESTEP
2291 ResPacket alphav = pset1<ResPacket>(alpha);
2293 R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
2294 R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
2295 traits.acc(C0, alphav, R0);
2296 traits.acc(C4, alphav, R1);
2297 r0.storePacket(0 * Traits::ResPacketSize, R0);
2298 r0.storePacket(1 * Traits::ResPacketSize, R1);
2304 if (mr >= 1 * Traits::LhsProgress) {
2305 lhs_process_one_packet<nr, LhsProgress, RhsProgress, LhsScalar, RhsScalar, ResScalar, AccPacket, LhsPacket,
2306 RhsPacket, ResPacket, Traits, LinearMapper, DataMapper>
2308 p(res, blockA, blockB, alpha, peeled_mc2, peeled_mc1, strideA, strideB, offsetA, offsetB, prefetch_res_offset,
2309 peeled_kc, pk, cols, depth, packet_cols4);
2312 if ((LhsProgressHalf < LhsProgress) && mr >= LhsProgressHalf) {
2313 lhs_process_fraction_of_packet<nr, LhsProgressHalf, RhsProgressHalf, LhsScalar, RhsScalar, ResScalar, AccPacketHalf,
2314 LhsPacketHalf, RhsPacketHalf, ResPacketHalf, HalfTraits, LinearMapper, DataMapper>
2316 p(res, blockA, blockB, alpha, peeled_mc1, peeled_mc_half, strideA, strideB, offsetA, offsetB, prefetch_res_offset,
2317 peeled_kc, pk, cols, depth, packet_cols4);
2320 if ((LhsProgressQuarter < LhsProgressHalf) && mr >= LhsProgressQuarter) {
2321 lhs_process_fraction_of_packet<nr, LhsProgressQuarter, RhsProgressQuarter, LhsScalar, RhsScalar, ResScalar,
2322 AccPacketQuarter, LhsPacketQuarter, RhsPacketQuarter, ResPacketQuarter,
2323 QuarterTraits, LinearMapper, DataMapper>
2325 p(res, blockA, blockB, alpha, peeled_mc_half, peeled_mc_quarter, strideA, strideB, offsetA, offsetB,
2326 prefetch_res_offset, peeled_kc, pk, cols, depth, packet_cols4);
2329 if (peeled_mc_quarter < rows) {
2330#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
2331 EIGEN_IF_CONSTEXPR(nr >= 8) {
2333 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
2335 for (
Index i = peeled_mc_quarter; i < rows; i += 1) {
2336 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2339 ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
2340 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 8];
2341 for (
Index k = 0; k < depth; k++) {
2342 LhsScalar A0 = blA[k];
2346 C0 = cj.pmadd(A0, B_0, C0);
2349 C1 = cj.pmadd(A0, B_0, C1);
2352 C2 = cj.pmadd(A0, B_0, C2);
2355 C3 = cj.pmadd(A0, B_0, C3);
2358 C4 = cj.pmadd(A0, B_0, C4);
2361 C5 = cj.pmadd(A0, B_0, C5);
2364 C6 = cj.pmadd(A0, B_0, C6);
2367 C7 = cj.pmadd(A0, B_0, C7);
2371 res(i, j2 + 0) += alpha * C0;
2372 res(i, j2 + 1) += alpha * C1;
2373 res(i, j2 + 2) += alpha * C2;
2374 res(i, j2 + 3) += alpha * C3;
2375 res(i, j2 + 4) += alpha * C4;
2376 res(i, j2 + 5) += alpha * C5;
2377 res(i, j2 + 6) += alpha * C6;
2378 res(i, j2 + 7) += alpha * C7;
2384 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2386 for (
Index i = peeled_mc_quarter; i < rows; i += 1) {
2387 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2389 const RhsScalar* blB = &blockB[j2 * strideB + offsetB * 4];
2394 const int SResPacketHalfSize = unpacket_traits<typename unpacket_traits<SResPacket>::half>::size;
2395 const int SResPacketQuarterSize =
2396 unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
2401 constexpr bool kCanLoadSRhsQuad =
2402 (unpacket_traits<SLhsPacket>::size < 4) ||
2403 (unpacket_traits<SRhsPacket>::size % ((std::max<int>)(unpacket_traits<SLhsPacket>::size, 4) / 4)) == 0;
2404 if (kCanLoadSRhsQuad && (SwappedTraits::LhsProgress % 4) == 0 && (SwappedTraits::LhsProgress <= 16) &&
2405 (SwappedTraits::LhsProgress != 8 || SResPacketHalfSize == nr) &&
2406 (SwappedTraits::LhsProgress != 16 || SResPacketQuarterSize == nr)) {
2407 SAccPacket C0, C1, C2, C3;
2408 straits.initAcc(C0);
2409 straits.initAcc(C1);
2410 straits.initAcc(C2);
2411 straits.initAcc(C3);
2413 const Index spk = (std::max)(1, SwappedTraits::LhsProgress / 4);
2414 const Index endk = (depth / spk) * spk;
2415 const Index endk4 = (depth / (spk * 4)) * (spk * 4);
2418 for (; k < endk4; k += 4 * spk) {
2420 SRhsPacket B_0, B_1;
2422 straits.loadLhsUnaligned(blB + 0 * SwappedTraits::LhsProgress, A0);
2423 straits.loadLhsUnaligned(blB + 1 * SwappedTraits::LhsProgress, A1);
2425 straits.loadRhsQuad(blA + 0 * spk, B_0);
2426 straits.loadRhsQuad(blA + 1 * spk, B_1);
2427 straits.madd(A0, B_0, C0, B_0,
fix<0>);
2428 straits.madd(A1, B_1, C1, B_1,
fix<0>);
2430 straits.loadLhsUnaligned(blB + 2 * SwappedTraits::LhsProgress, A0);
2431 straits.loadLhsUnaligned(blB + 3 * SwappedTraits::LhsProgress, A1);
2432 straits.loadRhsQuad(blA + 2 * spk, B_0);
2433 straits.loadRhsQuad(blA + 3 * spk, B_1);
2434 straits.madd(A0, B_0, C2, B_0,
fix<0>);
2435 straits.madd(A1, B_1, C3, B_1,
fix<0>);
2437 blB += 4 * SwappedTraits::LhsProgress;
2440 C0 = padd(padd(C0, C1), padd(C2, C3));
2441 for (; k < endk; k += spk) {
2445 straits.loadLhsUnaligned(blB, A0);
2446 straits.loadRhsQuad(blA, B_0);
2447 straits.madd(A0, B_0, C0, B_0,
fix<0>);
2449 blB += SwappedTraits::LhsProgress;
2452 if (SwappedTraits::LhsProgress == 8) {
2454 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8,
typename unpacket_traits<SResPacket>::half,
2457 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8,
typename unpacket_traits<SLhsPacket>::half,
2460 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8,
typename unpacket_traits<SRhsPacket>::half,
2463 typedef std::conditional_t<SwappedTraits::LhsProgress >= 8,
typename unpacket_traits<SAccPacket>::half,
2467 SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
2468 SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
2470 if (depth - endk > 0) {
2474 straits.loadLhsUnaligned(blB, a0);
2475 straits.loadRhs(blA, b0);
2476 SAccPacketHalf c0 = predux_half_dowto4(C0);
2477 straits.madd(a0, b0, c0, b0,
fix<0>);
2478 straits.acc(c0, alphav, R);
2480 straits.acc(predux_half_dowto4(C0), alphav, R);
2482 res.scatterPacket(i, j2, R);
2483 }
else if (SwappedTraits::LhsProgress == 16) {
2488 last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
2489 p(res, straits, blA, blB, depth, endk, i, j2, alpha, C0);
2491 SResPacket R = res.template gatherPacket<SResPacket>(i, j2);
2492 SResPacket alphav = pset1<SResPacket>(alpha);
2493 straits.acc(C0, alphav, R);
2494 res.scatterPacket(i, j2, R);
2499 ResScalar C0(0), C1(0), C2(0), C3(0);
2501 for (
Index k = 0; k < depth; k++) {
2509 C0 = cj.pmadd(A0, B_0, C0);
2510 C1 = cj.pmadd(A0, B_1, C1);
2514 C2 = cj.pmadd(A0, B_0, C2);
2515 C3 = cj.pmadd(A0, B_1, C3);
2519 res(i, j2 + 0) += alpha * C0;
2520 res(i, j2 + 1) += alpha * C1;
2521 res(i, j2 + 2) += alpha * C2;
2522 res(i, j2 + 3) += alpha * C3;
2527 for (
Index j2 = packet_cols4; j2 < cols; j2++) {
2529 for (
Index i = peeled_mc_quarter; i < rows; i += 1) {
2530 const LhsScalar* blA = &blockA[i * strideA + offsetA];
2534 const RhsScalar* blB = &blockB[j2 * strideB + offsetB];
2535 for (
Index k = 0; k < depth; k++) {
2536 LhsScalar A0 = blA[k];
2537 RhsScalar B_0 = blB[k];
2538 C0 = cj.pmadd(A0, B_0, C0);
2540 res(i, j2) += alpha * C0;
2560template <
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
2562struct gemm_pack_lhs<Scalar,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
2563 typedef typename DataMapper::LinearMapper LinearMapper;
2564 EIGEN_DONT_INLINE
void operator()(Scalar* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
2568template <
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
2570EIGEN_DONT_INLINE
void gemm_pack_lhs<Scalar,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate,
2571 PanelMode>::operator()(Scalar* blockA,
const DataMapper& lhs,
Index depth,
2573 typedef typename unpacket_traits<Packet>::half HalfPacket;
2574 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
2576 PacketSize = unpacket_traits<Packet>::size,
2577 HalfPacketSize = unpacket_traits<HalfPacket>::size,
2578 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
2579 HasHalf = (int)HalfPacketSize < (
int)PacketSize,
2580 HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize
2583 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK LHS");
2584 EIGEN_UNUSED_VARIABLE(stride);
2585 EIGEN_UNUSED_VARIABLE(offset);
2586 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2587 eigen_assert(((Pack1 % PacketSize) == 0 && Pack1 <= 4 * PacketSize) || (Pack1 <= 4));
2588 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2591 const Index peeled_mc3 = Pack1 >= 3 * PacketSize ? (rows / (3 * PacketSize)) * (3 * PacketSize) : 0;
2592 const Index peeled_mc2 =
2593 Pack1 >= 2 * PacketSize ? peeled_mc3 + ((rows - peeled_mc3) / (2 * PacketSize)) * (2 * PacketSize) : 0;
2594 const Index peeled_mc1 =
2595 Pack1 >= 1 * PacketSize ? peeled_mc2 + ((rows - peeled_mc2) / (1 * PacketSize)) * (1 * PacketSize) : 0;
2596 const Index peeled_mc_half =
2597 Pack1 >= HalfPacketSize ? peeled_mc1 + ((rows - peeled_mc1) / (HalfPacketSize)) * (HalfPacketSize) : 0;
2598 const Index peeled_mc_quarter = Pack1 >= QuarterPacketSize ? (rows / (QuarterPacketSize)) * (QuarterPacketSize) : 0;
2599 const Index last_lhs_progress = rows > peeled_mc_quarter ? (rows - peeled_mc_quarter) & ~1 : 0;
2600 const Index peeled_mc0 = Pack2 >= PacketSize ? peeled_mc_quarter
2601 : Pack2 > 1 && last_lhs_progress ? (rows / last_lhs_progress) * last_lhs_progress
2607 if (Pack1 >= 3 * PacketSize) {
2608 for (; i < peeled_mc3; i += 3 * PacketSize) {
2609 if (PanelMode) count += (3 * PacketSize) * offset;
2611 for (
Index k = 0; k < depth; k++) {
2613 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2614 B = lhs.template loadPacket<Packet>(i + 1 * PacketSize, k);
2615 C = lhs.template loadPacket<Packet>(i + 2 * PacketSize, k);
2616 pstore(blockA + count, cj.pconj(A));
2617 count += PacketSize;
2618 pstore(blockA + count, cj.pconj(B));
2619 count += PacketSize;
2620 pstore(blockA + count, cj.pconj(C));
2621 count += PacketSize;
2623 if (PanelMode) count += (3 * PacketSize) * (stride - offset - depth);
2627 if (Pack1 >= 2 * PacketSize) {
2628 for (; i < peeled_mc2; i += 2 * PacketSize) {
2629 if (PanelMode) count += (2 * PacketSize) * offset;
2631 for (
Index k = 0; k < depth; k++) {
2633 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2634 B = lhs.template loadPacket<Packet>(i + 1 * PacketSize, k);
2635 pstore(blockA + count, cj.pconj(A));
2636 count += PacketSize;
2637 pstore(blockA + count, cj.pconj(B));
2638 count += PacketSize;
2640 if (PanelMode) count += (2 * PacketSize) * (stride - offset - depth);
2644 if (Pack1 >= 1 * PacketSize) {
2645 for (; i < peeled_mc1; i += 1 * PacketSize) {
2646 if (PanelMode) count += (1 * PacketSize) * offset;
2648 for (
Index k = 0; k < depth; k++) {
2650 A = lhs.template loadPacket<Packet>(i + 0 * PacketSize, k);
2651 pstore(blockA + count, cj.pconj(A));
2652 count += PacketSize;
2654 if (PanelMode) count += (1 * PacketSize) * (stride - offset - depth);
2658 if (HasHalf && Pack1 >= HalfPacketSize) {
2659 for (; i < peeled_mc_half; i += HalfPacketSize) {
2660 if (PanelMode) count += (HalfPacketSize)*offset;
2662 for (
Index k = 0; k < depth; k++) {
2664 A = lhs.template loadPacket<HalfPacket>(i + 0 * (HalfPacketSize), k);
2665 pstoreu(blockA + count, cj.pconj(A));
2666 count += HalfPacketSize;
2668 if (PanelMode) count += (HalfPacketSize) * (stride - offset - depth);
2672 if (HasQuarter && Pack1 >= QuarterPacketSize) {
2673 for (; i < peeled_mc_quarter; i += QuarterPacketSize) {
2674 if (PanelMode) count += (QuarterPacketSize)*offset;
2676 for (
Index k = 0; k < depth; k++) {
2678 A = lhs.template loadPacket<QuarterPacket>(i + 0 * (QuarterPacketSize), k);
2679 pstoreu(blockA + count, cj.pconj(A));
2680 count += QuarterPacketSize;
2682 if (PanelMode) count += (QuarterPacketSize) * (stride - offset - depth);
2691 if (Pack2 < PacketSize && Pack2 > 1) {
2692 for (; i < peeled_mc0; i += last_lhs_progress) {
2693 if (PanelMode) count += last_lhs_progress * offset;
2695 for (
Index k = 0; k < depth; k++)
2696 for (
Index w = 0; w < last_lhs_progress; w++) blockA[count++] = cj(lhs(i + w, k));
2698 if (PanelMode) count += last_lhs_progress * (stride - offset - depth);
2702 for (; i < rows; i++) {
2703 if (PanelMode) count += offset;
2704 for (
Index k = 0; k < depth; k++) blockA[count++] = cj(lhs(i, k));
2705 if (PanelMode) count += (stride - offset - depth);
2709template <
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
2711struct gemm_pack_lhs<Scalar,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
2712 typedef typename DataMapper::LinearMapper LinearMapper;
2713 EIGEN_DONT_INLINE
void operator()(Scalar* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
2717template <
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
2719EIGEN_DONT_INLINE
void gemm_pack_lhs<Scalar,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate,
2720 PanelMode>::operator()(Scalar* blockA,
const DataMapper& lhs,
Index depth,
2722 typedef typename unpacket_traits<Packet>::half HalfPacket;
2723 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
2725 PacketSize = unpacket_traits<Packet>::size,
2726 HalfPacketSize = unpacket_traits<HalfPacket>::size,
2727 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
2728 HasHalf = (int)HalfPacketSize < (
int)PacketSize,
2729 HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize
2732 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK LHS");
2733 EIGEN_UNUSED_VARIABLE(stride);
2734 EIGEN_UNUSED_VARIABLE(offset);
2735 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2736 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2738 bool gone_half =
false, gone_quarter =
false, gone_last =
false;
2742 Index psize = PacketSize;
2744 Index remaining_rows = rows - i;
2745 Index peeled_mc = gone_last ? Pack2 > 1 ? (rows / pack) * pack : 0 : i + (remaining_rows / pack) * pack;
2746 Index starting_pos = i;
2747 for (; i < peeled_mc; i += pack) {
2748 if (PanelMode) count += pack * offset;
2751 if (pack >= psize && psize >= QuarterPacketSize) {
2752 const Index peeled_k = (depth / psize) * psize;
2753 for (; k < peeled_k; k += psize) {
2754 for (
Index m = 0; m < pack; m += psize) {
2755 if (psize == PacketSize) {
2756 PacketBlock<Packet> kernel;
2757 for (
Index p = 0; p < psize; ++p) kernel.packet[p] = lhs.template loadPacket<Packet>(i + p + m, k);
2759 for (
Index p = 0; p < psize; ++p) pstore(blockA + count + m + (pack)*p, cj.pconj(kernel.packet[p]));
2760 }
else if (HasHalf && psize == HalfPacketSize) {
2762 PacketBlock<HalfPacket> kernel_half;
2763 for (
Index p = 0; p < psize; ++p)
2764 kernel_half.packet[p] = lhs.template loadPacket<HalfPacket>(i + p + m, k);
2765 ptranspose(kernel_half);
2766 for (
Index p = 0; p < psize; ++p) pstore(blockA + count + m + (pack)*p, cj.pconj(kernel_half.packet[p]));
2767 }
else if (HasQuarter && psize == QuarterPacketSize) {
2768 gone_quarter =
true;
2769 PacketBlock<QuarterPacket> kernel_quarter;
2770 for (
Index p = 0; p < psize; ++p)
2771 kernel_quarter.packet[p] = lhs.template loadPacket<QuarterPacket>(i + p + m, k);
2772 ptranspose(kernel_quarter);
2773 for (
Index p = 0; p < psize; ++p)
2774 pstore(blockA + count + m + (pack)*p, cj.pconj(kernel_quarter.packet[p]));
2777 count += psize * pack;
2781 for (; k < depth; k++) {
2783 for (; w < pack - 3; w += 4) {
2784 Scalar a(cj(lhs(i + w + 0, k))), b(cj(lhs(i + w + 1, k))), c(cj(lhs(i + w + 2, k))), d(cj(lhs(i + w + 3, k)));
2785 blockA[count++] = a;
2786 blockA[count++] = b;
2787 blockA[count++] = c;
2788 blockA[count++] = d;
2791 for (; w < pack; ++w) blockA[count++] = cj(lhs(i + w, k));
2794 if (PanelMode) count += pack * (stride - offset - depth);
2798 Index left = rows - i;
2800 if (!gone_last && (starting_pos == i || left >= psize / 2 || left >= psize / 4) &&
2801 ((psize / 2 == HalfPacketSize && HasHalf && !gone_half) ||
2802 (psize / 2 == QuarterPacketSize && HasQuarter && !gone_quarter))) {
2813 if (Pack2 < PacketSize && !gone_last) {
2815 psize = pack = left & ~1;
2820 for (; i < rows; i++) {
2821 if (PanelMode) count += offset;
2822 for (
Index k = 0; k < depth; k++) blockA[count++] = cj(lhs(i, k));
2823 if (PanelMode) count += (stride - offset - depth);
2834template <
typename Scalar,
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
2835struct gemm_pack_rhs<Scalar,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
2836 typedef typename packet_traits<Scalar>::type Packet;
2837 typedef typename DataMapper::LinearMapper LinearMapper;
2838 enum { PacketSize = packet_traits<Scalar>::size };
2839 EIGEN_DONT_INLINE
void operator()(Scalar* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
2843template <
typename Scalar,
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
2844EIGEN_DONT_INLINE
void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
2846 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK RHS COLMAJOR");
2847 EIGEN_UNUSED_VARIABLE(stride);
2848 EIGEN_UNUSED_VARIABLE(offset);
2849 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
2850 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
2851 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
2852 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
2854 const Index peeled_k = (depth / PacketSize) * PacketSize;
2856#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
2857 EIGEN_IF_CONSTEXPR(nr >= 8) {
2858 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
2860 if (PanelMode) count += 8 * offset;
2861 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
2862 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
2863 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
2864 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
2865 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
2866 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
2867 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
2868 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
2870 if (PacketSize % 2 == 0 && PacketSize <= 8)
2872 for (; k < peeled_k; k += PacketSize) {
2873 if (PacketSize == 2) {
2874 PacketBlock<Packet, PacketSize == 2 ? 2 : PacketSize> kernel0, kernel1, kernel2, kernel3;
2875 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2876 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2877 kernel1.packet[0 % PacketSize] = dm2.template loadPacket<Packet>(k);
2878 kernel1.packet[1 % PacketSize] = dm3.template loadPacket<Packet>(k);
2879 kernel2.packet[0 % PacketSize] = dm4.template loadPacket<Packet>(k);
2880 kernel2.packet[1 % PacketSize] = dm5.template loadPacket<Packet>(k);
2881 kernel3.packet[0 % PacketSize] = dm6.template loadPacket<Packet>(k);
2882 kernel3.packet[1 % PacketSize] = dm7.template loadPacket<Packet>(k);
2883 ptranspose(kernel0);
2884 ptranspose(kernel1);
2885 ptranspose(kernel2);
2886 ptranspose(kernel3);
2888 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2889 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
2890 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize]));
2891 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize]));
2893 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2894 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
2895 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize]));
2896 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize]));
2897 count += 8 * PacketSize;
2898 }
else if (PacketSize == 4) {
2899 PacketBlock<Packet, PacketSize == 4 ? 4 : PacketSize> kernel0, kernel1;
2901 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2902 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2903 kernel0.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2904 kernel0.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2905 kernel1.packet[0 % PacketSize] = dm4.template loadPacket<Packet>(k);
2906 kernel1.packet[1 % PacketSize] = dm5.template loadPacket<Packet>(k);
2907 kernel1.packet[2 % PacketSize] = dm6.template loadPacket<Packet>(k);
2908 kernel1.packet[3 % PacketSize] = dm7.template loadPacket<Packet>(k);
2909 ptranspose(kernel0);
2910 ptranspose(kernel1);
2912 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2913 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
2914 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2915 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
2916 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[2 % PacketSize]));
2917 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[2 % PacketSize]));
2918 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel0.packet[3 % PacketSize]));
2919 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel1.packet[3 % PacketSize]));
2920 count += 8 * PacketSize;
2921 }
else if (PacketSize == 8) {
2922 PacketBlock<Packet, PacketSize == 8 ? 8 : PacketSize> kernel0;
2924 kernel0.packet[0 % PacketSize] = dm0.template loadPacket<Packet>(k);
2925 kernel0.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2926 kernel0.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2927 kernel0.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2928 kernel0.packet[4 % PacketSize] = dm4.template loadPacket<Packet>(k);
2929 kernel0.packet[5 % PacketSize] = dm5.template loadPacket<Packet>(k);
2930 kernel0.packet[6 % PacketSize] = dm6.template loadPacket<Packet>(k);
2931 kernel0.packet[7 % PacketSize] = dm7.template loadPacket<Packet>(k);
2932 ptranspose(kernel0);
2934 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
2935 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
2936 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel0.packet[2 % PacketSize]));
2937 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel0.packet[3 % PacketSize]));
2938 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[4 % PacketSize]));
2939 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel0.packet[5 % PacketSize]));
2940 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel0.packet[6 % PacketSize]));
2941 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel0.packet[7 % PacketSize]));
2942 count += 8 * PacketSize;
2947 for (; k < depth; k++) {
2948 blockB[count + 0] = cj(dm0(k));
2949 blockB[count + 1] = cj(dm1(k));
2950 blockB[count + 2] = cj(dm2(k));
2951 blockB[count + 3] = cj(dm3(k));
2952 blockB[count + 4] = cj(dm4(k));
2953 blockB[count + 5] = cj(dm5(k));
2954 blockB[count + 6] = cj(dm6(k));
2955 blockB[count + 7] = cj(dm7(k));
2959 if (PanelMode) count += 8 * (stride - offset - depth);
2964 EIGEN_IF_CONSTEXPR(nr >= 4) {
2965 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
2967 if (PanelMode) count += 4 * offset;
2968 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
2969 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
2970 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
2971 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
2974 if ((PacketSize % 4) == 0)
2976 for (; k < peeled_k; k += PacketSize) {
2977 PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
2978 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
2979 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
2980 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
2981 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
2983 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
2984 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
2985 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
2986 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
2987 count += 4 * PacketSize;
2990 for (; k < depth; k++) {
2991 blockB[count + 0] = cj(dm0(k));
2992 blockB[count + 1] = cj(dm1(k));
2993 blockB[count + 2] = cj(dm2(k));
2994 blockB[count + 3] = cj(dm3(k));
2998 if (PanelMode) count += 4 * (stride - offset - depth);
3003 for (
Index j2 = packet_cols4; j2 < cols; ++j2) {
3004 if (PanelMode) count += offset;
3005 const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
3006 for (
Index k = 0; k < depth; k++) {
3007 blockB[count] = cj(dm0(k));
3010 if (PanelMode) count += (stride - offset - depth);
3015template <
typename Scalar,
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3016struct gemm_pack_rhs<Scalar,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3017 typedef typename packet_traits<Scalar>::type Packet;
3018 typedef typename unpacket_traits<Packet>::half HalfPacket;
3019 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
3020 typedef typename DataMapper::LinearMapper LinearMapper;
3022 PacketSize = packet_traits<Scalar>::size,
3023 HalfPacketSize = unpacket_traits<HalfPacket>::size,
3024 QuarterPacketSize = unpacket_traits<QuarterPacket>::size
3026 EIGEN_DONT_INLINE
void operator()(Scalar* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3028 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK RHS ROWMAJOR");
3029 EIGEN_UNUSED_VARIABLE(stride);
3030 EIGEN_UNUSED_VARIABLE(offset);
3031 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
3032 const bool HasHalf = (int)HalfPacketSize < (
int)PacketSize;
3033 const bool HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize;
3034 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
3035 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
3036 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
3039#if EIGEN_ARCH_ARM64 || EIGEN_ARCH_LOONGARCH64
3040 EIGEN_IF_CONSTEXPR(nr >= 8) {
3041 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
3043 if (PanelMode) count += 8 * offset;
3044 for (Index k = 0; k < depth; k++) {
3045 if (PacketSize == 8) {
3046 Packet A = rhs.template loadPacket<Packet>(k, j2);
3047 pstoreu(blockB + count, cj.pconj(A));
3048 count += PacketSize;
3049 }
else if (PacketSize == 4) {
3050 Packet A = rhs.template loadPacket<Packet>(k, j2);
3051 Packet B = rhs.template loadPacket<Packet>(k, j2 + 4);
3052 pstoreu(blockB + count, cj.pconj(A));
3053 pstoreu(blockB + count + PacketSize, cj.pconj(B));
3054 count += 2 * PacketSize;
3056 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
3057 blockB[count + 0] = cj(dm0(0));
3058 blockB[count + 1] = cj(dm0(1));
3059 blockB[count + 2] = cj(dm0(2));
3060 blockB[count + 3] = cj(dm0(3));
3061 blockB[count + 4] = cj(dm0(4));
3062 blockB[count + 5] = cj(dm0(5));
3063 blockB[count + 6] = cj(dm0(6));
3064 blockB[count + 7] = cj(dm0(7));
3069 if (PanelMode) count += 8 * (stride - offset - depth);
3075 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
3077 if (PanelMode) count += 4 * offset;
3078 for (Index k = 0; k < depth; k++) {
3079 if (PacketSize == 4) {
3080 Packet A = rhs.template loadPacket<Packet>(k, j2);
3081 pstoreu(blockB + count, cj.pconj(A));
3082 count += PacketSize;
3083 }
else if (HasHalf && HalfPacketSize == 4) {
3084 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
3085 pstoreu(blockB + count, cj.pconj(A));
3086 count += HalfPacketSize;
3087 }
else if (HasQuarter && QuarterPacketSize == 4) {
3088 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
3089 pstoreu(blockB + count, cj.pconj(A));
3090 count += QuarterPacketSize;
3092 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
3093 blockB[count + 0] = cj(dm0(0));
3094 blockB[count + 1] = cj(dm0(1));
3095 blockB[count + 2] = cj(dm0(2));
3096 blockB[count + 3] = cj(dm0(3));
3101 if (PanelMode) count += 4 * (stride - offset - depth);
3105 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
3106 if (PanelMode) count += offset;
3107 for (Index k = 0; k < depth; k++) {
3108 blockB[count] = cj(rhs(k, j2));
3111 if (PanelMode) count += stride - offset - depth;
3121 std::ptrdiff_t l1, l2, l3;
3122 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3129 std::ptrdiff_t l1, l2, l3;
3130 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3137 std::ptrdiff_t l1, l2, l3;
3138 internal::manage_caching_sizes(GetAction, &l1, &l2, &l3);
3148 internal::manage_caching_sizes(SetAction, &l1, &l2, &l3);
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
std::ptrdiff_t l1CacheSize()
Definition GeneralBlockPanelKernel.h:3120
std::ptrdiff_t l2CacheSize()
Definition GeneralBlockPanelKernel.h:3128
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
std::ptrdiff_t l3CacheSize()
Definition GeneralBlockPanelKernel.h:3136
void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2, std::ptrdiff_t l3)
Definition GeneralBlockPanelKernel.h:3147