Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorContractionSycl.h
1// This file is part of Eigen, a lightweight C++ template library for linear algebra.
2//
3// Mehdi Goli Codeplay Software Ltd.
4// Ralph Potter Codeplay Software Ltd.
5// Luke Iwanski Codeplay Software Ltd.
6// Contact: <eigen@codeplay.com>
7//
8// This Source Code Form is subject to the terms of the Mozilla Public License v. 2.0. If a copy of the MPL was not
9// distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11/*****************************************************************
12 * TensorContractionSycl.h
13 *
14 * \brief:
15 * TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
16 *
17 *****************************************************************/
18
19#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
21
22// IWYU pragma: private
23#include "./InternalHeaderCheck.h"
24
25namespace Eigen {
26
27namespace TensorSycl {
28namespace internal {
29
30#ifndef EIGEN_SYCL_DISABLE_GEMV
45template <typename Scalar, typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
47 // LocalThreadSizeC: determines total number of thread per workgroup for the contracting dimension
48 static constexpr StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0;
49 // LocalThreadSizeNC: determines total number of thread per workgroup for the non-contracting dimension
50 static constxpr StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1;
51 // TileSizeDimNC: determines the tile size for the non-contracting dimension
52 static constexpr StorageIndex TileSizeDimNC = NCWindow / NCFactor;
53 // TileSizeDimC: determines the tile size for the contracting dimension
54 static constexpr StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC;
55 // WorkLoadPerThreadNC : determines workload per thread for loading the non-contracting dimension
56 static constexpr StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC;
57 // WorkLoadPerThreadC: determines workload per thread for loading the non-contracting dimension
58 static constexpr StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC;
59 // BC : determines if supporting bank conflict is required
60 static constexpr bool BC = false;
61};
62#endif
63
80
81template <typename Scalar, typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
83 // TileSizeDimK: determines Tile size for dimension K. The packet size is assumed to be considered
84 static constexpr StorageIndex TileSizeDimK = TSDK;
85 // WorkLoadPerThreadM : determines workload per thread for loading the M dimension This can be varied based on the
86 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro//
87#ifndef EIGEN_SYCL_REG_M
88 static constexpr StorageIndex WorkLoadPerThreadM = REG_SIZE_M;
89#else
90 static constexpr StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M;
91#endif
92// WorkLoadPerThreadN : determines workload per thread for loading the N dimension This can be varied based on the
93// available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro
94#ifndef EIGEN_SYCL_REG_N
95 static constexpr StorageIndex WorkLoadPerThreadN = REG_SIZE_N;
96#else
97 static constexpr StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N;
98#endif
99 // LocalThreadSizeM: determines total number of thread per workgroup for the m dimension
100 static constexpr StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0;
101 // LocalThreadSizeN: determines total number of thread per workgroup for the n dimension
102 static constexpr StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1;
103 // TileSizeDimM: determines the tile size for the m dimension
104 static constexpr StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM;
105 // TileSizeDimN: determines the tile size for the n dimension
106 static constexpr StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN;
107 // LoadPerThreadLhs: determines workload per thread for loading Lhs Tensor. This must be divisible by packetsize
108 static constexpr StorageIndex LoadPerThreadLhs =
109 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN));
110 // LoadPerThreadRhs: determines workload per thread for loading Rhs Tensor. This must be divisible by packetsize
111 static constexpr StorageIndex LoadPerThreadRhs =
112 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM));
113 // BC : determines if supporting bank conflict is required
114 static constexpr bool BC = true;
115 // DoubleBuffer: determines if double buffering technique should be used (This can be disabled by
116 // EIGEN_SYCL_DISABLE_DOUBLE_BUFFER macro when the device does not have sufficient local memory)
117 static constexpr bool DoubleBuffer =
118#ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
119 false;
120#else
121 true;
122#endif
123};
124
125/* !
126 * \brief contraction_type: an enum class representing the Tensor Contraction implementation algorithm. This is used to
127 * specialize the contraction algorithm based on device support for dedicated local memory.
128 */
129enum class contraction_type { local, no_local };
130/* !
131 * \brief data_source an enum class determining the location of the data in a memory hierarchy (global, local, private).
132 */
133enum class data_source { global_mem, local_mem, private_mem };
134
160template <bool PacketLoad, bool is_coalesced_layout, bool, typename PacketType, typename TensorMapper,
161 typename StorageIndex>
162static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<PacketLoad, PacketType> read(
163 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld) {
164 const StorageIndex row = (is_coalesced_layout) ? NCIndex : CIndex;
165 const StorageIndex col = (is_coalesced_layout) ? CIndex : NCIndex;
166 return tensorMapper.get_tensor().template packet<Unaligned>(row + (col * ld));
167}
168
191template <bool PacketLoad, bool, bool IsRhs, typename PacketType, typename TensorMapper, typename StorageIndex>
192static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<!PacketLoad, PacketType> read(
193 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &) {
194 const StorageIndex row = (IsRhs) ? CIndex : NCIndex;
195 const StorageIndex col = (IsRhs) ? NCIndex : CIndex;
196 return tensorMapper(row, col);
197}
198
219
220template <typename StorageIndex, StorageIndex ld, data_source dt, typename PacketType, typename DataScalar>
221static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<dt != data_source::global_mem, void> write(
222 PacketType &packet_data, DataScalar ptr) {
223 constexpr int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size;
224 EIGEN_UNROLL_LOOP
225 for (int i = 0; i < PacketSize; i++) {
226 *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(i, packet_data);
227 ptr += ld;
228 }
229}
230
245
246template <data_source dt, typename PacketType, typename DataScalar>
247static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
248 typename std::enable_if_t<Eigen::internal::unpacket_traits<PacketType>::size != 1 && dt == data_source::global_mem,
249 void>
250 write(PacketType &packet_data, DataScalar *ptr) {
251 ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
252}
253
267template <data_source dt, typename PacketType, typename DataScalar>
268static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
269 typename std::enable_if_t<Eigen::internal::unpacket_traits<PacketType>::size == 1 && dt == data_source::global_mem,
270 void>
271 write(PacketType &packet_data, DataScalar *ptr) {
272 *ptr = packet_data;
273}
274
280template <bool is_internal>
281EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary(bool) {
282 return true;
283}
284
290template <>
291EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary<false>(bool cond) {
292 return cond;
293}
294
321template <bool is_transposed, bool is_rhs_, bool packet_load_, typename PacketType>
323 static constexpr bool packet_load = packet_load_;
324 typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar;
325 static constexpr bool is_rhs = is_rhs_;
326 typedef std::conditional_t<packet_load, PacketType, OutScalar> OutType;
327 static constexpr int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size;
328 static constexpr bool is_coalesced_layout = !(is_transposed ^ is_rhs);
329 static constexpr int nc_stride = (is_coalesced_layout ? elements_per_access : 1);
330 static constexpr int c_stride = (is_coalesced_layout ? 1 : elements_per_access);
331};
332
372template <typename StorageIndex>
373struct ThreadProperties {
374 const StorageIndex linearLocalThreadId;
375 const StorageIndex kGroupId;
376 const StorageIndex mGroupOffset;
377 const StorageIndex nGroupOffset;
378 const StorageIndex kGroupOffset;
379 const StorageIndex mLocalOffset;
380 const StorageIndex nLocalOffset;
381 const StorageIndex mGlobalOffset;
382 const StorageIndex nGlobalOffset;
383 StorageIndex kSize;
384 const bool is_internal;
385 // this is used to adjust the last block
386 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ThreadProperties(
387 const StorageIndex linearLocalThreadId_, const StorageIndex kGroupId_, const StorageIndex mGroupOffset_,
388 const StorageIndex nGroupOffset_, const StorageIndex kGroupOffset_, const StorageIndex mLocalOffset_,
389 const StorageIndex nLocalOffset_, const StorageIndex mGlobalOffset_, const StorageIndex nGlobalOffset_,
390 StorageIndex kSize_, const bool is_internal_)
391 : linearLocalThreadId(linearLocalThreadId_),
392 kGroupId(kGroupId_),
393 mGroupOffset(mGroupOffset_),
394 nGroupOffset(nGroupOffset_),
395 kGroupOffset(kGroupOffset_),
396 mLocalOffset(mLocalOffset_),
397 nLocalOffset(nLocalOffset_),
398 mGlobalOffset(mGlobalOffset_),
399 nGlobalOffset(nGlobalOffset_),
400 kSize(kSize_),
401 is_internal(is_internal_) {}
402};
403
454template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
455 typename RhsMapper, typename StorageIndex, typename Properties, typename TripleDim, bool Vectorizable,
456 typename input_mapper_properties, bool IsFinal, contraction_type contraction_tp>
457class TensorContractionKernel {
458 public:
459 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
460 PacketReturnType;
461 static constexpr int PacketSize =
462 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
463 static constexpr bool is_lhs_transposed =
464 !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous;
465 static constexpr bool is_rhs_transposed =
466 !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous;
467
468 typedef BlockProperties<is_lhs_transposed, false, input_mapper_properties::is_lhs_matrix && Vectorizable,
469 PacketReturnType>
470 LHSBlockProperties;
471
472 typedef BlockProperties<is_rhs_transposed, true, input_mapper_properties::is_rhs_matrix && Vectorizable,
473 PacketReturnType>
474 RHSBlockProperties;
475
476 static constexpr StorageIndex NStride =
477 contraction_tp == contraction_type::local ? Properties::WorkLoadPerThreadN : RHSBlockProperties::nc_stride;
478
479 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
480 typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr;
481 typedef OutScalar * /*cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::private_space>*/ private_ptr;
482 typedef std::conditional_t<contraction_tp == contraction_type::local, local_ptr, private_ptr> tile_ptr;
483 static constexpr StorageIndex LSDL = contraction_tp == contraction_type::local
484 ? Properties::TileSizeDimM + Properties::BC
485 : Properties::WorkLoadPerThreadM;
486 static constexpr StorageIndex LSDR = contraction_tp == contraction_type::local
487 ? Properties::TileSizeDimN + Properties::BC
488 : Properties::WorkLoadPerThreadN;
489 static constexpr StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
490
503 template <contraction_type, StorageIndex>
504 struct MemHolder {
505 tile_ptr ptr;
506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {}
507 };
508
511 template <StorageIndex MemSize>
512 struct MemHolder<contraction_type::no_local, MemSize> {
513 OutScalar ptr[MemSize] = {OutScalar{0}};
514 };
515
537 struct TiledMemory {
540 tile_ptr lhs_scratch_ptr_compute;
541 tile_ptr rhs_scratch_ptr_compute;
542 const std::pair<StorageIndex, StorageIndex> lhs_extract_index;
543 const std::pair<StorageIndex, StorageIndex> rhs_extract_index;
544 template <contraction_type tp = contraction_tp>
545 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TiledMemory(const ThreadProperties<StorageIndex> &, local_ptr,
546 std::enable_if_t<tp == contraction_type::no_local> * = 0)
547 : lhs_scratch_extract{},
548 rhs_scratch_extract{},
549 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr),
550 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr),
551 lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
552 rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
553
554 template <contraction_type tp = contraction_tp>
555 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TiledMemory(const ThreadProperties<StorageIndex> &thread_properties,
556 local_ptr block_start_ptr,
557 std::enable_if_t<tp == contraction_type::local> * = 0)
558 : lhs_scratch_extract{block_start_ptr},
559 rhs_scratch_extract{lhs_scratch_extract.ptr +
560 ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)},
561 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset),
562 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset),
563 lhs_extract_index(
564 local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
565 rhs_extract_index(
566 local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
567 };
568
569 Scratch scratch;
570 const LhsMapper lhs;
571 const RhsMapper rhs;
572 OutAccessor out_res;
573 const StorageIndex groupSizeM;
574 const StorageIndex groupSizeN;
575 const StorageIndex numTiles;
576 const TripleDim triple_dim;
577
578 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
579 const RhsMapper rhs_, OutAccessor out_res_,
580 const StorageIndex groupSizeM_,
581 const StorageIndex groupSizeN_,
582 const StorageIndex numTiles_,
583 const TripleDim triple_dim_)
584 : scratch(scratch_),
585 lhs(lhs_),
586 rhs(rhs_),
587 out_res(out_res_),
588 groupSizeM(groupSizeM_),
589 groupSizeN(groupSizeN_),
590 numTiles(numTiles_),
591 triple_dim(triple_dim_) {}
592
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
594 const RhsMapper rhs_, OutAccessor out_res_,
595 const StorageIndex groupSizeM_,
596 const StorageIndex numTiles_,
597 const TripleDim triple_dim_)
598 : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {}
599
600 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) const {
601 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
602 const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
603 const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
604 const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM;
605 const StorageIndex tmp = itemID.get_group(0) / groupSizeM;
606 const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN;
607 const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN;
608 const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
609 const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
610 const StorageIndex mLocalOffset = PacketSize * mLocalThreadId;
611 const StorageIndex nLocalOffset = NStride * nLocalThreadId;
612 const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
613 const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
614
615 const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK;
616 StorageIndex kGroupOffset = kGroupId * kSizePerWG;
617 const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
618 triple_dim.N - nGroupOffset >= Properties::TileSizeDimN &&
619 triple_dim.K - kGroupOffset >= kSizePerWG;
620 // this is used to adjust the last block
621 StorageIndex kSize = IsFinal ? triple_dim.K : std::min(kSizePerWG, triple_dim.K - kGroupOffset);
622 // This is used to find out the lats K offset so that kGroupOffset -kSize can compute the coffset for loading to
623 // tile
624 kGroupOffset += kSize;
625
626 auto thread_properties =
627 ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset,
628 mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
629
630 auto out_ptr = out_res + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N);
631
632 (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
633 : compute_panel<false>(itemID, thread_properties, out_ptr);
634 }
635 // The compute block computes the contraction operation private block for each thread and store the resutl in the
636 // privateRes memory of Each computation the compute block function is independent of local and no local concepts as
637 // it only compute the block on each thread's private memory space
638 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr,
639 PacketReturnType *privateRes) const {
640 StorageIndex idx = 0;
641 constexpr StorageIndex lhs_stride =
642 contraction_tp == contraction_type::local ? (PacketSize * Properties::LocalThreadSizeM) : 1;
643 EIGEN_UNROLL_LOOP
644 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
645 auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)};
646 StorageIndex lhs_index = 0;
647 EIGEN_UNROLL_LOOP
648 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
649 PacketReturnType lhsPack{};
650 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack,
651 lhs_block_ptr + lhs_index);
652 privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]);
653
654 lhs_index += lhs_stride;
655 idx++;
656 }
657 }
658 }
659 // The store function write the computed contraction operation in the private memory of each thread to the global
660 // memory. The store function is independent of local and no local concepts s that it can be abstract out in the base
661 // class.
662 template <bool is_internal_block, StorageIndex PrivateNStride, typename OutPtr>
663 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void store(OutPtr *out_ptr, PacketReturnType *privateRes,
664 StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) const {
665 auto chk_bound = [&](const StorageIndex &mIndex, const StorageIndex &nIndex) EIGEN_DEVICE_FUNC {
666 return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N);
667 };
668 // when local memory is not used M and N are both accessed in a coalesced way. However, when local memory is
669 // available the k*N is transposed in the local to N*K therefore, each blocks operates on blockId*
670 // WorkLoadPerThreadN slice of N
671 constexpr StorageIndex GlobalNStride = contraction_tp == contraction_type::local ? 1 : Properties::LocalThreadSizeN;
672 EIGEN_UNROLL_LOOP
673 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
674 // output leading dimension
675 StorageIndex outputLD = 0;
676 // When local memory is used the PrivateNstride is always 1 because the coalesced access on N is loaded into Local
677 // memory and extracting from local to global is the same as no transposed version. However, when local memory is
678 // not used and RHS is transposed we packetize the load for RHS.
679 EIGEN_UNROLL_LOOP
680 for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
681 StorageIndex globalRow = mGlobalOffset;
682 EIGEN_UNROLL_LOOP
683 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
684 PacketReturnType privetOut = privateRes[wLPTM];
685 if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
686 // Store the final results in C. The C matrix has always M as a first StorageIndex and N as a second
687 // StorageIndex Therefore it is always coalesced layout
688 write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
689 } else {
690 EIGEN_UNROLL_LOOP
691 for (StorageIndex mId = 0; mId < PacketSize; mId++) {
692 StorageIndex mOffset = globalRow + mId;
693 if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) {
694 out_ptr[mOffset + outputLD] =
695 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut);
696 }
697 }
698 }
699 globalRow += (PacketSize * Properties::LocalThreadSizeM);
700 }
701 outputLD += triple_dim.M;
702 privateRes += Properties::WorkLoadPerThreadM / PacketSize;
703 }
704 out_ptr += (GlobalNStride * outputLD);
705
706 nGlobalOffset += (PrivateNStride * GlobalNStride);
707 }
708 }
709 // when no local memory is used the following extract_block will be enabled
710 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename PrivateReg,
711 contraction_type contract_tp = contraction_tp>
712 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<contract_tp == contraction_type::no_local> extract_block(
713 const Input &inpt, PrivateReg private_ptr, const std::pair<StorageIndex, StorageIndex> &,
714 const StorageIndex &ncOffset, const StorageIndex cOffset) const {
715 constexpr StorageIndex LocalThreadSizeNC =
716 InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
717 constexpr StorageIndex WorkLoadPerThreadNC =
718 InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
719 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
720
721 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
722 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
723 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
724 };
725 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
726 StorageIndex cIndex = cOffset;
727
728 EIGEN_UNROLL_LOOP
729 for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
730 StorageIndex ncIndex = ncOffset;
731 EIGEN_UNROLL_LOOP
732 for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
733 if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
734 auto val =
735 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
736 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
737
738 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
739 data_source::private_mem>(val, private_ptr);
740 } else {
741 EIGEN_UNROLL_LOOP
742 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
743 const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
744 const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
745 OutScalar val =
746 (ncInd < NC && cInd < triple_dim.K)
747 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
748 inpt, ncInd, cInd, ld)
749 : OutScalar(0);
750 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
751 data_source::private_mem>(
752 val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) +
753 ((InputBlockProperties::is_coalesced_layout ? 0 : i) * WorkLoadPerThreadNC));
754 }
755 }
756
757 // if it is lhs we have to load it packetised when the packet size is > 1, because the output is coalesced. So
758 // even if M is not accessed in a coalesced mode, we have to load packet_size number of m per thread.
759 ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1)
760 ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC
761 : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
762 private_ptr += InputBlockProperties::nc_stride;
763 }
764 // the previous for loop ( private_ptr += (ncId * nc_stride)) has already moved ptr with one WorkLoadPerThreadNC
765 private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
766 cIndex += InputBlockProperties::c_stride;
767 }
768 }
769 template <typename InputBlockProperties, StorageIndex TileSizeDimNC>
770 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract(
771 const StorageIndex &linearLocalThreadId) {
772 const StorageIndex localThreadNC =
773 (InputBlockProperties::is_coalesced_layout)
774 ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
775 : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
776 const StorageIndex localThreadC =
777 (InputBlockProperties::is_coalesced_layout)
778 ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
779 : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
780 return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
781 }
782
783 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
784 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<db && ctp == contraction_type::local> sync_mem(
785 const cl::sycl::nd_item<1> &, bool &db_offset) noexcept {
786 db_offset = !db_offset;
787 }
788
789 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
790 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<!db && ctp == contraction_type::local> sync_mem(
791 const cl::sycl::nd_item<1> &itemID, bool &) noexcept {
792 itemID.barrier(cl::sycl::access::fence_space::local_space);
793 }
794
795 template <contraction_type ctp = contraction_tp>
796 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<ctp == contraction_type::no_local> sync_mem(
797 const cl::sycl::nd_item<1> &, bool &) noexcept {
798 return;
799 }
800
801 template <bool need_sync, contraction_type ctp = contraction_tp>
802 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<need_sync && ctp == contraction_type::no_local>
803 sync_thread(const cl::sycl::nd_item<1> &
804#ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
805 itemID
806#endif
807 ) noexcept {
808#ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
809 itemID.barrier(cl::sycl::access::fence_spacce::local_space);
810#else
811 return;
812#endif
813 }
814 template <bool need_sync, contraction_type ctp = contraction_tp>
815 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<need_sync && ctp == contraction_type::local>
816 sync_thread(const cl::sycl::nd_item<1> &itemID) {
817 itemID.barrier(cl::sycl::access::fence_space::local_space);
818 }
819 template <bool need_sync>
820 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<!need_sync> sync_thread(const cl::sycl::nd_item<1> &) {
821 return;
822 }
823
824 template <bool is_internal_block>
825 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_tile_per_panel(const cl::sycl::nd_item<1> &itemID,
826 ThreadProperties<StorageIndex> &thread_properties,
827 TiledMemory &tiled_input_block,
828 PacketReturnType *privateRes,
829 bool &db_offset) const {
830 // Tiling the Rhs block from global to local memory
831 extract_block<RHSBlockProperties, is_internal_block>(
832 rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR),
833 tiled_input_block.rhs_extract_index,
834 contraction_tp == contraction_type::local ? thread_properties.nGroupOffset : thread_properties.nGlobalOffset,
835 thread_properties.kGroupOffset - thread_properties.kSize);
836
837 sync_thread<contraction_tp == contraction_type::no_local>(itemID);
838
839 // Tiling the Lhs block from global to local memory
840 extract_block<LHSBlockProperties, is_internal_block>(
841 lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK),
842 tiled_input_block.lhs_extract_index,
843 contraction_tp == contraction_type::local ? thread_properties.mGroupOffset : thread_properties.mGlobalOffset,
844 thread_properties.kGroupOffset - thread_properties.kSize);
845
846 // itemID.barrier(cl::sycl::access::fence_space::local_space);
847 sync_thread<contraction_tp == contraction_type::local>(itemID);
848 // switch to compute mede
849 StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK);
850 StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR);
851 // Loop over the values of a single tile
852 for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
853 compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset,
854 tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
855 lhs_offset += LSDL;
856 rhs_offset += LSDR;
857 }
858 // computing the K index for the next tile
859 thread_properties.kSize -= Properties::TileSizeDimK;
860 sync_mem(itemID, db_offset);
861 }
862
863 // when local memory is available the following compute_panel will be enabled
864 template <bool is_internal_block, typename OutPtr>
865 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(const cl::sycl::nd_item<1> &itemID,
866 ThreadProperties<StorageIndex> &thread_properties,
867 OutPtr out_ptr) const {
868 auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()};
869 // Allocate register space
870 PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = {
871 PacketReturnType{0}};
872 bool db_offset = 0;
873
874 while (thread_properties.kSize >= Properties::TileSizeDimK) {
875 compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
876 }
877 if (thread_properties.kSize > 0) {
878 compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
879 }
880
881 // Storing the final results in the output
882 store<is_internal_block,
883 contraction_tp == contraction_type::local ? static_cast<StorageIndex>(1) : RHSBlockProperties::nc_stride>(
884 out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset,
885 thread_properties.nGlobalOffset);
886 }
887 // When local memory is available the following extract_block will be enabled
888 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename Local,
889 contraction_type contract_tp = contraction_tp>
890 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<contract_tp == contraction_type::local> extract_block(
891 const Input &inpt, Local local_ptr, const std::pair<StorageIndex, StorageIndex> &local_index,
892 const StorageIndex &ncOffset, const StorageIndex cOffset) const {
893 constexpr StorageIndex TileSizeDimNC =
894 InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
895 constexpr StorageIndex LoadPerThread =
896 InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
897 constexpr StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL;
898 static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
899 (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
900 " LocalOffset must be divisible by stride");
901 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
902 StorageIndex localThreadNC = local_index.first;
903 StorageIndex localThreadC = local_index.second;
904 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
905 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
906 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
907 };
908 EIGEN_UNROLL_LOOP
909 for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
910 const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
911 const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
912 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
913 if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
914 auto val =
915 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
916 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
917 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
918 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
919 (InputBlockProperties::c_stride * localThreadC * LSD));
920 } else {
921 EIGEN_UNROLL_LOOP
922 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
923 const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
924 const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
925 OutScalar val =
926 (nCInd < NC && cInd < triple_dim.K)
927 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
928 inpt, nCInd, cInd, ld)
929 : OutScalar(0);
930
931 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
932 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
933 (InputBlockProperties::is_coalesced_layout ? i : 0) +
934 ((InputBlockProperties::c_stride * localThreadC +
935 (InputBlockProperties::is_coalesced_layout ? 0 : i)) *
936 LSD));
937 }
938 }
939 localThreadNC += (InputBlockProperties::is_coalesced_layout)
940 ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
941 : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
942 localThreadC += (InputBlockProperties::is_coalesced_layout)
943 ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
944 : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
945 }
946 }
947};
948
949#ifndef EIGEN_SYCL_DISABLE_GEMV
950
992template <typename OutScalar, typename OutAccessor, typename VectorMapper, typename TensorMapper, typename StorageIndex,
993 typename Properties, StorageIndex KFactor, bool Vectorizable, bool is_lhs_vec, bool IsFinal>
994struct GeneralVectorTensor {
995 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
996 PacketReturnType;
997 static constexpr int PacketSize =
998 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
999 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1000
1001 static constexpr StorageIndex OutScratchOffset =
1002 KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1003
1004 // Since the access layout for a vector can always be coalesced, when LHS is a vector, we pass false and false to make
1005 // sure that the !^ is true When RHS is a vector, we pass true and true to make sure that the !^ is true.
1007 VecBlockProperties;
1008
1009 Scratch scratch;
1010 const VectorMapper vec;
1011 const TensorMapper mat;
1012 OutAccessor out_res;
1013 const StorageIndex nonContractGroupSize;
1014 const StorageIndex nonContractDim;
1015 const StorageIndex contractDim;
1016
1017 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_, const VectorMapper vec_,
1018 const TensorMapper mat_, OutAccessor out_res_,
1019 const StorageIndex nonContractGroupSize_,
1020 const StorageIndex nonContractDim_,
1021 const StorageIndex contractDim_)
1022 : scratch(scratch_),
1023 vec(vec_),
1024 mat(mat_),
1025 out_res(out_res_),
1026 nonContractGroupSize(nonContractGroupSize_),
1027 nonContractDim(nonContractDim_),
1028 contractDim(contractDim_) {}
1029
1030 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) const {
1031 auto scratch_ptr = scratch.get_pointer();
1032 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1033 StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1034 : linearLocalThreadId % Properties::LocalThreadSizeNC;
1035 StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1036 : linearLocalThreadId / Properties::LocalThreadSizeNC;
1037 const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize;
1038 const StorageIndex nonContractGroupId =
1039 is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize;
1040 const StorageIndex contractGroupId =
1041 is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize;
1042 auto out_ptr = out_res + (IsFinal ? 0 : contractGroupId * nonContractDim);
1043
1044 const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1045 const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1046 auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1047 const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1048 const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1049 auto local_output = scratch_ptr + OutScratchOffset;
1050 const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1051 contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1052 is_internal
1053 ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr,
1054#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1055 scratch_ptr, contractGroupOffset,
1056#endif
1057 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1058 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1059 : compute_panel<false>(itemID, vec, mat, local_output, out_ptr,
1060#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1061 scratch_ptr, contractGroupOffset,
1062#endif
1063 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1064 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1065 }
1066 template <bool is_internal_block, typename OutPtr>
1067 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(
1068 const cl::sycl::nd_item<1> &itemID, const VectorMapper &vec, const TensorMapper &mat, OutScalar *local_output,
1069 OutPtr out_ptr,
1070#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1071 OutScalar *scratch_ptr, const StorageIndex contractGroupOffset,
1072#endif
1073 const StorageIndex nonContractGroupOffset, const StorageIndex linearLocalThreadId, StorageIndex contractDim,
1074 StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1075 StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1076 OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1077 // Reading the vector
1078#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1079 const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1080 extract_block<VecBlockProperties, is_internal_block, KFactor,
1081 Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId,
1082 vectorOffset, contractDim);
1083
1084 itemID.barrier(cl::sycl::access::fence_space::local_space);
1085 auto in_scratch_ptr = scratch_ptr + contractId;
1086#endif
1087
1088 StorageIndex privateOffsetC = 0;
1089 EIGEN_UNROLL_LOOP
1090 for (StorageIndex i = 0; i < Properties::WorkLoadPerThreadC; i++) {
1091 StorageIndex privateOffsetNC = 0;
1092 bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim);
1093#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1094 auto vecScalar = *in_scratch_ptr;
1095#else
1096 auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1097 ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1098 is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1099 : OutScalar(0);
1100#endif
1101 EIGEN_UNROLL_LOOP
1102 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1103 auto matScalar = (check_boundary<is_internal_block>(
1104 contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim)))
1105 ? mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1106 : globalNonContractDimOffset + privateOffsetNC,
1107 is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1108 : globalContractDimOffset + privateOffsetC)
1109 : OutScalar(0);
1110
1111 outScalar[j] = ::Eigen::internal::pmadd(matScalar, vecScalar, outScalar[j]);
1112 privateOffsetNC += Properties::LocalThreadSizeNC;
1113 }
1114 privateOffsetC += Properties::LocalThreadSizeC;
1115#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1116 in_scratch_ptr += Properties::LocalThreadSizeC;
1117#endif
1118 }
1119
1120 auto out_scratch_ptr = local_output + outScratchIndex;
1121 // Each block of 16*16 element in shared memory should reduce to 16*1
1122 EIGEN_UNROLL_LOOP
1123 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1124 *out_scratch_ptr = outScalar[j];
1125
1126 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1127 }
1128 if (is_lhs_vec) {
1129 nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1130 contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1131 outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1132 }
1133
1134 out_scratch_ptr = local_output + outScratchIndex;
1135 EIGEN_UNROLL_LOOP
1136 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1137 EIGEN_UNROLL_LOOP
1138 for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) {
1139 itemID.barrier(cl::sycl::access::fence_space::local_space);
1140 if (contractId < offset) {
1141 StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset);
1142 *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1143 }
1144 }
1145 // moving to next 16 by 16 block
1146 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1147 }
1148
1149 if (contractId == 0) {
1150 out_scratch_ptr = local_output + nonContractId;
1151 StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1152 out_ptr += global_final_offset;
1153 EIGEN_UNROLL_LOOP
1154 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1155 if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) {
1156 auto res = *out_scratch_ptr;
1157
1158 *out_ptr = res;
1159 out_ptr += Properties::LocalThreadSizeNC;
1160 }
1161 // moving to next 16 by 16 block to ge the next 16 reduced elements
1162 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1163 if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1164 }
1165 }
1166 }
1167
1168 template <typename InputBlockProperties, bool is_internal_block, int CFactor, int GroupSize, typename Input,
1169 typename Local>
1170 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_block(const Input &inpt, Local *local_ptr,
1171 const StorageIndex &linearLocalThreadId,
1172 const StorageIndex &cOffset, const StorageIndex &C) {
1173 local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1174 StorageIndex cIndex = cOffset;
1175 for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1176 if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) {
1177 auto val = read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1178 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1179 cIndex, StorageIndex(1));
1180 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1181 } else {
1182 EIGEN_UNROLL_LOOP
1183 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
1184 OutScalar val =
1185 (cIndex + i < C)
1186 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1187 inpt, StorageIndex(0), cIndex + i, StorageIndex(1))
1188 : OutScalar(0);
1189 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i);
1190 }
1191 }
1192 local_ptr += InputBlockProperties::c_stride * GroupSize;
1193 cIndex += InputBlockProperties::c_stride * GroupSize;
1194 }
1195 }
1196};
1197#endif
1198
1199#ifndef EIGEN_SYCL_DISABLE_SCALAR
1200
1232template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
1233 typename RhsMapper, typename StorageIndex, bool Vectorizable>
1234struct GeneralScalarContraction {
1235 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1236 Scratch scratch;
1237 const LhsMapper lhs;
1238 const RhsMapper rhs;
1239 OutAccessor out_res;
1240 const StorageIndex rng;
1241
1242 EIGEN_DEVICE_FUNC GeneralScalarContraction(Scratch scratch_, const LhsMapper lhs_, const RhsMapper rhs_,
1243 OutAccessor out_res_, const StorageIndex rng_)
1244 : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {}
1245
1246 EIGEN_DEVICE_FUNC void operator()(cl::sycl::nd_item<1> itemID) const {
1247 auto out_ptr = out_res;
1248 OutScalar *scratch_ptr = scratch.get_pointer();
1249
1250 StorageIndex globalid = itemID.get_global_id(0);
1251 StorageIndex localid = itemID.get_local_id(0);
1252 OutScalar accumulator = OutScalar(0);
1253 for (StorageIndex i = globalid; i < rng; i += itemID.get_global_range(0)) {
1254 accumulator = Eigen::internal::pmadd(lhs(0, i), rhs(i, 0), accumulator);
1255 }
1256 auto out_scratch_ptr = scratch_ptr + localid;
1257 *out_scratch_ptr = accumulator;
1258 for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) {
1259 itemID.barrier(cl::sycl::access::fence_space::local_space);
1260 if (localid < offset) {
1261 *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]);
1262 }
1263 }
1264 if (localid == 0) {
1265 out_ptr[itemID.get_group(0)] = accumulator;
1266 }
1267 }
1268};
1269#endif
1270
1271} // namespace internal
1272} // namespace TensorSycl
1273
1274template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1275struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
1276 Eigen::SyclDevice>
1277 : public TensorContractionEvaluatorBase<TensorEvaluator<
1278 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1279 static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1280 "SYCL tensor contraction does not support output kernels.");
1281
1282 typedef Eigen::SyclDevice Device;
1283
1285 typedef TensorContractionEvaluatorBase<Self> Base;
1287 typedef std::remove_const_t<typename XprType::Scalar> Scalar;
1288 typedef typename XprType::Index StorageIndex;
1289 typedef typename XprType::CoeffReturnType CoeffReturnType;
1290 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
1291 typedef typename Base::Storage Storage;
1292 typedef typename Base::EvaluatorPointerType EvaluatorPointerType;
1293 struct TripleDim {
1294 const StorageIndex M;
1295 const StorageIndex N;
1296 const StorageIndex K;
1297 TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_) : M(M_), N(N_), K(K_) {}
1298 };
1299 enum {
1300 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
1301 BlockAccess = false,
1302 };
1303
1304 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
1305 static constexpr int LDims = Base::LDims;
1306 static constexpr int RDims = Base::RDims;
1307 static constexpr int ContractDims = Base::ContractDims;
1308
1309 typedef array<StorageIndex, LDims> left_dim_mapper_t;
1310 typedef array<StorageIndex, RDims> right_dim_mapper_t;
1311
1312 typedef array<StorageIndex, ContractDims> contract_t;
1313 typedef array<StorageIndex, LDims - ContractDims> left_nocontract_t;
1314 typedef array<StorageIndex, RDims - ContractDims> right_nocontract_t;
1315
1316 static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
1317
1318 typedef DSizes<StorageIndex, NumDims> Dimensions;
1319
1320 typedef TensorEvaluator<typename Base::EvalLeftArgType, Device> LeftEvaluator;
1321 typedef TensorEvaluator<typename Base::EvalRightArgType, Device> RightEvaluator;
1322 typedef std::remove_const_t<typename LeftEvaluator::CoeffReturnType> LhsScalar;
1323 typedef std::remove_const_t<typename RightEvaluator::CoeffReturnType> RhsScalar;
1324
1325 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1326 typedef typename RightEvaluator::Dimensions RightDimensions;
1327
1328 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered>
1329 struct input_mapper_propertis {
1330 static constexpr bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1331 static constexpr bool is_rhs_matrix =
1332 (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1333 };
1334
1335 TensorEvaluator(const XprType &op, const Device &device) : Base(op, device) {}
1336
1337 // We need to redefine this method to make nvcc happy
1338 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data) {
1339 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1340 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1341 if (!data) {
1342 this->m_result = this->m_device.get(
1343 static_cast<Scalar *>(this->m_device.allocate_temp(this->dimensions().TotalSize() * sizeof(Scalar))));
1344 data = this->m_result;
1345 }
1346 evalToSycl(data);
1347 return (this->m_result != NULL);
1348 }
1349 const Eigen::SyclDevice &device() const { return this->m_device; }
1350 void evalToSycl(typename Base::EvaluatorPointerType buffer) const {
1351 if (this->m_lhs_inner_dim_contiguous) {
1352 if (this->m_rhs_inner_dim_contiguous) {
1353 if (this->m_rhs_inner_dim_reordered) {
1354 evalTyped<true, true, true, Unaligned>(buffer);
1355 } else {
1356 evalTyped<true, true, false, Unaligned>(buffer);
1357 }
1358 } else {
1359 if (this->m_rhs_inner_dim_reordered) {
1360 evalTyped<true, false, true, Unaligned>(buffer);
1361 } else {
1362 evalTyped<true, false, false, Unaligned>(buffer);
1363 }
1364 }
1365 } else {
1366 if (this->m_rhs_inner_dim_contiguous) {
1367 if (this->m_rhs_inner_dim_reordered) {
1368 evalTyped<false, true, true, Unaligned>(buffer);
1369 } else {
1370 evalTyped<false, true, false, Unaligned>(buffer);
1371 }
1372 } else {
1373 if (this->m_rhs_inner_dim_reordered) {
1374 evalTyped<false, false, true, Unaligned>(buffer);
1375 } else {
1376 evalTyped<false, false, false, Unaligned>(buffer);
1377 }
1378 }
1379 }
1380 }
1381
1382 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1383 void evalTyped(typename Base::EvaluatorPointerType buffer) const {
1384 const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1385 typedef internal::TensorContractionInputMapper<
1386 LhsScalar, StorageIndex, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t,
1387 PacketType<CoeffReturnType, Device>::size, lhs_inner_dim_contiguous, false, Unaligned, MakePointer>
1388 LhsMapper;
1389
1390 typedef internal::TensorContractionInputMapper<RhsScalar, StorageIndex, internal::Rhs, RightEvaluator,
1391 right_nocontract_t, contract_t,
1392 PacketType<CoeffReturnType, Device>::size, rhs_inner_dim_contiguous,
1393 rhs_inner_dim_reordered, Unaligned, MakePointer>
1394 RhsMapper;
1395
1396 // initialize data mappers
1397 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1398 this->m_left_contracting_strides, this->m_k_strides);
1399
1400 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1401 this->m_right_contracting_strides, this->m_k_strides);
1402
1403#ifndef EIGEN_SYCL_DISABLE_SCALAR
1404 if (triple_dim.M == 1 && triple_dim.N == 1) {
1405 launchSC(buffer, lhs, rhs, triple_dim.K);
1406 } else
1407#endif
1408#ifndef EIGEN_SYCL_DISABLE_GEMV
1409 if (triple_dim.M != 1 && triple_dim.N == 1) {
1410 LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1411 } else if (triple_dim.M == 1 && triple_dim.N != 1) {
1412 LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1413 } else // This is equivalent of if (m!=1 && n!=1)
1414#endif
1415 {
1416 typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1417 inpt_mapper_properties;
1418#ifndef EIGEN_SYCL_DISABLE_SKINNY
1419 bool skinny = false;
1420 auto platform_name = this->device().getPlatformName();
1421 // This is based on empirical calculation for AMD r9-nano and Fiji
1422 if (platform_name.find("AMD") == 0) {
1423 skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1424 ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1425 (uint64_t(triple_dim.M * triple_dim.N) < uint64_t(triple_dim.K)));
1426 } else {
1427 skinny = (((std::max(triple_dim.K, triple_dim.N) / std::min(triple_dim.K, triple_dim.N)) > 100) ||
1428 ((std::max(triple_dim.K, triple_dim.M) / std::min(triple_dim.K, triple_dim.M)) > 100) ||
1429 ((std::max(triple_dim.N, triple_dim.M) / std::min(triple_dim.N, triple_dim.M)) > 100));
1430 }
1431 if (skinny)
1432 adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1433 else
1434#endif // EIGEN_SYCL_DISABLE_SKINNY
1435 adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1436 }
1437 }
1438
1439 template <bool skinny, typename input_mapper_properties, typename LhsMapper, typename RhsMapper>
1440 void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1441 const TripleDim &triple_dim) const {
1442#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1443 if (device().has_local_memory()) {
1444 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1445 launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1446 buffer, lhs, rhs, triple_dim);
1447 }
1448#endif
1449#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1450 if (!(device().has_local_memory())) {
1451 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters;
1452 launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1453 buffer, lhs, rhs, triple_dim);
1454 }
1455#endif
1456 }
1457
1458 template <TensorSycl::internal::contraction_type ct, bool skinny, typename input_mapper_properties,
1459 typename Properties, typename LhsMapper, typename RhsMapper>
1460 void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1461 const TripleDim &triple_dim) const {
1462 const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1463 const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1464 const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1465 const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1466
1467 const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1468 StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1469 StorageIndex groupSizeK =
1470 skinny
1471 ? std::max(std::min(totalTilesK,
1472 (StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(), true) * 4) /
1473 (groupSizeM * groupSizeN)),
1474 StorageIndex(1))
1475 : StorageIndex(1);
1476
1477 const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1478
1479 const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1480
1481 const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1482 const StorageIndex globalRange = totalGroupSize * localRange;
1483
1484 const StorageIndex scratchSize = (ct == TensorSycl::internal::contraction_type::local)
1485 ? ((Properties::DoubleBuffer + 1) *
1486 (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1487 ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1488 (Properties::TileSizeDimN + Properties::BC))
1489 : StorageIndex(1);
1490
1491 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1492 if (groupSizeK == 1) {
1493 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1494 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1495 PacketAccess, input_mapper_properties, true, ct>
1496 ContractKernelName;
1497 device()
1498 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1499 lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim)
1500 .wait();
1501 } else {
1502 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1503 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1504 PacketAccess, input_mapper_properties, false, ct>
1505 ContractKernelName;
1506 CoeffReturnType *temp_pointer = static_cast<CoeffReturnType *>(
1507 device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK * sizeof(CoeffReturnType)));
1508 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1509
1510 device()
1511 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1512 lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1513 triple_dim)
1514 .wait();
1515
1516 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1517 auto op = Op();
1518 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1519 EvaluatorPointerType, Op>
1520 ReductionKernel;
1521
1522 device()
1523 .template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1524 tmp_global_accessor, buffer,
1525 cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex(
1526 Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1527 cl::sycl::range<1>(localRange)),
1528 StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK)
1529 .wait();
1530 device().deallocate_temp(temp_pointer);
1531 }
1532 }
1533
1534#ifndef EIGEN_SYCL_DISABLE_GEMV
1535 template <bool is_lhs_vec, typename VectorMapper, typename TensorMapper, typename StorageIndex>
1536 void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat,
1537 StorageIndex NC, StorageIndex C) const {
1538 const StorageIndex nonContractDim = NC;
1539 constexpr StorageIndex NCFactor = 1;
1540 constexpr StorageIndex CFactor = 1;
1541 constexpr StorageIndex NCWindow = 16;
1542 typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor>
1543 Properties;
1544 const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC);
1545 const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1546 const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1547 const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1548 const StorageIndex globalRange =
1549 (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1550 const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1551 const StorageIndex scratchSize =
1552 (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1553 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1554 if (cNumGroups > 1) {
1555 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1556 TensorMapper, StorageIndex, Properties, CFactor, false,
1557 is_lhs_vec, false>
1558 ContractKernelName;
1559 CoeffReturnType *temp_pointer =
1560 static_cast<CoeffReturnType *>(device().allocate_temp(nonContractDim * cNumGroups * sizeof(CoeffReturnType)));
1561 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1562
1563 device()
1564 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1565 vec, mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C)
1566 .wait();
1567
1568 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1569 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1570 EvaluatorPointerType, Op>
1571 ReductionKernel;
1572
1573 device()
1574 .template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1575 tmp_global_accessor, buffer,
1576 cl::sycl::nd_range<1>(
1577 cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1578 cl::sycl::range<1>(localRange)),
1579 StorageIndex(1), Op(), nonContractDim, cNumGroups)
1580 .wait();
1581 device().deallocate_temp(temp_pointer);
1582 } else {
1583 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1584 TensorMapper, StorageIndex, Properties, CFactor, false,
1585 is_lhs_vec, true>
1586 ContractKernelName;
1587 device()
1588 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1589 vec, mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C)
1590 .wait();
1591 }
1592 }
1593#endif
1594
1595#ifndef EIGEN_SYCL_DISABLE_SCALAR
1596 template <typename LhsMapper, typename RhsMapper>
1597 EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1598 StorageIndex K) const {
1599 EIGEN_STATIC_ASSERT(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) &
1600 (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1601 "The Local thread size must be a power of 2 for the reduction "
1602 "operation");
1603 constexpr StorageIndex local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1;
1604
1605 // Here we force the code not to be more than 2-step reduction: Our empirical research shows that if each thread
1606 // reduces at least 512 elementss individually, we get better performance.
1607 const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1608 const StorageIndex global_range = num_work_group * local_range;
1609
1610 typedef Eigen::TensorSycl::internal::GeneralScalarContraction<
1611 CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, LhsMapper, RhsMapper, StorageIndex, false>
1612 ContractKernelName;
1613 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1614 if (num_work_group > 1) {
1615 CoeffReturnType *temp_pointer =
1616 static_cast<CoeffReturnType *>(device().allocate_temp(num_work_group * sizeof(CoeffReturnType)));
1617 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1618 device()
1619 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1620 thread_range, local_range, K)
1621 .wait();
1622 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1623 typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType,
1624 EvaluatorPointerType, StorageIndex, local_range>
1625 GenericRKernel;
1626 device()
1627 .template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1628 tmp_global_accessor, buffer,
1629 cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range,
1630 Op())
1631 .wait();
1632 device().deallocate_temp(temp_pointer);
1633 } else {
1634 device()
1635 .template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range,
1636 local_range, K)
1637 .wait();
1638 }
1639 }
1640#endif
1641
1642 EIGEN_STRONG_INLINE void cleanup() {
1643 this->m_leftImpl.cleanup();
1644 this->m_rightImpl.cleanup();
1645
1646 if (this->m_result) {
1647 this->m_device.deallocate_temp(this->m_result);
1648 this->m_result = NULL;
1649 }
1650 }
1651};
1652} // namespace Eigen
1653#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
Definition TensorContraction.h:303
TensorContractionKernel is a template class that provides Tensor -Tensor contraction operation.
Definition TensorContractionSycl.h:457
Namespace containing all symbols from the Eigen library.
The tensor evaluator class.
Definition TensorEvaluator.h:30
BlockProperties is a template class that provides different characteristic of a block of each Tensor ...
Definition TensorContractionSycl.h:322
TTPanelSize, a template class used for setting the panel size required for launching General Tensor T...
Definition TensorContractionSycl.h:82
TVPanelSize, a template class used for setting the panel size required for launching General TensorVe...
Definition TensorContractionSycl.h:46
MemHolder this is a place holder struct for creating memory hierarchy in SYCL. Inside SYCL kernel it ...
Definition TensorContractionSycl.h:504
TiledMemory: contains required memory pointer for loading each tile of the TensorContraction panel fr...
Definition TensorContractionSycl.h:537
ThreadProperties is a template class that provides each thread's properties within a workgroup....
Definition TensorContractionSycl.h:373