12#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
13#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
15#if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
19template<
typename Scalar,
typename Index,
typename LhsMapper,
20 typename RhsMapper,
typename OutputMapper,
bool needs_edge_check>
21__device__ EIGEN_STRONG_INLINE
void
22EigenContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
23 const OutputMapper output,
Scalar* lhs_shmem,
Scalar* rhs_shmem,
26 const Index m_block_idx = blockIdx.x;
27 const Index n_block_idx = blockIdx.y;
29 const Index base_m = 64 * m_block_idx;
30 const Index base_n = 64 * n_block_idx;
67 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
97 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98 const Index lhs_vert = base_m + load_idx_vert;
100#define prefetchIntoRegisters(base_k) \
120 if (!needs_edge_check || lhs_vert < m_size) { \
121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
139 } else if (lhs_horiz_6 < k_size) { \
140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
147 } else if (lhs_horiz_5 < k_size) { \
148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
154 } else if (lhs_horiz_4 < k_size) { \
155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
160 } else if (lhs_horiz_3 < k_size) { \
161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
165 } else if (lhs_horiz_2 < k_size) { \
166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
169 } else if (lhs_horiz_1 < k_size) { \
170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
172 } else if (lhs_horiz_0 < k_size) { \
173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
177 const Index rhs_vert = base_k + load_idx_vert; \
178 if (!needs_edge_check || rhs_vert < k_size) { \
179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
188 if (rhs_horiz_7 < n_size) { \
189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
197 } else if (rhs_horiz_6 < n_size) { \
198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
205 } else if (rhs_horiz_5 < n_size) { \
206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
212 } else if (rhs_horiz_4 < n_size) { \
213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
218 } else if (rhs_horiz_3 < n_size) { \
219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
223 } else if (rhs_horiz_2 < n_size) { \
224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
227 } else if (rhs_horiz_1 < n_size) { \
228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
230 } else if (rhs_horiz_0 < n_size) { \
231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
236#define writeRegToShmem(_) \
237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
262#define res(i, j) _res_##i##j
263#define initResultRow(i) \
264 Scalar res(i, 0) = conv(0); \
265 Scalar res(i, 1) = conv(0); \
266 Scalar res(i, 2) = conv(0); \
267 Scalar res(i, 3) = conv(0); \
268 Scalar res(i, 4) = conv(0); \
269 Scalar res(i, 5) = conv(0); \
270 Scalar res(i, 6) = conv(0); \
271 Scalar res(i, 7) = conv(0); \
273 internal::scalar_cast_op<int, Scalar> conv;
284 for (
Index base_k = 0; base_k < k_size; base_k += 64) {
289 prefetchIntoRegisters(base_k);
292 #undef prefetchIntoRegisters
293 #undef writeRegToShmem
301#define lcol(i) _lcol##i
311#define rrow(j) _rrow##j
322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
325#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
326#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
328#define loadData(i, j) \
329 lcol(0) = lhs_element(0, j); \
330 rrow(0) = rhs_element(i, 0); \
331 lcol(1) = lhs_element(1, j); \
332 rrow(1) = rhs_element(i, 1); \
333 lcol(2) = lhs_element(2, j); \
334 rrow(2) = rhs_element(i, 2); \
335 lcol(3) = lhs_element(3, j); \
336 rrow(3) = rhs_element(i, 3); \
337 lcol(4) = lhs_element(4, j); \
338 rrow(4) = rhs_element(i, 4); \
339 lcol(5) = lhs_element(5, j); \
340 rrow(5) = rhs_element(i, 5); \
341 lcol(6) = lhs_element(6, j); \
342 rrow(6) = rhs_element(i, 6); \
343 lcol(7) = lhs_element(7, j); \
344 rrow(7) = rhs_element(i, 7); \
346#define computeCol(j) \
347 res(0, j) += lcol(0) * rrow(j); \
348 res(1, j) += lcol(1) * rrow(j); \
349 res(2, j) += lcol(2) * rrow(j); \
350 res(3, j) += lcol(3) * rrow(j); \
351 res(4, j) += lcol(4) * rrow(j); \
352 res(5, j) += lcol(5) * rrow(j); \
353 res(6, j) += lcol(6) * rrow(j); \
354 res(7, j) += lcol(7) * rrow(j); \
356#define computePass(i) \
391#if EIGEN_CUDA_SDK_VER < 90000
392#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
394#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
397#define reduceRow(i, mask) \
398 shuffleInc(i, 0, mask); \
399 shuffleInc(i, 1, mask); \
400 shuffleInc(i, 2, mask); \
401 shuffleInc(i, 3, mask); \
402 shuffleInc(i, 4, mask); \
403 shuffleInc(i, 5, mask); \
404 shuffleInc(i, 6, mask); \
405 shuffleInc(i, 7, mask); \
407#define reduceMatrix(mask) \
408 reduceRow(0, mask); \
409 reduceRow(1, mask); \
410 reduceRow(2, mask); \
411 reduceRow(3, mask); \
412 reduceRow(4, mask); \
413 reduceRow(5, mask); \
414 reduceRow(6, mask); \
415 reduceRow(7, mask); \
442#define writeResultShmem(i, j) \
443 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
446 writeResultShmem(i, 0); \
447 writeResultShmem(i, 1); \
448 writeResultShmem(i, 2); \
449 writeResultShmem(i, 3); \
450 writeResultShmem(i, 4); \
451 writeResultShmem(i, 5); \
452 writeResultShmem(i, 6); \
453 writeResultShmem(i, 7); \
455 if (threadIdx.x == 0) {
465#undef writeResultShmem
468 const int max_i_write = numext::mini((
int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
469 const int max_j_write = numext::mini((
int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
471 if (threadIdx.x < max_i_write) {
472 if (max_j_write == 8) {
474 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
475 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
476 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
477 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
478 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
479 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
480 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
481 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
493 for (
int j = 0; j < max_j_write; j++) {
494 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
495 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
503template<
typename Scalar,
typename Index,
typename LhsMapper,
504 typename RhsMapper,
typename OutputMapper>
506__launch_bounds__(512)
507EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
508 const OutputMapper output,
510 __shared__
Scalar lhs_shmem[72 * 64];
511 __shared__
Scalar rhs_shmem[72 * 64];
513 const Index m_block_idx = blockIdx.x;
514 const Index n_block_idx = blockIdx.y;
516 const Index base_m = 64 * m_block_idx;
517 const Index base_n = 64 * n_block_idx;
519 if (base_m + 63 < m_size && base_n + 63 < n_size) {
520 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
522 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527template<
typename Index,
typename LhsMapper,
528 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
529 bool CHECK_RHS_BOUNDARY>
530__device__ EIGEN_STRONG_INLINE
void
531EigenFloatContractionKernelInternal16x16(
const LhsMapper lhs,
const RhsMapper rhs,
532 const OutputMapper output, float2 lhs_shmem2[][16],
533 float2 rhs_shmem2[][8],
const Index m_size,
537 float4 lhs_pf0, rhs_pf0;
540 for (
int i=0; i < 4; i++) {
541 results[i].x = results[i].y = results[i].z = results[i].w = 0;
545#define prefetch_lhs(reg, row, col) \
546 if (!CHECK_LHS_BOUNDARY) { \
547 if (col < k_size) { \
548 reg =lhs.template loadPacket<Unaligned>(row, col); \
551 if (col < k_size) { \
552 if (row + 3 < m_size) { \
553 reg =lhs.template loadPacket<Unaligned>(row, col); \
554 } else if (row + 2 < m_size) { \
555 reg.x =lhs(row + 0, col); \
556 reg.y =lhs(row + 1, col); \
557 reg.z =lhs(row + 2, col); \
558 } else if (row + 1 < m_size) { \
559 reg.x =lhs(row + 0, col); \
560 reg.y =lhs(row + 1, col); \
561 } else if (row < m_size) { \
562 reg.x =lhs(row + 0, col); \
568 Index lhs_vert = base_m+threadIdx.x*4;
570 for (
Index k = 0; k < k_size; k += 16) {
571 lhs_pf0 = internal::pset1<float4>(0);
572 rhs_pf0 = internal::pset1<float4>(0);
574 Index lhs_horiz = threadIdx.y+k;
575 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
577 Index rhs_vert = k+(threadIdx.x%4)*4;
578 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
580 if (!CHECK_RHS_BOUNDARY) {
581 if ((rhs_vert + 3) < k_size) {
583 rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
584 }
else if (rhs_vert + 2 < k_size) {
586 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
587 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
588 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
589 }
else if (rhs_vert + 1 < k_size) {
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592 }
else if (rhs_vert < k_size) {
593 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
596 if (rhs_horiz0 < n_size) {
597 if ((rhs_vert + 3) < k_size) {
598 rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
599 }
else if ((rhs_vert + 2) < k_size) {
600 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
601 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
602 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
603 }
else if ((rhs_vert + 1) < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606 }
else if (rhs_vert < k_size) {
607 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
613 if((threadIdx.x%8) < 4) {
620#if EIGEN_CUDA_SDK_VER < 90000
621 x1 = __shfl_xor(x1, 4);
622 x2 = __shfl_xor(x2, 4);
624 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
625 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
627 if((threadIdx.x%8) < 4) {
642 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
643 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
652 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
653 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
656#define add_vals(fl1, fl2, fr1, fr2)\
657 results[0].x += fl1.x * fr1.x;\
658 results[0].y += fl1.y * fr1.x;\
659 results[0].z += fl2.x * fr1.x;\
660 results[0].w += fl2.y * fr1.x;\
662 results[1].x += fl1.x * fr1.y;\
663 results[1].y += fl1.y * fr1.y;\
664 results[1].z += fl2.x * fr1.y;\
665 results[1].w += fl2.y * fr1.y;\
667 results[2].x += fl1.x * fr2.x;\
668 results[2].y += fl1.y * fr2.x;\
669 results[2].z += fl2.x * fr2.x;\
670 results[2].w += fl2.y * fr2.x;\
672 results[3].x += fl1.x * fr2.y;\
673 results[3].y += fl1.y * fr2.y;\
674 results[3].z += fl2.x * fr2.y;\
675 results[3].w += fl2.y * fr2.y;\
681 for (
int koff = 0; koff < 16; koff ++) {
683 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
684 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
686 int start_feature = threadIdx.y * 4;
687 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
688 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
690 add_vals(fl1, fl2, fr1, fr2)
698 Index horiz_base = threadIdx.y*4+base_n;
699 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
700 for (
int i = 0; i < 4; i++) {
701 output(lhs_vert, horiz_base + i) = results[i].x;
702 output(lhs_vert + 1, horiz_base + i) = results[i].y;
703 output(lhs_vert + 2, horiz_base + i) = results[i].z;
704 output(lhs_vert + 3, horiz_base + i) = results[i].w;
706 }
else if (!CHECK_RHS_BOUNDARY) {
708 if (lhs_vert + 3 < m_size) {
709 for (
int i = 0; i < 4; i++) {
710 output(lhs_vert, horiz_base + i) = results[i].x;
711 output(lhs_vert + 1, horiz_base + i) = results[i].y;
712 output(lhs_vert + 2, horiz_base + i) = results[i].z;
713 output(lhs_vert + 3, horiz_base + i) = results[i].w;
715 }
else if (lhs_vert + 2 < m_size) {
716 for (
int i = 0; i < 4; i++) {
717 output(lhs_vert, horiz_base + i) = results[i].x;
718 output(lhs_vert + 1, horiz_base + i) = results[i].y;
719 output(lhs_vert + 2, horiz_base + i) = results[i].z;
721 }
else if (lhs_vert + 1 < m_size) {
722 for (
int i = 0; i < 4; i++) {
723 output(lhs_vert, horiz_base + i) = results[i].x;
724 output(lhs_vert + 1, horiz_base + i) = results[i].y;
726 }
else if (lhs_vert < m_size) {
727 for (
int i = 0; i < 4; i++) {
728 output(lhs_vert, horiz_base + i) = results[i].x;
731 }
else if (!CHECK_LHS_BOUNDARY) {
741 for (
int i = 0; i < 4; i++) {
742 if (horiz_base+i < n_size) {
743 output(lhs_vert, horiz_base + i) = results[i].x;
744 output(lhs_vert + 1, horiz_base + i) = results[i].y;
745 output(lhs_vert + 2, horiz_base + i) = results[i].z;
746 output(lhs_vert + 3, horiz_base + i) = results[i].w;
751 for (
int i = 0; i < 4; i++) {
752 if (horiz_base+i < n_size) {
753 if (lhs_vert < m_size)
754 output(lhs_vert, horiz_base + i) = results[i].x;
755 if (lhs_vert + 1 < m_size)
756 output(lhs_vert + 1, horiz_base + i) = results[i].y;
757 if (lhs_vert + 2 < m_size)
758 output(lhs_vert + 2, horiz_base + i) = results[i].z;
759 if (lhs_vert + 3 < m_size)
760 output(lhs_vert + 3, horiz_base + i) = results[i].w;
767template<
typename Index,
typename LhsMapper,
768 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
769 bool CHECK_RHS_BOUNDARY>
770__device__ EIGEN_STRONG_INLINE
void
771EigenFloatContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
772 const OutputMapper output, float2 lhs_shmem2[][32],
773 float2 rhs_shmem2[][8],
const Index m_size,
777 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
778 float4 rhs_pf0, rhs_pf1;
781 for (
int i=0; i < 8; i++) {
782 results[i].x = results[i].y = results[i].z = results[i].w = 0;
786 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
787 for (
Index k = 0; k < k_size; k += 32) {
788 lhs_pf0 = internal::pset1<float4>(0);
789 lhs_pf1 = internal::pset1<float4>(0);
790 lhs_pf2 = internal::pset1<float4>(0);
791 lhs_pf3 = internal::pset1<float4>(0);
793 rhs_pf0 = internal::pset1<float4>(0);
794 rhs_pf1 = internal::pset1<float4>(0);
796 if (!CHECK_LHS_BOUNDARY) {
797 if ((threadIdx.y/4+k+24) < k_size) {
798 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
799 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
800 lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
801 lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
802 }
else if ((threadIdx.y/4+k+16) < k_size) {
803 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
804 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
805 lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
806 }
else if ((threadIdx.y/4+k+8) < k_size) {
807 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
808 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
809 }
else if ((threadIdx.y/4+k) < k_size) {
810 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
814 if (lhs_vert + 3 < m_size) {
815 if ((threadIdx.y/4+k+24) < k_size) {
816 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
817 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
818 lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
819 lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
820 }
else if ((threadIdx.y/4+k+16) < k_size) {
821 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
822 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
823 lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
824 }
else if ((threadIdx.y/4+k+8) < k_size) {
825 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
826 lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
827 }
else if ((threadIdx.y/4+k) < k_size) {
828 lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
830 }
else if (lhs_vert + 2 < m_size) {
831 if ((threadIdx.y/4+k+24) < k_size) {
832 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
833 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
834 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
835 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
836 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
837 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
838 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
839 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
840 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
841 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
842 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
843 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
844 }
else if ((threadIdx.y/4+k+16) < k_size) {
845 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
846 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
847 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
848 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
849 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
850 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
851 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
852 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
853 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
854 }
else if ((threadIdx.y/4+k+8) < k_size) {
855 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
856 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
857 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
858 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
859 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
860 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
861 }
else if ((threadIdx.y/4+k) < k_size) {
862 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
863 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
864 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
866 }
else if (lhs_vert + 1 < m_size) {
867 if ((threadIdx.y/4+k+24) < k_size) {
868 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
869 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
870 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
871 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
872 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
873 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
874 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
875 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
876 }
else if ((threadIdx.y/4+k+16) < k_size) {
877 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
878 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
879 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
880 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
881 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
882 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
883 }
else if ((threadIdx.y/4+k+8) < k_size) {
884 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
885 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
886 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
887 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
888 }
else if ((threadIdx.y/4+k) < k_size) {
889 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
890 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
892 }
else if (lhs_vert < m_size) {
893 if ((threadIdx.y/4+k+24) < k_size) {
894 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
895 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
896 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
897 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
898 }
else if ((threadIdx.y/4+k+16) < k_size) {
899 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
900 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
901 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
902 }
else if ((threadIdx.y/4+k+8) < k_size) {
903 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
904 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
905 }
else if ((threadIdx.y/4+k) < k_size) {
906 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
911 Index rhs_vert = k+threadIdx.x*4;
912 Index rhs_horiz0 = threadIdx.y*2+base_n;
913 Index rhs_horiz1 = threadIdx.y*2+1+base_n;
914 if (!CHECK_RHS_BOUNDARY) {
915 if ((rhs_vert + 3) < k_size) {
917 rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
918 rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
919 }
else if (rhs_vert + 2 < k_size) {
921 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
922 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
923 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
924 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
925 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
926 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
927 }
else if (rhs_vert + 1 < k_size) {
928 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
929 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
930 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
931 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
932 }
else if (rhs_vert < k_size) {
933 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
934 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
937 if (rhs_horiz1 < n_size) {
938 if ((rhs_vert + 3) < k_size) {
940 rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
941 rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
942 }
else if (rhs_vert + 2 < k_size) {
944 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
945 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
946 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
947 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
948 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
949 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
950 }
else if (k+threadIdx.x*4 + 1 < k_size) {
951 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
952 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
953 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
954 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
955 }
else if (k+threadIdx.x*4 < k_size) {
956 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
957 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
959 }
else if (rhs_horiz0 < n_size) {
960 if ((rhs_vert + 3) < k_size) {
962 rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
963 }
else if ((rhs_vert + 2) < k_size) {
965 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
966 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
967 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
968 }
else if ((rhs_vert + 1) < k_size) {
969 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
970 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
971 }
else if (rhs_vert < k_size) {
972 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
982 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
986 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
989 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
992 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
1002#define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1003 results[0].x += a_feat1.x * f1.x;\
1004 results[1].x += a_feat1.x * f1.y;\
1005 results[2].x += a_feat1.x * f2.x;\
1006 results[3].x += a_feat1.x * f2.y;\
1007 results[4].x += a_feat1.x * f3.x;\
1008 results[5].x += a_feat1.x * f3.y;\
1009 results[6].x += a_feat1.x * f4.x;\
1010 results[7].x += a_feat1.x * f4.y;\
1012 results[0].y += a_feat1.y * f1.x;\
1013 results[1].y += a_feat1.y * f1.y;\
1014 results[2].y += a_feat1.y * f2.x;\
1015 results[3].y += a_feat1.y * f2.y;\
1016 results[4].y += a_feat1.y * f3.x;\
1017 results[5].y += a_feat1.y * f3.y;\
1018 results[6].y += a_feat1.y * f4.x;\
1019 results[7].y += a_feat1.y * f4.y;\
1021 results[0].z += a_feat2.x * f1.x;\
1022 results[1].z += a_feat2.x * f1.y;\
1023 results[2].z += a_feat2.x * f2.x;\
1024 results[3].z += a_feat2.x * f2.y;\
1025 results[4].z += a_feat2.x * f3.x;\
1026 results[5].z += a_feat2.x * f3.y;\
1027 results[6].z += a_feat2.x * f4.x;\
1028 results[7].z += a_feat2.x * f4.y;\
1030 results[0].w += a_feat2.y * f1.x;\
1031 results[1].w += a_feat2.y * f1.y;\
1032 results[2].w += a_feat2.y * f2.x;\
1033 results[3].w += a_feat2.y * f2.y;\
1034 results[4].w += a_feat2.y * f3.x;\
1035 results[5].w += a_feat2.y * f3.y;\
1036 results[6].w += a_feat2.y * f4.x;\
1037 results[7].w += a_feat2.y * f4.y;\
1039 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1040 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1041 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1042 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1044 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1045 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1046 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1047 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1053 for (
int koff = 0; koff < 32; koff ++) {
1054 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1055 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1058 int start_feature = (threadIdx.y / 4) * 8;
1060 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1061 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1062 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1063 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1065 add_vals(a3, a4, br1, br2, br3, br4)
1072 Index horiz_base = (threadIdx.y/4)*8+base_n;
1073 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1074 for (
int i = 0; i < 8; i++) {
1075 output(lhs_vert, horiz_base + i) = results[i].x;
1076 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1077 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1078 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1080 }
else if (!CHECK_RHS_BOUNDARY) {
1081 if (lhs_vert + 3 < m_size) {
1082 for (
int i = 0; i < 8; i++) {
1083 output(lhs_vert, horiz_base + i) = results[i].x;
1084 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1085 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1086 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1088 }
else if (lhs_vert + 2 < m_size) {
1089 for (
int i = 0; i < 8; i++) {
1090 output(lhs_vert, horiz_base + i) = results[i].x;
1091 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1092 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1094 }
else if (lhs_vert + 1 < m_size) {
1095 for (
int i = 0; i < 8; i++) {
1096 output(lhs_vert, horiz_base + i) = results[i].x;
1097 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1099 }
else if (lhs_vert < m_size) {
1100 for (
int i = 0; i < 8; i++) {
1101 output(lhs_vert, horiz_base + i) = results[i].x;
1104 }
else if (!CHECK_LHS_BOUNDARY) {
1106 for (
int i = 0; i < 8; i++) {
1107 if (horiz_base + i < n_size) {
1108 output(lhs_vert, horiz_base + i) = results[i].x;
1109 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1110 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1111 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1116 for (
int i = 0; i < 8; i++) {
1117 if (horiz_base + i < n_size) {
1118 if (lhs_vert < m_size)
1119 output(lhs_vert, horiz_base + i) = results[i].x;
1120 if (lhs_vert + 1 < m_size)
1121 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1122 if (lhs_vert + 2 < m_size)
1123 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1124 if (lhs_vert + 3 < m_size)
1125 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1132template<
typename Index,
typename LhsMapper,
1133 typename RhsMapper,
typename OutputMapper>
1135__launch_bounds__(256)
1136EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1137 const OutputMapper output,
1139 __shared__ float2 lhs_shmem[64*32];
1140 __shared__ float2 rhs_shmem[128*8];
1142 typedef float2 LHS_MEM[64][32];
1143 typedef float2 RHS_MEM[128][8];
1145 const Index m_block_idx = blockIdx.x;
1146 const Index n_block_idx = blockIdx.y;
1148 const Index base_m = 128 * m_block_idx;
1149 const Index base_n = 64 * n_block_idx;
1151 bool check_rhs = (base_n + 63) >= n_size;
1152 bool check_lhs128 = (base_m + 127) >= m_size;
1155 if (!check_lhs128) {
1157 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1158 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1160 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1161 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1164 if (!check_lhs128) {
1166 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1167 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1170 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175template<
typename Index,
typename LhsMapper,
1176 typename RhsMapper,
typename OutputMapper>
1178__launch_bounds__(256)
1179EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1180 const OutputMapper output,
1182 __shared__ float2 lhs_shmem[32][16];
1183 __shared__ float2 rhs_shmem[64][8];
1185 const Index m_block_idx = blockIdx.x;
1186 const Index n_block_idx = blockIdx.y;
1188 const Index base_m = 64 * m_block_idx;
1189 const Index base_n = 64 * n_block_idx;
1191 if (base_m + 63 < m_size) {
1192 if (base_n + 63 < n_size) {
1193 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1195 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1198 if (base_n + 63 < n_size) {
1199 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1201 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207template<
typename Indices,
typename LeftArgType,
typename RightArgType>
1209 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
1211 typedef GpuDevice Device;
1213 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
1214 typedef TensorContractionEvaluatorBase<Self> Base;
1216 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
1217 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1218 typedef typename XprType::Index Index;
1219 typedef typename XprType::CoeffReturnType CoeffReturnType;
1220 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1223 Layout = TensorEvaluator<LeftArgType, Device>::Layout
1230 typedef typename internal::conditional<
1231 static_cast<int>(Layout) ==
static_cast<int>(
ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1232 typedef typename internal::conditional<
1233 static_cast<int>(Layout) ==
static_cast<int>(
ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1235 static const int LDims =
1236 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1237 static const int RDims =
1238 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1239 static const int ContractDims = internal::array_size<Indices>::value;
1241 typedef array<Index, LDims> left_dim_mapper_t;
1242 typedef array<Index, RDims> right_dim_mapper_t;
1244 typedef array<Index, ContractDims> contract_t;
1245 typedef array<Index, LDims - ContractDims> left_nocontract_t;
1246 typedef array<Index, RDims - ContractDims> right_nocontract_t;
1248 static const int NumDims = LDims + RDims - 2 * ContractDims;
1250 typedef DSizes<Index, NumDims> Dimensions;
1253 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1254 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1256 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1257 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1259 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1260 typedef typename RightEvaluator::Dimensions RightDimensions;
1262 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device&
device) :
1266 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* data) {
1267 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1268 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1273 this->m_result =
static_cast<Scalar *
>(this->m_device.allocate(this->dimensions().TotalSize() *
sizeof(Scalar)));
1274 evalTo(this->m_result);
1279 void evalTo(Scalar* buffer)
const {
1280 if (this->m_lhs_inner_dim_contiguous) {
1281 if (this->m_rhs_inner_dim_contiguous) {
1282 if (this->m_rhs_inner_dim_reordered) {
1283 evalTyped<true, true, true, Unaligned>(buffer);
1286 evalTyped<true, true, false, Unaligned>(buffer);
1290 if (this->m_rhs_inner_dim_reordered) {
1291 evalTyped<true, false, true, Unaligned>(buffer);
1294 evalTyped<true, false, false, Unaligned>(buffer);
1299 if (this->m_rhs_inner_dim_contiguous) {
1300 if (this->m_rhs_inner_dim_reordered) {
1301 evalTyped<false, true, true, Unaligned>(buffer);
1304 evalTyped<false, true, false, Unaligned>(buffer);
1308 if (this->m_rhs_inner_dim_reordered) {
1309 evalTyped<false, false, true, Unaligned>(buffer);
1312 evalTyped<false, false, false, Unaligned>(buffer);
1318 template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels {
1319 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output, Index m, Index n, Index k,
const GpuDevice& device) {
1320 const Index m_blocks = (m + 63) / 64;
1321 const Index n_blocks = (n + 63) / 64;
1322 const dim3 num_blocks(m_blocks, n_blocks, 1);
1323 const dim3 block_size(8, 8, 8);
1324 LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1328 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1329 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output, Index m, Index n, Index k,
const GpuDevice& device) {
1330 if (m < 768 || n < 768) {
1331 const Index m_blocks = (m + 63) / 64;
1332 const Index n_blocks = (n + 63) / 64;
1333 const dim3 num_blocks(m_blocks, n_blocks, 1);
1334 const dim3 block_size(16, 16, 1);
1335 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1337 const Index m_blocks = (m + 127) / 128;
1338 const Index n_blocks = (n + 63) / 64;
1339 const dim3 num_blocks(m_blocks, n_blocks, 1);
1340 const dim3 block_size(8, 32, 1);
1341 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1346 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1347 void evalTyped(Scalar* buffer)
const {
1349 const Index k = this->m_k_size;
1350 EIGEN_UNUSED_VARIABLE(k)
1353 const Index m = this->m_i_size;
1356 const Index n = this->m_j_size;
1359 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
1361 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1362 LeftEvaluator, left_nocontract_t,
1364 lhs_inner_dim_contiguous,
1367 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1368 RightEvaluator, right_nocontract_t,
1370 rhs_inner_dim_contiguous,
1371 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
1373 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1377 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1378 this->m_left_contracting_strides, this->m_k_strides);
1380 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1381 this->m_right_contracting_strides, this->m_k_strides);
1383 OutputMapper output(buffer, m);
1385 setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
1386 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
Tensor contraction class.
Definition TensorContraction.h:75
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112