Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorContractionGpu.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5// Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
6// Copyright (C) 2014 Eric Martin <eric@ericmart.in>
7//
8// This Source Code Form is subject to the terms of the Mozilla
9// Public License v. 2.0. If a copy of the MPL was not distributed
10// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
14
15#if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
16
17// IWYU pragma: private
18#include "./InternalHeaderCheck.h"
19
20namespace Eigen {
21
22template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper,
23 bool needs_edge_check>
24__device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
25 const OutputMapper output, Scalar* lhs_shmem,
26 Scalar* rhs_shmem, const Index m_size,
27 const Index n_size, const Index k_size) {
28 const Index m_block_idx = blockIdx.x;
29 const Index n_block_idx = blockIdx.y;
30
31 const Index base_m = 64 * m_block_idx;
32 const Index base_n = 64 * n_block_idx;
33
34 // declare and initialize 64 registers for output 8x8 block
35
36 // prefetch registers
37 Scalar lhs_pf0;
38 Scalar lhs_pf1;
39 Scalar lhs_pf2;
40 Scalar lhs_pf3;
41 Scalar lhs_pf4;
42 Scalar lhs_pf5;
43 Scalar lhs_pf6;
44 Scalar lhs_pf7;
45
46 Scalar rhs_pf0;
47 Scalar rhs_pf1;
48 Scalar rhs_pf2;
49 Scalar rhs_pf3;
50 Scalar rhs_pf4;
51 Scalar rhs_pf5;
52 Scalar rhs_pf6;
53 Scalar rhs_pf7;
54
55 // shared memory is formatted
56 // (contract idx in block, nocontract idx in block, block idx)
57 // where block idx is column major. This transposition limits the number of
58 // bank conflicts when reading the LHS. The core idea is that since the contracting
59 // index is shared by both sides, then the contracting index should be in threadIdx.x.
60
61 // On the LHS, we pad each row inside of each block with an extra element. This makes
62 // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
63 // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
64
65 // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
66 // conflicts on writes and also none on reads.
67
68 // storage indices
69 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
70 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
71
72 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
73 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
74 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
75 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
76 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
77 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
78 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
79 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
80
81 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
82 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
83 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
84 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
85 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
86 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
87 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
88 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
89
90 // in the loading code, the following variables are important:
91 // threadIdx.x: the vertical position in an 8x8 block
92 // threadIdx.y: the vertical index of the 8x8 block in the grid
93 // threadIdx.z: the horizontal position in an 8x8 block
94 // k: the horizontal index of the 8x8 block in the grid
95 //
96 // The k parameter is implicit (it was the loop counter for a loop that went
97 // from 0 to <8, but now that loop is unrolled in the below code.
98
99 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
100 const Index lhs_vert = base_m + load_idx_vert;
101
102#define prefetchIntoRegisters(base_k) \
103 { \
104 lhs_pf0 = conv(0); \
105 lhs_pf1 = conv(0); \
106 lhs_pf2 = conv(0); \
107 lhs_pf3 = conv(0); \
108 lhs_pf4 = conv(0); \
109 lhs_pf5 = conv(0); \
110 lhs_pf6 = conv(0); \
111 lhs_pf7 = conv(0); \
112 \
113 rhs_pf0 = conv(0); \
114 rhs_pf1 = conv(0); \
115 rhs_pf2 = conv(0); \
116 rhs_pf3 = conv(0); \
117 rhs_pf4 = conv(0); \
118 rhs_pf5 = conv(0); \
119 rhs_pf6 = conv(0); \
120 rhs_pf7 = conv(0); \
121 \
122 if (!needs_edge_check || lhs_vert < m_size) { \
123 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
124 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
125 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
126 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
127 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
128 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
129 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
130 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
131 \
132 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
133 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
134 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
135 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
136 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
137 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
138 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
139 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
140 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
141 } else if (lhs_horiz_6 < k_size) { \
142 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
143 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
144 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
145 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
146 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
147 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
148 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
149 } else if (lhs_horiz_5 < k_size) { \
150 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
151 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
152 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
153 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
154 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
155 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
156 } else if (lhs_horiz_4 < k_size) { \
157 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
158 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
159 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
160 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
161 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
162 } else if (lhs_horiz_3 < k_size) { \
163 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
164 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
165 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
166 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
167 } else if (lhs_horiz_2 < k_size) { \
168 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
169 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
170 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
171 } else if (lhs_horiz_1 < k_size) { \
172 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
173 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
174 } else if (lhs_horiz_0 < k_size) { \
175 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
176 } \
177 } \
178 \
179 const Index rhs_vert = base_k + load_idx_vert; \
180 if (!needs_edge_check || rhs_vert < k_size) { \
181 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
182 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
183 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
184 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
185 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
186 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
187 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
188 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
189 \
190 if (rhs_horiz_7 < n_size) { \
191 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
192 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
193 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
194 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
195 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
196 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
197 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
198 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
199 } else if (rhs_horiz_6 < n_size) { \
200 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
201 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
202 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
203 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
204 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
205 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
206 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
207 } else if (rhs_horiz_5 < n_size) { \
208 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
209 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
210 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
211 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
212 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
213 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
214 } else if (rhs_horiz_4 < n_size) { \
215 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
216 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
217 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
218 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
219 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
220 } else if (rhs_horiz_3 < n_size) { \
221 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
222 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
223 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
224 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
225 } else if (rhs_horiz_2 < n_size) { \
226 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
227 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
228 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
229 } else if (rhs_horiz_1 < n_size) { \
230 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
231 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
232 } else if (rhs_horiz_0 < n_size) { \
233 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
234 } \
235 } \
236 }
237
238#define writeRegToShmem() \
239 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
240 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
241 \
242 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
243 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
244 \
245 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
246 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
247 \
248 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
249 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
250 \
251 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
252 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
253 \
254 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
255 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
256 \
257 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
258 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
259 \
260 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
261 rhs_shmem[rhs_store_idx_7] = rhs_pf7;
262
263 // declare and initialize result array
264#define res(i, j) _res_##i##j
265#define initResultRow(i) \
266 Scalar res(i, 0) = conv(0); \
267 Scalar res(i, 1) = conv(0); \
268 Scalar res(i, 2) = conv(0); \
269 Scalar res(i, 3) = conv(0); \
270 Scalar res(i, 4) = conv(0); \
271 Scalar res(i, 5) = conv(0); \
272 Scalar res(i, 6) = conv(0); \
273 Scalar res(i, 7) = conv(0);
274
275 internal::scalar_cast_op<int, Scalar> conv;
276 initResultRow(0);
277 initResultRow(1);
278 initResultRow(2);
279 initResultRow(3);
280 initResultRow(4);
281 initResultRow(5);
282 initResultRow(6);
283 initResultRow(7);
284#undef initResultRow
285
286 for (Index base_k = 0; base_k < k_size; base_k += 64) {
287 // wait for previous iteration to finish with shmem. Despite common sense,
288 // the code is a bit faster with this here then at bottom of loop
289 __syncthreads();
290
291 prefetchIntoRegisters(base_k);
292 writeRegToShmem();
293
294#undef prefetchIntoRegisters
295#undef writeRegToShmem
296
297 // wait for shared mem packing to be done before starting computation
298 __syncthreads();
299
300 // compute 8x8 matrix product by outer product. This involves packing one column
301 // of LHS and one row of RHS into registers (takes 16 registers).
302
303#define lcol(i) _lcol##i
304 Scalar lcol(0);
305 Scalar lcol(1);
306 Scalar lcol(2);
307 Scalar lcol(3);
308 Scalar lcol(4);
309 Scalar lcol(5);
310 Scalar lcol(6);
311 Scalar lcol(7);
312
313#define rrow(j) _rrow##j
314 Scalar rrow(0);
315 Scalar rrow(1);
316 Scalar rrow(2);
317 Scalar rrow(3);
318 Scalar rrow(4);
319 Scalar rrow(5);
320 Scalar rrow(6);
321 Scalar rrow(7);
322
323 // Now x corresponds to k, y to m, and z to n
324 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
325 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
326
327#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
328#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
329
330#define loadData(i, j) \
331 lcol(0) = lhs_element(0, j); \
332 rrow(0) = rhs_element(i, 0); \
333 lcol(1) = lhs_element(1, j); \
334 rrow(1) = rhs_element(i, 1); \
335 lcol(2) = lhs_element(2, j); \
336 rrow(2) = rhs_element(i, 2); \
337 lcol(3) = lhs_element(3, j); \
338 rrow(3) = rhs_element(i, 3); \
339 lcol(4) = lhs_element(4, j); \
340 rrow(4) = rhs_element(i, 4); \
341 lcol(5) = lhs_element(5, j); \
342 rrow(5) = rhs_element(i, 5); \
343 lcol(6) = lhs_element(6, j); \
344 rrow(6) = rhs_element(i, 6); \
345 lcol(7) = lhs_element(7, j); \
346 rrow(7) = rhs_element(i, 7);
347
348#define computeCol(j) \
349 res(0, j) += lcol(0) * rrow(j); \
350 res(1, j) += lcol(1) * rrow(j); \
351 res(2, j) += lcol(2) * rrow(j); \
352 res(3, j) += lcol(3) * rrow(j); \
353 res(4, j) += lcol(4) * rrow(j); \
354 res(5, j) += lcol(5) * rrow(j); \
355 res(6, j) += lcol(6) * rrow(j); \
356 res(7, j) += lcol(7) * rrow(j);
357
358#define computePass(i) \
359 loadData(i, i); \
360 \
361 computeCol(0); \
362 computeCol(1); \
363 computeCol(2); \
364 computeCol(3); \
365 computeCol(4); \
366 computeCol(5); \
367 computeCol(6); \
368 computeCol(7);
369
370 computePass(0);
371 computePass(1);
372 computePass(2);
373 computePass(3);
374 computePass(4);
375 computePass(5);
376 computePass(6);
377 computePass(7);
378
379#undef lcol
380#undef rrow
381#undef lhs_element
382#undef rhs_element
383#undef loadData
384#undef computeCol
385#undef computePass
386 } // end loop over k
387
388 // we've now iterated over all of the large (ie width 64) k blocks and
389 // accumulated results in registers. At this point thread (x, y, z) contains
390 // the sum across all big k blocks of the product of little k block of index (x, y)
391 // with block of index (y, z). To compute the final output, we need to reduce
392 // the 8 threads over y by summation.
393#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
394#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
395#else
396#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
397#endif
398
399#define reduceRow(i, mask) \
400 shuffleInc(i, 0, mask); \
401 shuffleInc(i, 1, mask); \
402 shuffleInc(i, 2, mask); \
403 shuffleInc(i, 3, mask); \
404 shuffleInc(i, 4, mask); \
405 shuffleInc(i, 5, mask); \
406 shuffleInc(i, 6, mask); \
407 shuffleInc(i, 7, mask);
408
409#define reduceMatrix(mask) \
410 reduceRow(0, mask); \
411 reduceRow(1, mask); \
412 reduceRow(2, mask); \
413 reduceRow(3, mask); \
414 reduceRow(4, mask); \
415 reduceRow(5, mask); \
416 reduceRow(6, mask); \
417 reduceRow(7, mask);
418
419 // actually perform the reduction, now each thread of index (_, y, z)
420 // contains the correct values in its registers that belong in the output
421 // block
422 reduceMatrix(1);
423 reduceMatrix(2);
424 reduceMatrix(4);
425
426#undef shuffleInc
427#undef reduceRow
428#undef reduceMatrix
429
430 // now we need to copy the 64 values into main memory. We can't split work
431 // among threads because all variables are in registers. There's 2 ways
432 // to do this:
433 // (1) have 1 thread do 64 writes from registers into global memory
434 // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
435 // each do 8 writes into global memory. We can just overwrite the shared
436 // memory from the problem we just solved.
437 // (2) is slightly faster than (1) due to less branching and more ILP
438
439 // TODO: won't yield much gain, but could just use currently unused shared mem
440 // and then we won't have to sync
441 // wait for shared mem to be out of use
442 __syncthreads();
443
444#define writeResultShmem(i, j) lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j);
445
446#define writeRow(i) \
447 writeResultShmem(i, 0); \
448 writeResultShmem(i, 1); \
449 writeResultShmem(i, 2); \
450 writeResultShmem(i, 3); \
451 writeResultShmem(i, 4); \
452 writeResultShmem(i, 5); \
453 writeResultShmem(i, 6); \
454 writeResultShmem(i, 7);
455
456 if (threadIdx.x == 0) {
457 writeRow(0);
458 writeRow(1);
459 writeRow(2);
460 writeRow(3);
461 writeRow(4);
462 writeRow(5);
463 writeRow(6);
464 writeRow(7);
465 }
466#undef writeResultShmem
467#undef writeRow
468
469 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
470 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
471
472 if (threadIdx.x < max_i_write) {
473 if (max_j_write == 8) {
474 // TODO: can i trade bank conflicts for coalesced writes?
475 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
476 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
477 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
478 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
479 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
480 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
481 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
482 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
483
484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
491 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
492 } else {
493#pragma unroll 7
494 for (int j = 0; j < max_j_write; j++) {
495 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
496 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
497 }
498 }
499 }
500#undef res
501}
502
503template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
504__global__ void
505#if defined(EIGEN_HIPCC)
506__launch_bounds__(512, 1)
507#else
508__launch_bounds__(512)
509#endif
510 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
511 const Index n_size, const Index k_size) {
512 __shared__ Scalar lhs_shmem[72 * 64];
513 __shared__ Scalar rhs_shmem[72 * 64];
514
515 const Index m_block_idx = blockIdx.x;
516 const Index n_block_idx = blockIdx.y;
517
518 const Index base_m = 64 * m_block_idx;
519 const Index base_n = 64 * n_block_idx;
520
521 if (base_m + 63 < m_size && base_n + 63 < n_size) {
522 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(
523 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
524 } else {
525 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(
526 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527 }
528}
529
530template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
531 bool CHECK_RHS_BOUNDARY>
532__device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
533 const OutputMapper output,
534 float2 lhs_shmem2[][16],
535 float2 rhs_shmem2[][8], const Index m_size,
536 const Index n_size, const Index k_size,
537 const Index base_m, const Index base_n) {
538 // prefetch registers
539 float4 lhs_pf0, rhs_pf0;
540
541 float4 results[4];
542 for (int i = 0; i < 4; i++) {
543 results[i].x = results[i].y = results[i].z = results[i].w = 0;
544 }
545
546#define prefetch_lhs(reg, row, col) \
547 if (!CHECK_LHS_BOUNDARY) { \
548 if (col < k_size) { \
549 reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
550 } \
551 } else { \
552 if (col < k_size) { \
553 if (row + 3 < m_size) { \
554 reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
555 } else if (row + 2 < m_size) { \
556 reg.x = lhs(row + 0, col); \
557 reg.y = lhs(row + 1, col); \
558 reg.z = lhs(row + 2, col); \
559 } else if (row + 1 < m_size) { \
560 reg.x = lhs(row + 0, col); \
561 reg.y = lhs(row + 1, col); \
562 } else if (row < m_size) { \
563 reg.x = lhs(row + 0, col); \
564 } \
565 } \
566 }
567
568 Index lhs_vert = base_m + threadIdx.x * 4;
569
570 for (Index k = 0; k < k_size; k += 16) {
571 lhs_pf0 = internal::pset1<float4>(0);
572 rhs_pf0 = internal::pset1<float4>(0);
573
574 Index lhs_horiz = threadIdx.y + k;
575 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
576
577 Index rhs_vert = k + (threadIdx.x % 4) * 4;
578 Index rhs_horiz0 = (threadIdx.x >> 2) + threadIdx.y * 4 + base_n;
579
580 if (!CHECK_RHS_BOUNDARY) {
581 if ((rhs_vert + 3) < k_size) {
582 // just CHECK_RHS_BOUNDARY
583 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
584 } else if (rhs_vert + 2 < k_size) {
585 // just CHECK_RHS_BOUNDARY
586 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
587 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
588 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
589 } else if (rhs_vert + 1 < k_size) {
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592 } else if (rhs_vert < k_size) {
593 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
594 }
595 } else {
596 if (rhs_horiz0 < n_size) {
597 if ((rhs_vert + 3) < k_size) {
598 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
599 } else if ((rhs_vert + 2) < k_size) {
600 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
601 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
602 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
603 } else if ((rhs_vert + 1) < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606 } else if (rhs_vert < k_size) {
607 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
608 }
609 }
610 }
611 float x1, x2;
612 // the following can be a bitwise operation..... some day.
613 if ((threadIdx.x % 8) < 4) {
614 x1 = rhs_pf0.y;
615 x2 = rhs_pf0.w;
616 } else {
617 x1 = rhs_pf0.x;
618 x2 = rhs_pf0.z;
619 }
620#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
621 x1 = __shfl_xor(x1, 4);
622 x2 = __shfl_xor(x2, 4);
623#else
624 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
625 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
626#endif
627 if ((threadIdx.x % 8) < 4) {
628 rhs_pf0.y = x1;
629 rhs_pf0.w = x2;
630 } else {
631 rhs_pf0.x = x1;
632 rhs_pf0.z = x2;
633 }
634
635 // We have 64 features.
636 // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
637 // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
638 // ...
639 // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
640 // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
641 // ...
642 rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2][threadIdx.x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
643 rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2 + 32][threadIdx.x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
644
645 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
646 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
647 // ...
648 // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
649 // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
650 // ...
651
652 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
653 lhs_shmem2[threadIdx.y + 16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
654
655#define add_vals(fl1, fl2, fr1, fr2) \
656 results[0].x += fl1.x * fr1.x; \
657 results[0].y += fl1.y * fr1.x; \
658 results[0].z += fl2.x * fr1.x; \
659 results[0].w += fl2.y * fr1.x; \
660 \
661 results[1].x += fl1.x * fr1.y; \
662 results[1].y += fl1.y * fr1.y; \
663 results[1].z += fl2.x * fr1.y; \
664 results[1].w += fl2.y * fr1.y; \
665 \
666 results[2].x += fl1.x * fr2.x; \
667 results[2].y += fl1.y * fr2.x; \
668 results[2].z += fl2.x * fr2.x; \
669 results[2].w += fl2.y * fr2.x; \
670 \
671 results[3].x += fl1.x * fr2.y; \
672 results[3].y += fl1.y * fr2.y; \
673 results[3].z += fl2.x * fr2.y; \
674 results[3].w += fl2.y * fr2.y;
675
676 __syncthreads();
677
678// Do the multiplies.
679#pragma unroll
680 for (int koff = 0; koff < 16; koff++) {
681 // 32 x threads.
682 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
683 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
684
685 int start_feature = threadIdx.y * 4;
686 float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
687 float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
688
689 add_vals(fl1, fl2, fr1, fr2)
690 }
691 __syncthreads();
692 }
693
694#undef prefetch_lhs
695#undef add_vals
696
697 Index horiz_base = threadIdx.y * 4 + base_n;
698 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
699 for (int i = 0; i < 4; i++) {
700 output(lhs_vert, horiz_base + i) = results[i].x;
701 output(lhs_vert + 1, horiz_base + i) = results[i].y;
702 output(lhs_vert + 2, horiz_base + i) = results[i].z;
703 output(lhs_vert + 3, horiz_base + i) = results[i].w;
704 }
705 } else if (!CHECK_RHS_BOUNDARY) {
706 // CHECK LHS
707 if (lhs_vert + 3 < m_size) {
708 for (int i = 0; i < 4; i++) {
709 output(lhs_vert, horiz_base + i) = results[i].x;
710 output(lhs_vert + 1, horiz_base + i) = results[i].y;
711 output(lhs_vert + 2, horiz_base + i) = results[i].z;
712 output(lhs_vert + 3, horiz_base + i) = results[i].w;
713 }
714 } else if (lhs_vert + 2 < m_size) {
715 for (int i = 0; i < 4; i++) {
716 output(lhs_vert, horiz_base + i) = results[i].x;
717 output(lhs_vert + 1, horiz_base + i) = results[i].y;
718 output(lhs_vert + 2, horiz_base + i) = results[i].z;
719 }
720 } else if (lhs_vert + 1 < m_size) {
721 for (int i = 0; i < 4; i++) {
722 output(lhs_vert, horiz_base + i) = results[i].x;
723 output(lhs_vert + 1, horiz_base + i) = results[i].y;
724 }
725 } else if (lhs_vert < m_size) {
726 for (int i = 0; i < 4; i++) {
727 output(lhs_vert, horiz_base + i) = results[i].x;
728 }
729 }
730 } else if (!CHECK_LHS_BOUNDARY) {
731 // CHECK RHS
732 /*
733 int ncols_rem = fminf(n_size- horiz_base, 4);
734 for (int i = 0; i < ncols_rem; i++) {
735 output(lhs_vert, horiz_base + i) = results[i].x;
736 output(lhs_vert + 1, horiz_base + i) = results[i].y;
737 output(lhs_vert + 2, horiz_base + i) = results[i].z;
738 output(lhs_vert + 3, horiz_base + i) = results[i].w;
739 }*/
740 for (int i = 0; i < 4; i++) {
741 if (horiz_base + i < n_size) {
742 output(lhs_vert, horiz_base + i) = results[i].x;
743 output(lhs_vert + 1, horiz_base + i) = results[i].y;
744 output(lhs_vert + 2, horiz_base + i) = results[i].z;
745 output(lhs_vert + 3, horiz_base + i) = results[i].w;
746 }
747 }
748 } else {
749 // CHECK both boundaries.
750 for (int i = 0; i < 4; i++) {
751 if (horiz_base + i < n_size) {
752 if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
753 if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
754 if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
755 if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
756 }
757 }
758 }
759}
760
761template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
762 bool CHECK_RHS_BOUNDARY>
763__device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
764 const OutputMapper output, float2 lhs_shmem2[][32],
765 float2 rhs_shmem2[][8], const Index m_size,
766 const Index n_size, const Index k_size,
767 const Index base_m, const Index base_n) {
768 // prefetch registers
769 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
770 float4 rhs_pf0, rhs_pf1;
771
772 float4 results[8];
773 for (int i = 0; i < 8; i++) {
774 results[i].x = results[i].y = results[i].z = results[i].w = 0;
775 }
776
777 Index lhs_vert = base_m + threadIdx.x * 4 + (threadIdx.y % 4) * 32;
778 for (Index k = 0; k < k_size; k += 32) {
779 lhs_pf0 = internal::pset1<float4>(0);
780 lhs_pf1 = internal::pset1<float4>(0);
781 lhs_pf2 = internal::pset1<float4>(0);
782 lhs_pf3 = internal::pset1<float4>(0);
783
784 rhs_pf0 = internal::pset1<float4>(0);
785 rhs_pf1 = internal::pset1<float4>(0);
786
787 if (!CHECK_LHS_BOUNDARY) {
788 if ((threadIdx.y / 4 + k + 24) < k_size) {
789 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
790 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
791 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
792 lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
793 } else if ((threadIdx.y / 4 + k + 16) < k_size) {
794 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
795 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
796 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
797 } else if ((threadIdx.y / 4 + k + 8) < k_size) {
798 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
799 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
800 } else if ((threadIdx.y / 4 + k) < k_size) {
801 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
802 }
803 } else {
804 // just CHECK_LHS_BOUNDARY
805 if (lhs_vert + 3 < m_size) {
806 if ((threadIdx.y / 4 + k + 24) < k_size) {
807 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
808 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
809 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
810 lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
811 } else if ((threadIdx.y / 4 + k + 16) < k_size) {
812 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
813 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
814 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
815 } else if ((threadIdx.y / 4 + k + 8) < k_size) {
816 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
817 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
818 } else if ((threadIdx.y / 4 + k) < k_size) {
819 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
820 }
821 } else if (lhs_vert + 2 < m_size) {
822 if ((threadIdx.y / 4 + k + 24) < k_size) {
823 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
824 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
825 lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
826 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
827 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
828 lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
829 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
830 lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
831 lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
832 lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
833 lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
834 lhs_pf3.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 24));
835 } else if ((threadIdx.y / 4 + k + 16) < k_size) {
836 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
837 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
838 lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
839 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
840 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
841 lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
842 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
843 lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
844 lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
845 } else if ((threadIdx.y / 4 + k + 8) < k_size) {
846 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
847 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
848 lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
849 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
850 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
851 lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
852 } else if ((threadIdx.y / 4 + k) < k_size) {
853 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
854 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
855 lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
856 }
857 } else if (lhs_vert + 1 < m_size) {
858 if ((threadIdx.y / 4 + k + 24) < k_size) {
859 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
860 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
861 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
862 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
863 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
864 lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
865 lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
866 lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
867 } else if ((threadIdx.y / 4 + k + 16) < k_size) {
868 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
869 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
870 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
871 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
872 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
873 lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
874 } else if ((threadIdx.y / 4 + k + 8) < k_size) {
875 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
876 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
877 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
878 lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
879 } else if ((threadIdx.y / 4 + k) < k_size) {
880 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
881 lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
882 }
883 } else if (lhs_vert < m_size) {
884 if ((threadIdx.y / 4 + k + 24) < k_size) {
885 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
886 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
887 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
888 lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
889 } else if ((threadIdx.y / 4 + k + 16) < k_size) {
890 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
891 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
892 lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
893 } else if ((threadIdx.y / 4 + k + 8) < k_size) {
894 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
895 lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
896 } else if ((threadIdx.y / 4 + k) < k_size) {
897 lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
898 }
899 }
900 }
901 __syncthreads();
902 Index rhs_vert = k + threadIdx.x * 4;
903 Index rhs_horiz0 = threadIdx.y * 2 + base_n;
904 Index rhs_horiz1 = threadIdx.y * 2 + 1 + base_n;
905 if (!CHECK_RHS_BOUNDARY) {
906 if ((rhs_vert + 3) < k_size) {
907 // just CHECK_RHS_BOUNDARY
908 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
909 rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
910 } else if (rhs_vert + 2 < k_size) {
911 // just CHECK_RHS_BOUNDARY
912 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
913 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
914 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
915 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
916 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
917 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
918 } else if (rhs_vert + 1 < k_size) {
919 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
920 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
921 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
922 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
923 } else if (rhs_vert < k_size) {
924 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
925 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
926 }
927 } else {
928 if (rhs_horiz1 < n_size) {
929 if ((rhs_vert + 3) < k_size) {
930 // just CHECK_RHS_BOUNDARY
931 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
932 rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
933 } else if (rhs_vert + 2 < k_size) {
934 // just CHECK_RHS_BOUNDARY
935 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
936 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
937 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
940 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
941 } else if (k + threadIdx.x * 4 + 1 < k_size) {
942 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
943 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
944 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
945 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
946 } else if (k + threadIdx.x * 4 < k_size) {
947 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
948 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
949 }
950 } else if (rhs_horiz0 < n_size) {
951 if ((rhs_vert + 3) < k_size) {
952 // just CHECK_RHS_BOUNDARY
953 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
954 } else if ((rhs_vert + 2) < k_size) {
955 // just CHECK_RHS_BOUNDARY
956 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
957 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
958 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
959 } else if ((rhs_vert + 1) < k_size) {
960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
962 } else if (rhs_vert < k_size) {
963 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
964 }
965 }
966 }
967 __syncthreads();
968 // Loaded. Do computation
969 // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
970 // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
971 // ..
972 // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
973 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
974 // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
975 // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
976 // ..
977 rhs_shmem2[threadIdx.y + 32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
978 // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
979 // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
980 rhs_shmem2[threadIdx.y + 64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
981 // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
982 // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
983 rhs_shmem2[threadIdx.y + 96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
984
985 // LHS.
986 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
987 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
988 // ...
989 // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
990 // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
991
992#define add_vals(a_feat1, a_feat2, f1, f2, f3, f4) \
993 results[0].x += a_feat1.x * f1.x; \
994 results[1].x += a_feat1.x * f1.y; \
995 results[2].x += a_feat1.x * f2.x; \
996 results[3].x += a_feat1.x * f2.y; \
997 results[4].x += a_feat1.x * f3.x; \
998 results[5].x += a_feat1.x * f3.y; \
999 results[6].x += a_feat1.x * f4.x; \
1000 results[7].x += a_feat1.x * f4.y; \
1001 \
1002 results[0].y += a_feat1.y * f1.x; \
1003 results[1].y += a_feat1.y * f1.y; \
1004 results[2].y += a_feat1.y * f2.x; \
1005 results[3].y += a_feat1.y * f2.y; \
1006 results[4].y += a_feat1.y * f3.x; \
1007 results[5].y += a_feat1.y * f3.y; \
1008 results[6].y += a_feat1.y * f4.x; \
1009 results[7].y += a_feat1.y * f4.y; \
1010 \
1011 results[0].z += a_feat2.x * f1.x; \
1012 results[1].z += a_feat2.x * f1.y; \
1013 results[2].z += a_feat2.x * f2.x; \
1014 results[3].z += a_feat2.x * f2.y; \
1015 results[4].z += a_feat2.x * f3.x; \
1016 results[5].z += a_feat2.x * f3.y; \
1017 results[6].z += a_feat2.x * f4.x; \
1018 results[7].z += a_feat2.x * f4.y; \
1019 \
1020 results[0].w += a_feat2.y * f1.x; \
1021 results[1].w += a_feat2.y * f1.y; \
1022 results[2].w += a_feat2.y * f2.x; \
1023 results[3].w += a_feat2.y * f2.y; \
1024 results[4].w += a_feat2.y * f3.x; \
1025 results[5].w += a_feat2.y * f3.y; \
1026 results[6].w += a_feat2.y * f4.x; \
1027 results[7].w += a_feat2.y * f4.y;
1028
1029 lhs_shmem2[threadIdx.y / 4][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1030 lhs_shmem2[threadIdx.y / 4 + 8][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1031 lhs_shmem2[threadIdx.y / 4 + 16][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1032 lhs_shmem2[threadIdx.y / 4 + 24][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1033
1034 lhs_shmem2[threadIdx.y / 4 + 32][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1035 lhs_shmem2[threadIdx.y / 4 + 40][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1036 lhs_shmem2[threadIdx.y / 4 + 48][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1037 lhs_shmem2[threadIdx.y / 4 + 56][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1038
1039 __syncthreads();
1040
1041// Do the multiplies.
1042#pragma unroll
1043 for (int koff = 0; koff < 32; koff++) {
1044 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1045 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1046
1047 // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1048 int start_feature = (threadIdx.y / 4) * 8;
1049
1050 float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
1051 float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
1052 float2 br3 = rhs_shmem2[start_feature / 2 + 2 + (koff % 4) * 32][koff / 4];
1053 float2 br4 = rhs_shmem2[start_feature / 2 + 3 + (koff % 4) * 32][koff / 4];
1054
1055 add_vals(a3, a4, br1, br2, br3, br4)
1056 }
1057 __syncthreads();
1058 } // end loop over k
1059
1060#undef add_vals
1061
1062 __syncthreads();
1063 Index horiz_base = (threadIdx.y / 4) * 8 + base_n;
1064 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1065 for (int i = 0; i < 8; i++) {
1066 output(lhs_vert, horiz_base + i) = results[i].x;
1067 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1068 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1069 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1070 }
1071 } else if (!CHECK_RHS_BOUNDARY) {
1072 if (lhs_vert + 3 < m_size) {
1073 for (int i = 0; i < 8; i++) {
1074 output(lhs_vert, horiz_base + i) = results[i].x;
1075 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1076 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1077 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1078 }
1079 } else if (lhs_vert + 2 < m_size) {
1080 for (int i = 0; i < 8; i++) {
1081 output(lhs_vert, horiz_base + i) = results[i].x;
1082 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1083 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1084 }
1085 } else if (lhs_vert + 1 < m_size) {
1086 for (int i = 0; i < 8; i++) {
1087 output(lhs_vert, horiz_base + i) = results[i].x;
1088 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1089 }
1090 } else if (lhs_vert < m_size) {
1091 for (int i = 0; i < 8; i++) {
1092 output(lhs_vert, horiz_base + i) = results[i].x;
1093 }
1094 }
1095 } else if (!CHECK_LHS_BOUNDARY) {
1096 // CHECK BOUNDARY_B
1097 for (int i = 0; i < 8; i++) {
1098 if (horiz_base + i < n_size) {
1099 output(lhs_vert, horiz_base + i) = results[i].x;
1100 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1102 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1103 }
1104 }
1105 } else {
1106 // CHECK both boundaries.
1107 for (int i = 0; i < 8; i++) {
1108 if (horiz_base + i < n_size) {
1109 if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
1110 if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
1111 if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
1112 if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
1113 }
1114 }
1115 }
1116}
1117
1118template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1119__global__ void
1120#if defined(EIGEN_HIPCC)
1121__launch_bounds__(256, 1)
1122#else
1123__launch_bounds__(256)
1124#endif
1125 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
1126 const Index n_size, const Index k_size) {
1127 __shared__ float2 lhs_shmem[64 * 32];
1128 __shared__ float2 rhs_shmem[128 * 8];
1129
1130 typedef float2 LHS_MEM[64][32];
1131 typedef float2 RHS_MEM[128][8];
1132
1133 const Index m_block_idx = blockIdx.x;
1134 const Index n_block_idx = blockIdx.y;
1135
1136 const Index base_m = 128 * m_block_idx;
1137 const Index base_n = 64 * n_block_idx;
1138
1139 bool check_rhs = (base_n + 63) >= n_size;
1140 bool check_lhs128 = (base_m + 127) >= m_size;
1141
1142 if (!check_rhs) {
1143 if (!check_lhs128) {
1144 // >= 128 rows left
1145 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1146 lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1147 } else {
1148 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1149 lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1150 }
1151 } else {
1152 if (!check_lhs128) {
1153 // >= 128 rows left
1154 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1155 lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1156 } else {
1157 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1158 lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1159 }
1160 }
1161}
1162
1163template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1164__global__ void
1165#if defined(EIGEN_HIPCC)
1166__launch_bounds__(256, 1)
1167#else
1168__launch_bounds__(256)
1169#endif
1170 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output,
1171 const Index m_size, const Index n_size, const Index k_size) {
1172 __shared__ float2 lhs_shmem[32][16];
1173 __shared__ float2 rhs_shmem[64][8];
1174
1175 const Index m_block_idx = blockIdx.x;
1176 const Index n_block_idx = blockIdx.y;
1177
1178 const Index base_m = 64 * m_block_idx;
1179 const Index base_n = 64 * n_block_idx;
1180
1181 if (base_m + 63 < m_size) {
1182 if (base_n + 63 < n_size) {
1183 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1184 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1185 } else {
1186 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1187 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1188 }
1189 } else {
1190 if (base_n + 63 < n_size) {
1191 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1192 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1193 } else {
1194 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1195 lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1196 }
1197 }
1198}
1199
1200template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1201struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice>
1202 : public TensorContractionEvaluatorBase<TensorEvaluator<
1203 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1204 typedef GpuDevice Device;
1205
1206 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1207 typedef TensorContractionEvaluatorBase<Self> Base;
1208
1209 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1210 typedef std::remove_const_t<typename XprType::Scalar> Scalar;
1211 typedef typename XprType::Index Index;
1212 typedef typename XprType::CoeffReturnType CoeffReturnType;
1213 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1214
1215 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
1216
1217 // Most of the code is assuming that both input tensors are ColMajor. If the
1218 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1219 // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1220 // will pretend B is LHS and A is RHS.
1221 typedef std::conditional_t<Layout == static_cast<int>(ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
1222 typedef std::conditional_t<Layout == static_cast<int>(ColMajor), RightArgType, LeftArgType> EvalRightArgType;
1223
1224 static constexpr int LDims =
1225 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1226 static constexpr int RDims =
1227 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1228 static constexpr int ContractDims = internal::array_size<Indices>::value;
1229
1230 typedef array<Index, LDims> left_dim_mapper_t;
1231 typedef array<Index, RDims> right_dim_mapper_t;
1232
1233 typedef array<Index, ContractDims> contract_t;
1234 typedef array<Index, LDims - ContractDims> left_nocontract_t;
1235 typedef array<Index, RDims - ContractDims> right_nocontract_t;
1236
1237 static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
1238
1239 typedef DSizes<Index, NumDims> Dimensions;
1240
1241 // typedefs needed in evalTo
1242 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
1243 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
1244
1245 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1246 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1247
1248 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1249 typedef typename RightEvaluator::Dimensions RightDimensions;
1250
1251 TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {
1252 EIGEN_STATIC_ASSERT((internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1253 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1254 }
1255
1256 // We need to redefine this method to make nvcc happy
1257 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1258 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1259 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1260 if (data) {
1261 evalTo(data);
1262 return false;
1263 } else {
1264 this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1265 evalTo(this->m_result);
1266 return true;
1267 }
1268 }
1269
1270 void evalTo(Scalar* buffer) const {
1271 if (this->m_lhs_inner_dim_contiguous) {
1272 if (this->m_rhs_inner_dim_contiguous) {
1273 if (this->m_rhs_inner_dim_reordered) {
1274 evalTyped<true, true, true, Unaligned>(buffer);
1275 } else {
1276 evalTyped<true, true, false, Unaligned>(buffer);
1277 }
1278 } else {
1279 if (this->m_rhs_inner_dim_reordered) {
1280 evalTyped<true, false, true, Unaligned>(buffer);
1281 } else {
1282 evalTyped<true, false, false, Unaligned>(buffer);
1283 }
1284 }
1285 } else {
1286 if (this->m_rhs_inner_dim_contiguous) {
1287 if (this->m_rhs_inner_dim_reordered) {
1288 evalTyped<false, true, true, Unaligned>(buffer);
1289 } else {
1290 evalTyped<false, true, false, Unaligned>(buffer);
1291 }
1292 } else {
1293 if (this->m_rhs_inner_dim_reordered) {
1294 evalTyped<false, false, true, Unaligned>(buffer);
1295 } else {
1296 evalTyped<false, false, false, Unaligned>(buffer);
1297 }
1298 }
1299 }
1300 }
1301
1302 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper,
1303 typename OutputMapper>
1304 struct LaunchKernels {
1305 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
1306 const GpuDevice& device) {
1307 const Index m_blocks = (m + 63) / 64;
1308 const Index n_blocks = (n + 63) / 64;
1309 const dim3 num_blocks(m_blocks, n_blocks, 1);
1310 const dim3 block_size(8, 8, 8);
1311 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1312 block_size, 0, device, lhs, rhs, output, m, n, k);
1313 }
1314 };
1315
1316 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1317 struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1318 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
1319 const GpuDevice& device) {
1320 if (m < 768 || n < 768) {
1321 const Index m_blocks = (m + 63) / 64;
1322 const Index n_blocks = (n + 63) / 64;
1323 const dim3 num_blocks(m_blocks, n_blocks, 1);
1324 const dim3 block_size(16, 16, 1);
1325 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1326 block_size, 0, device, lhs, rhs, output, m, n, k);
1327 } else {
1328 const Index m_blocks = (m + 127) / 128;
1329 const Index n_blocks = (n + 63) / 64;
1330 const dim3 num_blocks(m_blocks, n_blocks, 1);
1331 const dim3 block_size(8, 32, 1);
1332 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1333 block_size, 0, device, lhs, rhs, output, m, n, k);
1334 }
1335 }
1336 };
1337
1338 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1339 void evalTyped(Scalar* buffer) const {
1340 // columns in left side, rows in right side
1341 const Index k = this->m_k_size;
1342 EIGEN_UNUSED_VARIABLE(k)
1343
1344 // rows in left side
1345 const Index m = this->m_i_size;
1346
1347 // columns in right side
1348 const Index n = this->m_j_size;
1349
1350 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar))
1351 this->m_device.fill(buffer, buffer + m * n, Scalar(0));
1352
1353 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
1354 contract_t, 4, lhs_inner_dim_contiguous, false, Unaligned>
1355 LhsMapper;
1356
1357 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
1358 contract_t, 4, rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1359 Unaligned>
1360 RhsMapper;
1361
1362 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1363
1364 // initialize data mappers
1365 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1366 this->m_left_contracting_strides, this->m_k_strides);
1367
1368 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1369 this->m_right_contracting_strides, this->m_k_strides);
1370
1371 OutputMapper output(buffer, m);
1372
1373#if defined(EIGEN_USE_HIP)
1374 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1375#else
1376 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1377#endif
1378
1379 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k,
1380 this->m_device);
1381 }
1382};
1383
1384} // end namespace Eigen
1385
1386#endif // EIGEN_USE_GPU and EIGEN_GPUCC
1387#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
Definition TensorContraction.h:303
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index