10#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
13template <
bool isARowMajor = true>
14EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15 EIGEN_IF_CONSTEXPR(isARowMajor)
return i * LDA + j;
16 else return i + j * LDA;
60EIGEN_ALWAYS_INLINE
auto remMask(int64_t m) {
61 EIGEN_IF_CONSTEXPR(N == 16) {
return 0xFFFF >> (16 - m); }
62 else EIGEN_IF_CONSTEXPR(N == 8) {
63 return 0xFF >> (8 - m);
65 else EIGEN_IF_CONSTEXPR(N == 4) {
66 return 0x0F >> (4 - m);
71template <
typename Packet>
72EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
75EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
85 kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86 kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87 kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88 kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89 kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90 kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91 kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92 kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
94 T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95 T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96 T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97 T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98 T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99 T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100 T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101 T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102 T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103 T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104 T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105 T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106 T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107 T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108 T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109 T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
111 kernel.packet[0] = T0;
112 kernel.packet[1] = T1;
113 kernel.packet[2] = T2;
114 kernel.packet[3] = T3;
115 kernel.packet[4] = T4;
116 kernel.packet[5] = T5;
117 kernel.packet[6] = T6;
118 kernel.packet[7] = T7;
122EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
129template <
typename Scalar>
132 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133 using vecHalf =
typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
151 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
152 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154 constexpr int64_t counterReverse = endN - counter;
155 constexpr int64_t startN = counterReverse;
157 EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158 EIGEN_IF_CONSTEXPR(remM) {
160 C_arr + LDC * startN,
161 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
167 pstoreu<Scalar>(C_arr + LDC * startN,
168 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
169 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
174 vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
177 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178 preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
180 EIGEN_IF_CONSTEXPR(remM) {
182 C_arr + LDC * startN,
183 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184 preinterpret<vecHalf>(
185 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
190 C_arr + LDC * startN,
191 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
192 preinterpret<vecHalf>(
193 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
196 aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
199 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
200 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202 EIGEN_UNUSED_VARIABLE(C_arr);
203 EIGEN_UNUSED_VARIABLE(LDC);
204 EIGEN_UNUSED_VARIABLE(zmm);
205 EIGEN_UNUSED_VARIABLE(remM_);
208 template <
int64_t endN,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
209 static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC,
210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
212 aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
241 template <
int64_t unrollN,
int64_t packetIndexOffset>
242 static EIGEN_ALWAYS_INLINE
void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
245 constexpr int64_t zmmStride = unrollN / PacketSize;
246 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
247 r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
248 r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
249 r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
250 r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
251 r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
252 r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
253 r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
254 r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
256 zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
257 zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
258 zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
259 zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
260 zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
261 zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
262 zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
263 zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
281template <
typename Scalar>
284 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
285 using vecHalf =
typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
286 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
303 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
304 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
305 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
307 constexpr int64_t counterReverse = endN - counter;
308 constexpr int64_t startN = counterReverse;
310 EIGEN_IF_CONSTEXPR(remM) {
311 ymm.packet[packetIndexOffset + startN] =
312 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
315 EIGEN_IF_CONSTEXPR(remN_ == 0) {
316 ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB]);
318 else ymm.packet[packetIndexOffset + startN] =
319 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
322 aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
325 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
326 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
327 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
329 EIGEN_UNUSED_VARIABLE(B_arr);
330 EIGEN_UNUSED_VARIABLE(LDB);
331 EIGEN_UNUSED_VARIABLE(ymm);
332 EIGEN_UNUSED_VARIABLE(remM_);
341 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
342 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
343 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
344 constexpr int64_t counterReverse = endN - counter;
345 constexpr int64_t startN = counterReverse;
347 EIGEN_IF_CONSTEXPR(remK || remM) {
348 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
349 remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
352 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
355 aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
358 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
359 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
360 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
361 EIGEN_UNUSED_VARIABLE(B_arr);
362 EIGEN_UNUSED_VARIABLE(LDB);
363 EIGEN_UNUSED_VARIABLE(ymm);
364 EIGEN_UNUSED_VARIABLE(rem_);
373 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
374 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
375 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
376 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
377 constexpr int64_t counterReverse = endN - counter;
378 constexpr int64_t startN = counterReverse;
379 transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
380 aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
383 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
384 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
385 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
386 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
387 EIGEN_UNUSED_VARIABLE(B_arr);
388 EIGEN_UNUSED_VARIABLE(LDB);
389 EIGEN_UNUSED_VARIABLE(B_temp);
390 EIGEN_UNUSED_VARIABLE(LDB_);
391 EIGEN_UNUSED_VARIABLE(ymm);
392 EIGEN_UNUSED_VARIABLE(remM_);
401 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
402 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
403 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
404 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
405 constexpr int64_t counterReverse = endN - counter;
406 constexpr int64_t startN = counterReverse;
408 EIGEN_IF_CONSTEXPR(toTemp) {
409 transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
412 transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
415 aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
418 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
419 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
420 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
421 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
422 EIGEN_UNUSED_VARIABLE(B_arr);
423 EIGEN_UNUSED_VARIABLE(LDB);
424 EIGEN_UNUSED_VARIABLE(B_temp);
425 EIGEN_UNUSED_VARIABLE(LDB_);
426 EIGEN_UNUSED_VARIABLE(ymm);
427 EIGEN_UNUSED_VARIABLE(remM_);
434 template <
int64_t endN,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
435 static EIGEN_ALWAYS_INLINE
void loadB(Scalar *B_arr, int64_t LDB,
436 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
438 aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
441 template <
int64_t endN,
int64_t packetIndexOffset,
bool remK,
bool remM>
442 static EIGEN_ALWAYS_INLINE
void storeB(Scalar *B_arr, int64_t LDB,
443 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
445 aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
448 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remN_ = 0>
449 static EIGEN_ALWAYS_INLINE
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
450 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
452 EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
454 aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
458 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remK_>
459 static EIGEN_ALWAYS_INLINE
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
460 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
462 aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
465 template <
int64_t packetIndexOffset>
466 static EIGEN_ALWAYS_INLINE
void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
469 PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
470 r.packet[0] = ymm.packet[packetIndexOffset + 0];
471 r.packet[1] = ymm.packet[packetIndexOffset + 1];
472 r.packet[2] = ymm.packet[packetIndexOffset + 2];
473 r.packet[3] = ymm.packet[packetIndexOffset + 3];
474 r.packet[4] = ymm.packet[packetIndexOffset + 4];
475 r.packet[5] = ymm.packet[packetIndexOffset + 5];
476 r.packet[6] = ymm.packet[packetIndexOffset + 6];
477 r.packet[7] = ymm.packet[packetIndexOffset + 7];
479 ymm.packet[packetIndexOffset + 0] = r.packet[0];
480 ymm.packet[packetIndexOffset + 1] = r.packet[1];
481 ymm.packet[packetIndexOffset + 2] = r.packet[2];
482 ymm.packet[packetIndexOffset + 3] = r.packet[3];
483 ymm.packet[packetIndexOffset + 4] = r.packet[4];
484 ymm.packet[packetIndexOffset + 5] = r.packet[5];
485 ymm.packet[packetIndexOffset + 6] = r.packet[6];
486 ymm.packet[packetIndexOffset + 7] = r.packet[7];
489 template <
int64_t unrollN,
bool toTemp,
bool remM>
490 static EIGEN_ALWAYS_INLINE
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
491 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
493 constexpr int64_t U3 = PacketSize * 3;
494 constexpr int64_t U2 = PacketSize * 2;
495 constexpr int64_t U1 = PacketSize * 1;
503 EIGEN_IF_CONSTEXPR(unrollN == U3) {
505 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
506 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
507 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
510 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
512 EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
513 transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
515 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
518 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
522 else EIGEN_IF_CONSTEXPR(unrollN == U2) {
524 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
525 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
526 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528 EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
529 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
531 EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
532 transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
533 &B_temp[maxUBlock], LDB_, ymm, remM_);
534 transB::template transposeLxL<0>(ymm);
535 transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
536 &B_temp[maxUBlock], LDB_, ymm, remM_);
539 else EIGEN_IF_CONSTEXPR(unrollN == U1) {
541 transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
542 transB::template transposeLxL<0>(ymm);
543 EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
544 transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
546 else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
548 transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
549 transB::template transposeLxL<0>(ymm);
550 transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
552 else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
554 transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
555 transB::template transposeLxL<0>(ymm);
556 transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
558 else EIGEN_IF_CONSTEXPR(unrollN == 2) {
560 transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
561 transB::template transposeLxL<0>(ymm);
562 transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
564 else EIGEN_IF_CONSTEXPR(unrollN == 1) {
566 transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
567 transB::template transposeLxL<0>(ymm);
568 transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
585template <
typename Scalar>
588 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
589 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
606 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
607 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
608 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
609 constexpr int64_t counterReverse = endM * endK - counter;
610 constexpr int64_t startM = counterReverse / (endK);
611 constexpr int64_t startK = counterReverse % endK;
613 constexpr int64_t packetIndex = startM * endK + startK;
614 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
615 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
616 EIGEN_IF_CONSTEXPR(krem) {
617 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
620 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
622 aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
625 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
626 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
627 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
628 EIGEN_UNUSED_VARIABLE(B_arr);
629 EIGEN_UNUSED_VARIABLE(LDB);
630 EIGEN_UNUSED_VARIABLE(RHSInPacket);
631 EIGEN_UNUSED_VARIABLE(rem);
641 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
642 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
643 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
644 constexpr int64_t counterReverse = endM * endK - counter;
645 constexpr int64_t startM = counterReverse / (endK);
646 constexpr int64_t startK = counterReverse % endK;
648 constexpr int64_t packetIndex = startM * endK + startK;
649 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
650 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
651 EIGEN_IF_CONSTEXPR(krem) {
652 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
655 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
657 aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
660 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
661 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
662 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
663 EIGEN_UNUSED_VARIABLE(B_arr);
664 EIGEN_UNUSED_VARIABLE(LDB);
665 EIGEN_UNUSED_VARIABLE(RHSInPacket);
666 EIGEN_UNUSED_VARIABLE(rem);
677 template <
int64_t currM,
int64_t endK,
int64_t counter>
678 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
679 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
680 constexpr int64_t counterReverse = endK - counter;
681 constexpr int64_t startK = counterReverse;
683 constexpr int64_t packetIndex = currM * endK + startK;
684 RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
685 aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
688 template <
int64_t currM,
int64_t endK,
int64_t counter>
689 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
690 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
691 EIGEN_UNUSED_VARIABLE(RHSInPacket);
692 EIGEN_UNUSED_VARIABLE(AInPacket);
702 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
703 int64_t counter, int64_t currentM>
704 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
705 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
706 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
707 constexpr int64_t counterReverse = (endM - initM) * endK - counter;
708 constexpr int64_t startM = initM + counterReverse / (endK);
709 constexpr int64_t startK = counterReverse % endK;
712 constexpr int64_t packetIndex = startM * endK + startK;
713 EIGEN_IF_CONSTEXPR(currentM > 0) {
714 RHSInPacket.packet[packetIndex] =
715 pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
716 RHSInPacket.packet[packetIndex]);
719 EIGEN_IF_CONSTEXPR(startK == endK - 1) {
721 EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
724 EIGEN_IF_CONSTEXPR(isFWDSolve)
725 AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
726 else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
730 EIGEN_IF_CONSTEXPR(isFWDSolve)
731 AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
732 else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
736 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
737 A_arr, LDA, RHSInPacket, AInPacket);
740 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
741 int64_t counter, int64_t currentM>
742 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
743 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
744 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
745 EIGEN_UNUSED_VARIABLE(A_arr);
746 EIGEN_UNUSED_VARIABLE(LDA);
747 EIGEN_UNUSED_VARIABLE(RHSInPacket);
748 EIGEN_UNUSED_VARIABLE(AInPacket);
757 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t counter,
int64_t numK>
758 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
759 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
760 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
761 constexpr int64_t counterReverse = endM - counter;
762 constexpr int64_t startM = counterReverse;
764 constexpr int64_t currentM = startM;
771 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
772 trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
777 trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
781 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
782 trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
784 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
788 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
789 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
790 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
791 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
792 EIGEN_UNUSED_VARIABLE(A_arr);
793 EIGEN_UNUSED_VARIABLE(LDA);
794 EIGEN_UNUSED_VARIABLE(RHSInPacket);
795 EIGEN_UNUSED_VARIABLE(AInPacket);
806 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
807 static EIGEN_ALWAYS_INLINE
void loadRHS(Scalar *B_arr, int64_t LDB,
808 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
809 aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
816 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
817 static EIGEN_ALWAYS_INLINE
void storeRHS(Scalar *B_arr, int64_t LDB,
818 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
819 aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
825 template <
int64_t currM,
int64_t endK>
826 static EIGEN_ALWAYS_INLINE
void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
827 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
828 aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
835 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
837 static EIGEN_ALWAYS_INLINE
void updateRHS(Scalar *A_arr, int64_t LDA,
838 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
839 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
840 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
841 A_arr, LDA, RHSInPacket, AInPacket);
850 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t numK>
851 static EIGEN_ALWAYS_INLINE
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
852 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
853 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
854 static_assert(numK >= 1 && numK <= 3,
"numK out of range");
855 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
864template <
typename Scalar,
bool isAdd>
867 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
868 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
886 template <
int64_t endM,
int64_t endN,
int64_t counter>
887 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
888 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
889 constexpr int64_t counterReverse = endM * endN - counter;
890 constexpr int64_t startM = counterReverse / (endN);
891 constexpr int64_t startN = counterReverse % endN;
893 zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
894 aux_setzero<endM, endN, counter - 1>(zmm);
897 template <
int64_t endM,
int64_t endN,
int64_t counter>
898 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
899 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
900 EIGEN_UNUSED_VARIABLE(zmm);
910 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
911 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
912 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
913 EIGEN_UNUSED_VARIABLE(rem_);
914 constexpr int64_t counterReverse = endM * endN - counter;
915 constexpr int64_t startM = counterReverse / (endN);
916 constexpr int64_t startN = counterReverse % endN;
918 EIGEN_IF_CONSTEXPR(rem)
919 zmm.packet[startN * endM + startM] =
920 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
921 zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
922 else zmm.packet[startN * endM + startM] =
923 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
924 aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
927 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
928 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
929 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
930 EIGEN_UNUSED_VARIABLE(C_arr);
931 EIGEN_UNUSED_VARIABLE(LDC);
932 EIGEN_UNUSED_VARIABLE(zmm);
933 EIGEN_UNUSED_VARIABLE(rem_);
943 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
944 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
945 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
946 EIGEN_UNUSED_VARIABLE(rem_);
947 constexpr int64_t counterReverse = endM * endN - counter;
948 constexpr int64_t startM = counterReverse / (endN);
949 constexpr int64_t startN = counterReverse % endN;
951 EIGEN_IF_CONSTEXPR(rem)
952 pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
953 remMask<PacketSize>(rem_));
954 else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
955 aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
958 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
959 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
960 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
961 EIGEN_UNUSED_VARIABLE(C_arr);
962 EIGEN_UNUSED_VARIABLE(LDC);
963 EIGEN_UNUSED_VARIABLE(zmm);
964 EIGEN_UNUSED_VARIABLE(rem_);
973 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
974 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
975 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
976 EIGEN_UNUSED_VARIABLE(rem_);
977 constexpr int64_t counterReverse = endL - counter;
978 constexpr int64_t startL = counterReverse;
980 EIGEN_IF_CONSTEXPR(rem)
981 zmm.packet[unrollM * unrollN + startL] =
982 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
983 else zmm.packet[unrollM * unrollN + startL] =
984 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
986 aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
989 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
990 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
991 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
992 EIGEN_UNUSED_VARIABLE(B_t);
993 EIGEN_UNUSED_VARIABLE(LDB);
994 EIGEN_UNUSED_VARIABLE(zmm);
995 EIGEN_UNUSED_VARIABLE(rem_);
1004 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1005 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1006 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1007 constexpr int64_t counterReverse = endB - counter;
1008 constexpr int64_t startB = counterReverse;
1010 zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1012 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1015 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1016 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1017 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1018 EIGEN_UNUSED_VARIABLE(A_t);
1019 EIGEN_UNUSED_VARIABLE(LDA);
1020 EIGEN_UNUSED_VARIABLE(zmm);
1030 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1031 int64_t numBCast,
bool rem>
1032 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1033 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1034 EIGEN_UNUSED_VARIABLE(rem_);
1035 if ((numLoad / endM + currK < unrollK)) {
1036 constexpr int64_t counterReverse = endM - counter;
1037 constexpr int64_t startM = counterReverse;
1039 EIGEN_IF_CONSTEXPR(rem) {
1040 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1041 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1044 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1045 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1048 aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1052 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1053 int64_t numBCast,
bool rem>
1054 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1055 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1056 EIGEN_UNUSED_VARIABLE(B_t);
1057 EIGEN_UNUSED_VARIABLE(LDB);
1058 EIGEN_UNUSED_VARIABLE(zmm);
1059 EIGEN_UNUSED_VARIABLE(rem_);
1070 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1071 int64_t numBCast,
bool rem>
1072 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1073 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1075 EIGEN_UNUSED_VARIABLE(rem_);
1076 constexpr int64_t counterReverse = endM * endN * endK - counter;
1077 constexpr int startK = counterReverse / (endM * endN);
1078 constexpr int startN = (counterReverse / (endM)) % endN;
1079 constexpr int startM = counterReverse % endM;
1081 EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1082 gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1083 gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1088 EIGEN_IF_CONSTEXPR(isAdd) {
1089 zmm.packet[startN * endM + startM] =
1090 pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1091 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1094 zmm.packet[startN * endM + startM] =
1095 pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1096 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1099 EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1100 zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1101 (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1106 EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1107 gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1109 aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1112 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1113 int64_t numBCast,
bool rem>
1114 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1115 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1117 EIGEN_UNUSED_VARIABLE(B_t);
1118 EIGEN_UNUSED_VARIABLE(A_t);
1119 EIGEN_UNUSED_VARIABLE(LDB);
1120 EIGEN_UNUSED_VARIABLE(LDA);
1121 EIGEN_UNUSED_VARIABLE(zmm);
1122 EIGEN_UNUSED_VARIABLE(rem_);
1129 template <
int64_t endM,
int64_t endN>
1130 static EIGEN_ALWAYS_INLINE
void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1131 aux_setzero<endM, endN, endM * endN>(zmm);
1137 template <
int64_t endM,
int64_t endN,
bool rem = false>
1138 static EIGEN_ALWAYS_INLINE
void updateC(Scalar *C_arr, int64_t LDC,
1139 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1141 EIGEN_UNUSED_VARIABLE(rem_);
1142 aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1145 template <
int64_t endM,
int64_t endN,
bool rem = false>
1146 static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC,
1147 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1149 EIGEN_UNUSED_VARIABLE(rem_);
1150 aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1156 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
bool rem>
1157 static EIGEN_ALWAYS_INLINE
void startLoadB(Scalar *B_t, int64_t LDB,
1158 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1160 EIGEN_UNUSED_VARIABLE(rem_);
1161 aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1167 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t numLoad>
1168 static EIGEN_ALWAYS_INLINE
void startBCastA(Scalar *A_t, int64_t LDA,
1169 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1170 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1176 template <
int64_t endM,
int64_t unrollN,
int64_t currK,
int64_t unrollK,
int64_t numLoad,
int64_t numBCast,
bool rem>
1177 static EIGEN_ALWAYS_INLINE
void loadB(Scalar *B_t, int64_t LDB,
1178 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1180 EIGEN_UNUSED_VARIABLE(rem_);
1181 aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1207 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1209 static EIGEN_ALWAYS_INLINE
void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1212 EIGEN_UNUSED_VARIABLE(rem_);
1213 aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,