Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TrsmUnrolls.inc
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2022 Intel Corporation
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
12
13template <bool isARowMajor = true>
14EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15 EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
16 else return i + j * LDA;
17}
18
57namespace unrolls {
58
59template <int64_t N>
60EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
61 EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
62 else EIGEN_IF_CONSTEXPR(N == 8) {
63 return 0xFF >> (8 - m);
64 }
65 else EIGEN_IF_CONSTEXPR(N == 4) {
66 return 0x0F >> (4 - m);
67 }
68 return 0;
69}
70
71template <typename Packet>
72EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
73
74template <>
75EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
84
85 kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86 kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87 kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88 kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89 kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90 kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91 kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92 kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
93
94 T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95 T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96 T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97 T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98 T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99 T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100 T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101 T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102 T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103 T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104 T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105 T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106 T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107 T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108 T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109 T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
110
111 kernel.packet[0] = T0;
112 kernel.packet[1] = T1;
113 kernel.packet[2] = T2;
114 kernel.packet[3] = T3;
115 kernel.packet[4] = T4;
116 kernel.packet[5] = T5;
117 kernel.packet[6] = T6;
118 kernel.packet[7] = T7;
119}
120
121template <>
122EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
123 ptranspose(kernel);
124}
125
126/***
127 * Unrolls for transposed C stores
128 */
129template <typename Scalar>
130class trans {
131 public:
132 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133 using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
135
136 /***********************************
137 * Auxiliary Functions for:
138 * - storeC
139 ***********************************
140 */
141
151 template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
152 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154 constexpr int64_t counterReverse = endN - counter;
155 constexpr int64_t startN = counterReverse;
156
157 EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158 EIGEN_IF_CONSTEXPR(remM) {
159 pstoreu<Scalar>(
160 C_arr + LDC * startN,
161 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
165 }
166 else {
167 pstoreu<Scalar>(C_arr + LDC * startN,
168 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
169 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
170 }
171 }
172 else { // This block is only needed for fp32 case
173 // Reinterpret as __m512 for _mm512_shuffle_f32x4
174 vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
176 // Swap lower and upper half of avx register.
177 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178 preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
179
180 EIGEN_IF_CONSTEXPR(remM) {
181 pstoreu<Scalar>(
182 C_arr + LDC * startN,
183 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184 preinterpret<vecHalf>(
185 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
187 }
188 else {
189 pstoreu<Scalar>(
190 C_arr + LDC * startN,
191 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
192 preinterpret<vecHalf>(
193 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
194 }
195 }
196 aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
197 }
198
199 template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
200 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202 EIGEN_UNUSED_VARIABLE(C_arr);
203 EIGEN_UNUSED_VARIABLE(LDC);
204 EIGEN_UNUSED_VARIABLE(zmm);
205 EIGEN_UNUSED_VARIABLE(remM_);
206 }
207
208 template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
209 static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
211 int64_t remM_ = 0) {
212 aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
213 }
214
241 template <int64_t unrollN, int64_t packetIndexOffset>
242 static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
243 // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
244 // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
245 constexpr int64_t zmmStride = unrollN / PacketSize;
246 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
247 r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
248 r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
249 r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
250 r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
251 r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
252 r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
253 r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
254 r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
255 trans8x8blocks(r);
256 zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
257 zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
258 zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
259 zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
260 zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
261 zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
262 zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
263 zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
264 }
265};
266
281template <typename Scalar>
282class transB {
283 public:
284 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
285 using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
286 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
287
288 /***********************************
289 * Auxiliary Functions for:
290 * - loadB
291 * - storeB
292 * - loadBBlock
293 * - storeBBlock
294 ***********************************
295 */
296
303 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
304 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
305 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
306 int64_t remM_ = 0) {
307 constexpr int64_t counterReverse = endN - counter;
308 constexpr int64_t startN = counterReverse;
309
310 EIGEN_IF_CONSTEXPR(remM) {
311 ymm.packet[packetIndexOffset + startN] =
312 ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
313 }
314 else {
315 EIGEN_IF_CONSTEXPR(remN_ == 0) {
316 ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
317 }
318 else ymm.packet[packetIndexOffset + startN] =
319 ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
320 }
321
322 aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
323 }
324
325 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
326 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
327 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
328 int64_t remM_ = 0) {
329 EIGEN_UNUSED_VARIABLE(B_arr);
330 EIGEN_UNUSED_VARIABLE(LDB);
331 EIGEN_UNUSED_VARIABLE(ymm);
332 EIGEN_UNUSED_VARIABLE(remM_);
333 }
334
341 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
342 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
343 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
344 constexpr int64_t counterReverse = endN - counter;
345 constexpr int64_t startN = counterReverse;
346
347 EIGEN_IF_CONSTEXPR(remK || remM) {
348 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
349 remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
350 }
351 else {
352 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
353 }
354
355 aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
356 }
357
358 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
359 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
360 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
361 EIGEN_UNUSED_VARIABLE(B_arr);
362 EIGEN_UNUSED_VARIABLE(LDB);
363 EIGEN_UNUSED_VARIABLE(ymm);
364 EIGEN_UNUSED_VARIABLE(rem_);
365 }
366
373 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
374 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
375 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
376 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
377 constexpr int64_t counterReverse = endN - counter;
378 constexpr int64_t startN = counterReverse;
379 transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
380 aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
381 }
382
383 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
384 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
385 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
386 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
387 EIGEN_UNUSED_VARIABLE(B_arr);
388 EIGEN_UNUSED_VARIABLE(LDB);
389 EIGEN_UNUSED_VARIABLE(B_temp);
390 EIGEN_UNUSED_VARIABLE(LDB_);
391 EIGEN_UNUSED_VARIABLE(ymm);
392 EIGEN_UNUSED_VARIABLE(remM_);
393 }
394
401 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
402 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
403 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
404 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
405 constexpr int64_t counterReverse = endN - counter;
406 constexpr int64_t startN = counterReverse;
407
408 EIGEN_IF_CONSTEXPR(toTemp) {
409 transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
410 }
411 else {
412 transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
413 ymm, remM_);
414 }
415 aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
416 }
417
418 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
419 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
420 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
421 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
422 EIGEN_UNUSED_VARIABLE(B_arr);
423 EIGEN_UNUSED_VARIABLE(LDB);
424 EIGEN_UNUSED_VARIABLE(B_temp);
425 EIGEN_UNUSED_VARIABLE(LDB_);
426 EIGEN_UNUSED_VARIABLE(ymm);
427 EIGEN_UNUSED_VARIABLE(remM_);
428 }
429
430 /********************************************************
431 * Wrappers for aux_XXXX to hide counter parameter
432 ********************************************************/
433
434 template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
435 static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
436 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
437 int64_t remM_ = 0) {
438 aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
439 }
440
441 template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
442 static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
443 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
444 int64_t rem_ = 0) {
445 aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
446 }
447
448 template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
449 static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
450 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
451 int64_t remM_ = 0) {
452 EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
453 else {
454 aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
455 }
456 }
457
458 template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
459 static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
460 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
461 int64_t remM_ = 0) {
462 aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
463 }
464
465 template <int64_t packetIndexOffset>
466 static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
467 // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
468 // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
469 PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
470 r.packet[0] = ymm.packet[packetIndexOffset + 0];
471 r.packet[1] = ymm.packet[packetIndexOffset + 1];
472 r.packet[2] = ymm.packet[packetIndexOffset + 2];
473 r.packet[3] = ymm.packet[packetIndexOffset + 3];
474 r.packet[4] = ymm.packet[packetIndexOffset + 4];
475 r.packet[5] = ymm.packet[packetIndexOffset + 5];
476 r.packet[6] = ymm.packet[packetIndexOffset + 6];
477 r.packet[7] = ymm.packet[packetIndexOffset + 7];
478 ptranspose(r);
479 ymm.packet[packetIndexOffset + 0] = r.packet[0];
480 ymm.packet[packetIndexOffset + 1] = r.packet[1];
481 ymm.packet[packetIndexOffset + 2] = r.packet[2];
482 ymm.packet[packetIndexOffset + 3] = r.packet[3];
483 ymm.packet[packetIndexOffset + 4] = r.packet[4];
484 ymm.packet[packetIndexOffset + 5] = r.packet[5];
485 ymm.packet[packetIndexOffset + 6] = r.packet[6];
486 ymm.packet[packetIndexOffset + 7] = r.packet[7];
487 }
488
489 template <int64_t unrollN, bool toTemp, bool remM>
490 static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
491 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
492 int64_t remM_ = 0) {
493 constexpr int64_t U3 = PacketSize * 3;
494 constexpr int64_t U2 = PacketSize * 2;
495 constexpr int64_t U1 = PacketSize * 1;
503 EIGEN_IF_CONSTEXPR(unrollN == U3) {
504 // load LxU3 B col major, transpose LxU3 row major
505 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
506 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
507 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
510 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
511
512 EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
513 transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
514 ymm, remM_);
515 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
518 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
519 ymm, remM_);
520 }
521 }
522 else EIGEN_IF_CONSTEXPR(unrollN == U2) {
523 // load LxU2 B col major, transpose LxU2 row major
524 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
525 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
526 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528 EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
529 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
530
531 EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
532 transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
533 &B_temp[maxUBlock], LDB_, ymm, remM_);
534 transB::template transposeLxL<0>(ymm);
535 transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
536 &B_temp[maxUBlock], LDB_, ymm, remM_);
537 }
538 }
539 else EIGEN_IF_CONSTEXPR(unrollN == U1) {
540 // load LxU1 B col major, transpose LxU1 row major
541 transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
542 transB::template transposeLxL<0>(ymm);
543 EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
544 transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
545 }
546 else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
547 // load Lx4 B col major, transpose Lx4 row major
548 transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
549 transB::template transposeLxL<0>(ymm);
550 transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
551 }
552 else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
553 // load Lx4 B col major, transpose Lx4 row major
554 transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
555 transB::template transposeLxL<0>(ymm);
556 transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
557 }
558 else EIGEN_IF_CONSTEXPR(unrollN == 2) {
559 // load Lx2 B col major, transpose Lx2 row major
560 transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
561 transB::template transposeLxL<0>(ymm);
562 transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
563 }
564 else EIGEN_IF_CONSTEXPR(unrollN == 1) {
565 // load Lx1 B col major, transpose Lx1 row major
566 transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
567 transB::template transposeLxL<0>(ymm);
568 transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
569 }
570 }
571};
572
585template <typename Scalar>
586class trsm {
587 public:
588 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
589 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
590
591 /***********************************
592 * Auxiliary Functions for:
593 * - loadRHS
594 * - storeRHS
595 * - divRHSByDiag
596 * - updateRHS
597 * - triSolveMicroKernel
598 ************************************/
606 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
607 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
608 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
609 constexpr int64_t counterReverse = endM * endK - counter;
610 constexpr int64_t startM = counterReverse / (endK);
611 constexpr int64_t startK = counterReverse % endK;
612
613 constexpr int64_t packetIndex = startM * endK + startK;
614 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
615 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
616 EIGEN_IF_CONSTEXPR(krem) {
617 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
618 }
619 else {
620 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
621 }
622 aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
623 }
624
625 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
626 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
627 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
628 EIGEN_UNUSED_VARIABLE(B_arr);
629 EIGEN_UNUSED_VARIABLE(LDB);
630 EIGEN_UNUSED_VARIABLE(RHSInPacket);
631 EIGEN_UNUSED_VARIABLE(rem);
632 }
633
641 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
642 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
643 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
644 constexpr int64_t counterReverse = endM * endK - counter;
645 constexpr int64_t startM = counterReverse / (endK);
646 constexpr int64_t startK = counterReverse % endK;
647
648 constexpr int64_t packetIndex = startM * endK + startK;
649 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
650 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
651 EIGEN_IF_CONSTEXPR(krem) {
652 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
653 }
654 else {
655 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
656 }
657 aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
658 }
659
660 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
661 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
662 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
663 EIGEN_UNUSED_VARIABLE(B_arr);
664 EIGEN_UNUSED_VARIABLE(LDB);
665 EIGEN_UNUSED_VARIABLE(RHSInPacket);
666 EIGEN_UNUSED_VARIABLE(rem);
667 }
668
677 template <int64_t currM, int64_t endK, int64_t counter>
678 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
679 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
680 constexpr int64_t counterReverse = endK - counter;
681 constexpr int64_t startK = counterReverse;
682
683 constexpr int64_t packetIndex = currM * endK + startK;
684 RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
685 aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
686 }
687
688 template <int64_t currM, int64_t endK, int64_t counter>
689 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
690 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
691 EIGEN_UNUSED_VARIABLE(RHSInPacket);
692 EIGEN_UNUSED_VARIABLE(AInPacket);
693 }
694
702 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
703 int64_t counter, int64_t currentM>
704 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
705 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
706 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
707 constexpr int64_t counterReverse = (endM - initM) * endK - counter;
708 constexpr int64_t startM = initM + counterReverse / (endK);
709 constexpr int64_t startK = counterReverse % endK;
710
711 // For each row of A, first update all corresponding RHS
712 constexpr int64_t packetIndex = startM * endK + startK;
713 EIGEN_IF_CONSTEXPR(currentM > 0) {
714 RHSInPacket.packet[packetIndex] =
715 pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
716 RHSInPacket.packet[packetIndex]);
717 }
718
719 EIGEN_IF_CONSTEXPR(startK == endK - 1) {
720 // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
721 EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
722 // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
723 // This will be used in divRHSByDiag
724 EIGEN_IF_CONSTEXPR(isFWDSolve)
725 AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
726 else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
727 }
728 else {
729 // Broadcast next off diagonal element of A
730 EIGEN_IF_CONSTEXPR(isFWDSolve)
731 AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
732 else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
733 }
734 }
735
736 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
737 A_arr, LDA, RHSInPacket, AInPacket);
738 }
739
740 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
741 int64_t counter, int64_t currentM>
742 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
743 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
744 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
745 EIGEN_UNUSED_VARIABLE(A_arr);
746 EIGEN_UNUSED_VARIABLE(LDA);
747 EIGEN_UNUSED_VARIABLE(RHSInPacket);
748 EIGEN_UNUSED_VARIABLE(AInPacket);
749 }
750
757 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
758 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
759 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
760 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
761 constexpr int64_t counterReverse = endM - counter;
762 constexpr int64_t startM = counterReverse;
763
764 constexpr int64_t currentM = startM;
765 // Divides the right-hand side in row startM, by digonal value of A
766 // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
767 //
768 // Without "if constexpr" the compiler instantiates the case <-1, numK>
769 // this is handled with enable_if to prevent out-of-bound warnings
770 // from the compiler
771 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
772 trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
773
774 // After division, the rhs corresponding to subsequent rows of A can be partially updated
775 // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
776 // to be used in the next iteration.
777 trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
778 AInPacket);
779
780 // Handle division for the RHS corresponding to the final row of A.
781 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
782 trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
783
784 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
785 AInPacket);
786 }
787
788 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
789 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
790 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
791 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
792 EIGEN_UNUSED_VARIABLE(A_arr);
793 EIGEN_UNUSED_VARIABLE(LDA);
794 EIGEN_UNUSED_VARIABLE(RHSInPacket);
795 EIGEN_UNUSED_VARIABLE(AInPacket);
796 }
797
798 /********************************************************
799 * Wrappers for aux_XXXX to hide counter parameter
800 ********************************************************/
801
806 template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
807 static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
808 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
809 aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
810 }
811
816 template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
817 static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
818 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
819 aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
820 }
821
825 template <int64_t currM, int64_t endK>
826 static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
827 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
828 aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
829 }
830
835 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
836 int64_t currentM>
837 static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
838 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
839 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
840 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
841 A_arr, LDA, RHSInPacket, AInPacket);
842 }
843
850 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
851 static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
852 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
853 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
854 static_assert(numK >= 1 && numK <= 3, "numK out of range");
855 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
856 }
857};
858
864template <typename Scalar, bool isAdd>
865class gemm {
866 public:
867 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
868 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
869
870 /***********************************
871 * Auxiliary Functions for:
872 * - setzero
873 * - updateC
874 * - storeC
875 * - startLoadB
876 * - triSolveMicroKernel
877 ************************************/
878
886 template <int64_t endM, int64_t endN, int64_t counter>
887 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
888 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
889 constexpr int64_t counterReverse = endM * endN - counter;
890 constexpr int64_t startM = counterReverse / (endN);
891 constexpr int64_t startN = counterReverse % endN;
892
893 zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
894 aux_setzero<endM, endN, counter - 1>(zmm);
895 }
896
897 template <int64_t endM, int64_t endN, int64_t counter>
898 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
899 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
900 EIGEN_UNUSED_VARIABLE(zmm);
901 }
902
910 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
911 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
912 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
913 EIGEN_UNUSED_VARIABLE(rem_);
914 constexpr int64_t counterReverse = endM * endN - counter;
915 constexpr int64_t startM = counterReverse / (endN);
916 constexpr int64_t startN = counterReverse % endN;
917
918 EIGEN_IF_CONSTEXPR(rem)
919 zmm.packet[startN * endM + startM] =
920 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
921 zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
922 else zmm.packet[startN * endM + startM] =
923 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
924 aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
925 }
926
927 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
928 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
929 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
930 EIGEN_UNUSED_VARIABLE(C_arr);
931 EIGEN_UNUSED_VARIABLE(LDC);
932 EIGEN_UNUSED_VARIABLE(zmm);
933 EIGEN_UNUSED_VARIABLE(rem_);
934 }
935
943 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
944 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
945 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
946 EIGEN_UNUSED_VARIABLE(rem_);
947 constexpr int64_t counterReverse = endM * endN - counter;
948 constexpr int64_t startM = counterReverse / (endN);
949 constexpr int64_t startN = counterReverse % endN;
950
951 EIGEN_IF_CONSTEXPR(rem)
952 pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
953 remMask<PacketSize>(rem_));
954 else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
955 aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
956 }
957
958 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
959 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
960 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
961 EIGEN_UNUSED_VARIABLE(C_arr);
962 EIGEN_UNUSED_VARIABLE(LDC);
963 EIGEN_UNUSED_VARIABLE(zmm);
964 EIGEN_UNUSED_VARIABLE(rem_);
965 }
966
973 template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
974 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
975 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
976 EIGEN_UNUSED_VARIABLE(rem_);
977 constexpr int64_t counterReverse = endL - counter;
978 constexpr int64_t startL = counterReverse;
979
980 EIGEN_IF_CONSTEXPR(rem)
981 zmm.packet[unrollM * unrollN + startL] =
982 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
983 else zmm.packet[unrollM * unrollN + startL] =
984 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
985
986 aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
987 }
988
989 template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
990 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
991 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
992 EIGEN_UNUSED_VARIABLE(B_t);
993 EIGEN_UNUSED_VARIABLE(LDB);
994 EIGEN_UNUSED_VARIABLE(zmm);
995 EIGEN_UNUSED_VARIABLE(rem_);
996 }
997
1004 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1005 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1006 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1007 constexpr int64_t counterReverse = endB - counter;
1008 constexpr int64_t startB = counterReverse;
1009
1010 zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1011
1012 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1013 }
1014
1015 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1016 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1017 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1018 EIGEN_UNUSED_VARIABLE(A_t);
1019 EIGEN_UNUSED_VARIABLE(LDA);
1020 EIGEN_UNUSED_VARIABLE(zmm);
1021 }
1022
1030 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1031 int64_t numBCast, bool rem>
1032 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1033 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1034 EIGEN_UNUSED_VARIABLE(rem_);
1035 if ((numLoad / endM + currK < unrollK)) {
1036 constexpr int64_t counterReverse = endM - counter;
1037 constexpr int64_t startM = counterReverse;
1038
1039 EIGEN_IF_CONSTEXPR(rem) {
1040 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1041 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1042 }
1043 else {
1044 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1045 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1046 }
1047
1048 aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1049 }
1050 }
1051
1052 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1053 int64_t numBCast, bool rem>
1054 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1055 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1056 EIGEN_UNUSED_VARIABLE(B_t);
1057 EIGEN_UNUSED_VARIABLE(LDB);
1058 EIGEN_UNUSED_VARIABLE(zmm);
1059 EIGEN_UNUSED_VARIABLE(rem_);
1060 }
1061
1070 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1071 int64_t numBCast, bool rem>
1072 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1073 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1074 int64_t rem_ = 0) {
1075 EIGEN_UNUSED_VARIABLE(rem_);
1076 constexpr int64_t counterReverse = endM * endN * endK - counter;
1077 constexpr int startK = counterReverse / (endM * endN);
1078 constexpr int startN = (counterReverse / (endM)) % endN;
1079 constexpr int startM = counterReverse % endM;
1080
1081 EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1082 gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1083 gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1084 }
1085
1086 {
1087 // Interleave FMA and Bcast
1088 EIGEN_IF_CONSTEXPR(isAdd) {
1089 zmm.packet[startN * endM + startM] =
1090 pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1091 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1092 }
1093 else {
1094 zmm.packet[startN * endM + startM] =
1095 pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1096 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1097 }
1098 // Bcast
1099 EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1100 zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1101 (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1102 }
1103 }
1104
1105 // We have updated all accumulators, time to load next set of B's
1106 EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1107 gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1108 }
1109 aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1110 }
1111
1112 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1113 int64_t numBCast, bool rem>
1114 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1115 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1116 int64_t rem_ = 0) {
1117 EIGEN_UNUSED_VARIABLE(B_t);
1118 EIGEN_UNUSED_VARIABLE(A_t);
1119 EIGEN_UNUSED_VARIABLE(LDB);
1120 EIGEN_UNUSED_VARIABLE(LDA);
1121 EIGEN_UNUSED_VARIABLE(zmm);
1122 EIGEN_UNUSED_VARIABLE(rem_);
1123 }
1124
1125 /********************************************************
1126 * Wrappers for aux_XXXX to hide counter parameter
1127 ********************************************************/
1128
1129 template <int64_t endM, int64_t endN>
1130 static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1131 aux_setzero<endM, endN, endM * endN>(zmm);
1132 }
1133
1137 template <int64_t endM, int64_t endN, bool rem = false>
1138 static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
1139 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1140 int64_t rem_ = 0) {
1141 EIGEN_UNUSED_VARIABLE(rem_);
1142 aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1143 }
1144
1145 template <int64_t endM, int64_t endN, bool rem = false>
1146 static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
1147 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1148 int64_t rem_ = 0) {
1149 EIGEN_UNUSED_VARIABLE(rem_);
1150 aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1151 }
1152
1156 template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
1157 static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
1158 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1159 int64_t rem_ = 0) {
1160 EIGEN_UNUSED_VARIABLE(rem_);
1161 aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1162 }
1163
1167 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
1168 static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
1169 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1170 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1171 }
1172
1176 template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
1177 static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
1178 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1179 int64_t rem_ = 0) {
1180 EIGEN_UNUSED_VARIABLE(rem_);
1181 aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1182 }
1183
1207 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1208 bool rem = false>
1209 static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1211 int64_t rem_ = 0) {
1212 EIGEN_UNUSED_VARIABLE(rem_);
1213 aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
1214 rem_);
1215 }
1216};
1217} // namespace unrolls
1218
1219#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H