10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
14#ifdef EIGEN_USE_THREADS
17#include "./InternalHeaderCheck.h"
21template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
24 :
public TensorContractionEvaluatorBase<TensorEvaluator<
25 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice>> {
26 typedef ThreadPoolDevice Device;
28 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
29 typedef TensorContractionEvaluatorBase<Self> Base;
31 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
32 typedef std::remove_const_t<typename XprType::Scalar> Scalar;
33 typedef typename XprType::Index Index;
34 typedef typename XprType::CoeffReturnType CoeffReturnType;
35 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
37 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
43 typedef std::conditional_t<static_cast<int>(Layout) ==
static_cast<int>(
ColMajor), LeftArgType, RightArgType>
45 typedef std::conditional_t<static_cast<int>(Layout) ==
static_cast<int>(
ColMajor), RightArgType, LeftArgType>
48 static constexpr int LDims =
49 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
50 static constexpr int RDims =
51 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
52 static constexpr int ContractDims = internal::array_size<Indices>::value;
54 typedef array<Index, LDims> left_dim_mapper_t;
55 typedef array<Index, RDims> right_dim_mapper_t;
57 typedef array<Index, ContractDims> contract_t;
58 typedef array<Index, LDims - ContractDims> left_nocontract_t;
59 typedef array<Index, RDims - ContractDims> right_nocontract_t;
61 static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
63 typedef DSizes<Index, NumDims> Dimensions;
66 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
67 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
68 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
70 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
71 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
73 TensorEvaluator(
const XprType& op,
const Device& device) : Base(op, device) {}
75 template <
int Alignment>
76 void evalProduct(Scalar* buffer)
const {
77 evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
80 template <
typename EvalToCallback,
int Alignment>
81 void evalProductAsync(Scalar* buffer, EvalToCallback done)
const {
82 evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
85 template <
typename DoneCallback,
int Alignment>
86 void evalProductImpl(Scalar* buffer, DoneCallback done)
const {
102 static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
104 const Index m = this->m_i_size;
105 const Index n = this->m_j_size;
106 const Index k = this->m_k_size;
107 if (m == 0 || n == 0 || k == 0)
return;
132 bool shard_by_col = shardByCol(m, n, 2);
138 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n,
144 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n,
155 const TensorOpCost cost = contractionCost(m, n, bm, bn, bk, shard_by_col,
false);
157 TensorCostModel<ThreadPoolDevice>::numThreads(
static_cast<double>(n) * m, cost, this->m_device.numThreads());
158 int num_threads_by_k = numThreadsInnerDim(m, n, k);
159 if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
162 if (IsEvalInSyncMode) {
163 EvalShardedByInnerDimContext<DoneCallback> ctx(
this, num_threads_by_k, buffer, m, n, k, std::move(done));
164 ctx.template run<Alignment>();
167 new EvalShardedByInnerDimContext<DoneCallback>(
this, num_threads_by_k, buffer, m, n, k, std::move(done));
168 ctx->template runAsync<Alignment>();
176 if (n == 1) num_threads = 1;
178 if (num_threads == 1) {
179 TENSOR_CONTRACTION_DISPATCH(this->
template evalProductSequential,
Unaligned, (buffer));
180 if (!IsEvalInSyncMode) done();
185 shard_by_col = shardByCol(m, n, num_threads);
187 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(
188 k, m, n, num_threads);
193 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(
194 k, m, n, num_threads);
201 Index nm0 = numext::div_ceil(m, bm);
202 Index nn0 = numext::div_ceil(n, bn);
203 Index nk = numext::div_ceil(k, bk);
214 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
215 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
217 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
218 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
221 Index nm = numext::div_ceil(nm0, gm);
222 Index nn = numext::div_ceil(nn0, gn);
228 const Index sharding_dim_tasks = shard_by_col ? nn : nm;
229 const int num_worker_threads = this->m_device.numThreadsInPool();
234 const float oversharding_factor = num_worker_threads <= 4 ? 8.0
235 : num_worker_threads <= 8 ? 4.0
236 : num_worker_threads <= 16 ? 2.0
237 : num_worker_threads <= 32 ? 1.0
238 : num_worker_threads <= 64 ? 0.8
241 const bool parallelize_by_sharding_dim_only = sharding_dim_tasks >= oversharding_factor * num_worker_threads;
250 bool parallel_pack = num_threads >= nm * nn;
252 if (m * bk * Index(
sizeof(LhsScalar)) + n * bk * Index(
sizeof(RhsScalar)) <=
l2CacheSize() * num_threads)
253 parallel_pack =
true;
256 if ((shard_by_col ? nm : nn) == 1) parallel_pack =
false;
259 if (parallelize_by_sharding_dim_only) parallel_pack =
false;
262 if (IsEvalInSyncMode) {
263#define CONTEXT_ARGS \
264 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
265 parallelize_by_sharding_dim_only, NoCallback()) \
267 TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
271#define CONTEXT_ARGS \
272 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
273 parallelize_by_sharding_dim_only, std::move(done))
274 TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
284 void operator()() { eigen_assert(
false &&
"NoCallback should never be called"); }
289 template <
typename DoneCallback,
typename Context>
290 class EvalParallelNotification;
293 template <
typename Context>
294 class EvalParallelNotification<NoCallback, Context> {
296 EvalParallelNotification(Context*, NoCallback) {}
297 void Notify() { done_.Notify(); }
298 void Wait() { done_.Wait(); }
301 Eigen::Notification done_;
305 template <
typename DoneCallback,
typename Context>
306 class EvalParallelNotification {
308 EvalParallelNotification(Context* ctx, DoneCallback done) : ctx_(ctx), done_(std::move(done)) {}
314 DoneCallback done_copy = std::move(done_);
334 template <
typename DoneCallback,
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
335 bool rhs_inner_dim_reordered,
int Alignment>
336 class EvalParallelContext {
338 typedef internal::TensorContractionInputMapper<LhsScalar,
Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
339 contract_t, internal::packet_traits<LhsScalar>::size,
340 lhs_inner_dim_contiguous,
false,
Unaligned>
342 typedef internal::TensorContractionInputMapper<RhsScalar,
Index, internal::Rhs, RightEvaluator, right_nocontract_t,
343 contract_t, internal::packet_traits<RhsScalar>::size,
344 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
347 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
349 typedef internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
350 TensorContractionKernel;
352 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
353 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
354 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
356 EvalParallelContext(
const Self* self,
int num_threads, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
357 Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0,
358 bool shard_by_col,
bool parallel_pack,
bool parallelize_by_sharding_dim_only, DoneCallback done)
359 : created_by_thread_id_(std::this_thread::get_id()),
360 done_(this, std::move(done)),
361 device_(self->m_device),
362 lhs_(self->m_leftImpl, self->m_left_nocontract_strides, self->m_i_strides, self->m_left_contracting_strides,
364 rhs_(self->m_rightImpl, self->m_right_nocontract_strides, self->m_j_strides,
365 self->m_right_contracting_strides, self->m_k_strides),
368 output_kernel_(self->m_output_kernel),
369 tensor_contraction_params_(self->m_tensor_contraction_params),
370 num_threads_(num_threads),
371 shard_by_col_(shard_by_col),
372 parallel_pack_(parallel_pack),
373 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
387 kernel_(m_, k_, n_, bm_, bk_, bn_),
388 num_thread_local_allocations_(0),
392 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_ ? device_.numThreadsInPool() : 0)),
395 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity, {*
this}, {*
this}),
396 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0, {*
this}, {*
this}) {
398 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
400 for (Index x = 0; x < P; x++) {
406 x == 0 ? 1 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) + (x == P - 1 ? nm_ * nn_ : 0);
407 state_packing_ready_[x] = parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
408 state_kernel_[x] =
new std::atomic<uint8_t>*[nm_];
409 for (Index m = 0; m < nm_; m++) {
410 state_kernel_[x][m] =
new std::atomic<uint8_t>[nn_];
414 for (Index n = 0; n < nn_; n++)
415 state_kernel_[x][m][n].store((x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1), std::memory_order_relaxed);
420 packed_mem_ = kernel_.allocateSlices(
424 std::min<Index>(nk_, P - 1),
425 packed_lhs_, packed_rhs_);
427 if (parallelize_by_sharding_dim_only_) {
428 const int num_worker_threads = device_.numThreadsInPool();
431 can_use_thread_local_packed_ =
new std::atomic<bool>[nn_];
432 for (
int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(
true, std::memory_order_relaxed);
434 Index num_blocks = num_worker_threads * gn_;
435 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
440 nullptr, &rhs_thread_local_pre_allocated_);
443 can_use_thread_local_packed_ =
new std::atomic<bool>[nm_];
444 for (
int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(
true, std::memory_order_relaxed);
446 Index num_blocks = num_worker_threads * gm_;
447 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
451 1, &lhs_thread_local_pre_allocated_,
457 ~EvalParallelContext() {
458 for (Index x = 0; x < P; x++) {
459 for (Index m = 0; m < nm_; m++)
delete[] state_kernel_[x][m];
460 delete[] state_kernel_[x];
462 kernel_.deallocate(device_, packed_mem_);
463 if (parallelize_by_sharding_dim_only_) {
464 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
465 delete[] can_use_thread_local_packed_;
490 std::thread::id created_by_thread_id_;
494 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
496 const Device& device_;
499 Scalar*
const buffer_;
500 OutputMapper output_;
501 OutputKernelType output_kernel_;
502 TensorContractionParams tensor_contraction_params_;
503 const int num_threads_;
504 const bool shard_by_col_;
505 const bool parallel_pack_;
506 const bool parallelize_by_sharding_dim_only_;
527 TensorContractionKernel kernel_;
563 static constexpr Index P = 3;
566 BlockMemHandle packed_mem_;
567 std::vector<LhsBlock> packed_lhs_[P - 1];
568 std::vector<RhsBlock> packed_rhs_[P - 1];
588 BlockMemHandle thread_local_pre_alocated_mem_;
592 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
593 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
596 std::atomic<int> num_thread_local_allocations_;
597 const int thread_local_capacity;
605 template <
typename BlockType>
606 class ThreadLocalBlocks {
608 ThreadLocalBlocks() =
default;
610 ThreadLocalBlocks(BlockType* base,
size_t grain_size)
611 : is_pre_allocated_(true), thread_local_pre_allocated_base_(base), grain_size_(grain_size) {}
613 ThreadLocalBlocks(BlockMemHandle mem_handle, std::vector<BlockType> blocks)
614 : is_pre_allocated_(false), mem_handle_(std::move(mem_handle)), blocks_(std::move(blocks)) {}
616 BlockType& block(
int grain_index) {
617 eigen_assert(grain_index >= 0);
618 eigen_assert(
static_cast<size_t>(grain_index) < size());
619 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index] : blocks_[grain_index];
622 void Release(EvalParallelContext& ctx)
const {
623 if (!is_pre_allocated_) {
624 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
628 size_t size()
const {
return is_pre_allocated_ ? grain_size_ : blocks_.size(); }
631 bool is_pre_allocated_;
634 BlockType* thread_local_pre_allocated_base_ =
nullptr;
635 size_t grain_size_ = 0;
638 BlockMemHandle mem_handle_{};
639 std::vector<BlockType> blocks_;
648 template <
typename BlockType,
bool is_rhs>
649 class ThreadLocalBlocksInitialize {
650 static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
651 static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
652 static_assert(kIsLhs || kIsRhs,
"Unknown block type");
654 using Blocks = ThreadLocalBlocks<BlockType>;
657 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
658 : ctx_(ctx), num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
660 void operator()(Blocks& blocks) {
661 const int n = ctx_.num_thread_local_allocations_.fetch_add(1, std::memory_order_relaxed);
663 if (n >= num_worker_threads_) {
664 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
666 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
675 template <
bool pack_rhs,
typename EvalCtx = EvalParallelContext>
676 struct ThreadLocalBlocksAllocator;
678 template <
typename EvalCtx>
679 struct ThreadLocalBlocksAllocator<true, EvalCtx> {
680 static void allocate(EvalCtx& ctx, Blocks& blocks) {
681 std::vector<RhsBlock> rhs_blocks;
682 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
686 nullptr, &rhs_blocks);
688 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
691 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
692 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
693 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
697 template <
typename EvalCtx>
698 struct ThreadLocalBlocksAllocator<false, EvalCtx> {
699 static void allocate(EvalCtx& ctx, Blocks& blocks) {
700 std::vector<LhsBlock> lhs_blocks;
701 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
705 &lhs_blocks,
nullptr);
707 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
710 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
711 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
712 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
716 EvalParallelContext& ctx_;
717 const int num_worker_threads_;
720 template <
typename BlockType>
721 class ThreadLocalBlocksRelease {
723 using Blocks = ThreadLocalBlocks<BlockType>;
724 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
725 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
728 EvalParallelContext& ctx_;
732 using ThreadLocalLhsInit = ThreadLocalBlocksInitialize<LhsBlock,
false>;
733 using ThreadLocalRhsInit = ThreadLocalBlocksInitialize<RhsBlock,
true>;
736 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
737 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
741 Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit, ThreadLocalLhsRelease> lhs_thread_local_blocks_;
742 Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit, ThreadLocalRhsRelease> rhs_thread_local_blocks_;
749 std::atomic<bool>* can_use_thread_local_packed_;
751 std::atomic<uint8_t>** state_kernel_[P];
756 std::atomic<Index> state_packing_ready_[P];
757 std::atomic<Index> state_switch_[P];
759 LhsBlock& packed_lhs(Index m, Index k, Index m1,
bool use_thread_local) {
760 if (use_thread_local) {
761 eigen_assert(!shard_by_col_);
762 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
764 Index grain_index = m1 - m * gm_;
766 internal::convert_index<int>(grain_index));
768 return packed_lhs_[k % (P - 1)][m1];
772 RhsBlock& packed_rhs(Index n, Index k, Index n1,
bool use_thread_local) {
773 if (use_thread_local) {
774 eigen_assert(shard_by_col_);
775 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
777 Index grain_index = n1 - n * gn_;
779 internal::convert_index<int>(grain_index));
781 return packed_rhs_[k % (P - 1)][n1];
795 void pack_lhs(Index m, Index k) {
796 bool use_thread_local =
false;
798 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
799 can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
800 if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
801 use_thread_local =
true;
807 can_use_thread_local_packed_[m].store(
false, std::memory_order_relaxed);
811 const Index mend = m * gm_ + gm(m);
812 for (Index m1 = m * gm_; m1 < mend; m1++)
813 kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local), lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
815 if (!parallel_pack_ && shard_by_col_) {
816 eigen_assert(!use_thread_local);
819 signal_switch(k + 1);
820 for (Index n = nn_ - 1; n >= 0; n--) {
821 bool sync = parallelize_by_sharding_dim_only_ || n == 0;
822 signal_kernel(m, n, k, sync, use_thread_local);
827 void pack_rhs(Index n, Index k) {
828 bool use_thread_local =
false;
830 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
831 can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
832 if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
833 use_thread_local =
true;
839 can_use_thread_local_packed_[n].store(
false, std::memory_order_relaxed);
843 const Index nend = n * gn_ + gn(n);
844 for (Index n1 = n * gn_; n1 < nend; n1++) {
845 if (!TensorContractionKernel::HasBeta && k == 0) {
855 std::fill_n(buffer_ + n1 * bn_ * m_, bn(n1) * m_, Scalar(0));
857 kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local), rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
860 if (parallel_pack_ || shard_by_col_) {
861 signal_switch(k + 1);
862 for (Index m = nm_ - 1; m >= 0; m--) {
863 bool sync = parallelize_by_sharding_dim_only_ || m == 0;
864 signal_kernel(m, n, k, sync, use_thread_local);
867 eigen_assert(!use_thread_local);
872 void kernel(Index m, Index n, Index k,
bool use_thread_local) {
876 const Index nend = n * gn_ + gn(n);
877 const Index mend = m * gm_ + gm(m);
880 const Scalar alpha = Scalar(1);
881 const Scalar beta = (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
884 for (Index n1 = n * gn_; n1 < nend; n1++) {
885 for (Index m1 = m * gm_; m1 < mend; m1++) {
886 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
887 kernel_.invoke(output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
888 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), bk(k), bn(n1), alpha, beta);
892 output_kernel_(output_mapper, tensor_contraction_params_, m1 * bm_, n1 * bn_, bm(m1), bn(n1));
897 for (Index m1 = m * gm_; m1 < mend; m1++)
898 for (Index n1 = n * gn_; n1 < nend; n1++) {
899 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
900 kernel_.invoke(output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
901 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), bk(k), bn(n1), alpha, beta);
905 output_kernel_(output_mapper, tensor_contraction_params_, m1 * bm_, n1 * bn_, bm(m1), bn(n1));
909 signal_kernel(m, n, k + 1,
false,
false);
910 signal_switch(k + 2);
913 void signal_packing(Index k) {
914 eigen_assert(!parallel_pack_);
915 Index s = state_packing_ready_[k % P].fetch_sub(1);
918 state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
919 enqueue_packing(k, shard_by_col_);
922 void signal_kernel(Index m, Index n, Index k,
bool sync,
bool use_thread_local) {
923 std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
924 Index s = state->load();
926 if (s != 1 && state->fetch_sub(1) != 1) {
927 eigen_assert(!use_thread_local);
930 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
932 kernel(m, n, k, use_thread_local);
934 eigen_assert(!use_thread_local);
935 device_.enqueue([
this, m, n, k, use_thread_local]() {
936 kernel(m, n, k, use_thread_local);
941 void signal_switch(Index k, Index v = 1) {
942 Index s = state_switch_[k % P].fetch_sub(v);
943 eigen_assert(s >= v);
948 state_switch_[k % P] = (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) + nm_ * nn_;
952 if (parallel_pack_) {
953 enqueue_packing(k, !shard_by_col_);
954 enqueue_packing(k, shard_by_col_);
955 }
else if (shard_by_col_) {
956 enqueue_packing(k,
false);
958 enqueue_packing(k,
true);
966 }
else if (k == nk_) {
967 signal_switch(k + 1, parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
974 void enqueue_packing(Index k,
bool rhs) { enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs); }
976 void enqueue_packing_helper(Index start, Index end, Index k,
bool rhs) {
977 if (end - start == 1) {
983 while (end - start > 1) {
984 Index mid = (start + end) / 2;
985 device_.enqueue([
this, mid, end, k, rhs]() {
986 enqueue_packing_helper(mid, end, k, rhs);
999 bool pack_async = (start == 0) && (parallelize_by_sharding_dim_only_ && shard_by_col_ == rhs) &&
1000 (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1003 device_.enqueue([
this, start, end, k, rhs]() {
1004 enqueue_packing_helper(start, end, k, rhs);
1007 enqueue_packing_helper(start, end, k, rhs);
1013 Index bm(Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1014 Index bn(Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1015 Index bk(Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1017 Index gm(Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1018 Index gn(Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1020 EvalParallelContext(
const EvalParallelContext&) =
delete;
1021 void operator=(
const EvalParallelContext&) =
delete;
1024 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1025 using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
1026 rhs_inner_dim_reordered, Alignment>;
1035 template <
typename DoneCallback>
1036 struct EvalShardedByInnerDimContext {
1037 EvalShardedByInnerDimContext(
const Self* self,
int num_threads, Scalar* result_buffer, Index m_size, Index n_size,
1038 Index k_size, DoneCallback done_callback)
1040 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1041 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1042 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1043 result(result_buffer),
1047 done(std::move(done_callback)),
1048 buffer_size_bytes(m * n * sizeof(Scalar)),
1049 block_size(blockSize(k, num_threads)),
1050 num_blocks(numext::div_ceil<
Index>(k, block_size)),
1051 num_pending_blocks(internal::convert_index<int>(num_blocks)),
1052 l0_ranges(numext::div_ceil<
Index>(num_blocks, l0_size)),
1053 l0_state(l0_ranges),
1054 block_buffers(num_blocks) {
1056 for (
int i = 0; i < l0_ranges; ++i) {
1057 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
1058 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1062 for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1063 Scalar* buf = block_idx == 0 ? result :
static_cast<Scalar*
>(evaluator->m_device.allocate(buffer_size_bytes));
1064 block_buffers.emplace_back(buf);
1068 ~EvalShardedByInnerDimContext() {
1069 for (Index i = 1; i < num_blocks; ++i) {
1070 evaluator->m_device.deallocate(block_buffers[i]);
1074 template <
int Alignment>
1076 Barrier barrier(internal::convert_index<int>(num_blocks));
1077 eval<Alignment>(barrier, 0, num_blocks);
1081 aggregateL0Blocks<Alignment>();
1084 applyOutputKernel();
1087 template <
int Alignment>
1089 evalAsync<Alignment>(0, num_blocks);
1095 static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1097 const Self* evaluator;
1100 bool m_lhs_inner_dim_contiguous;
1101 bool m_rhs_inner_dim_contiguous;
1102 bool m_rhs_inner_dim_reordered;
1116 Index buffer_size_bytes;
1122 std::atomic<int> num_pending_blocks;
1140 static const Index l0_size = 4;
1144 MaxSizeVector<std::atomic<int>> l0_state;
1147 MaxSizeVector<Scalar*> block_buffers;
1149 template <
int Alignment>
1150 void processBlock(Index block_idx, Index begin, Index end) {
1151 Scalar* buf = block_buffers[block_idx];
1153 TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1155 internal::convert_index<int>(num_blocks)));
1158 const Index l0_index = block_idx / l0_size;
1159 const int v = l0_state[l0_index].fetch_sub(1);
1160 eigen_assert(v >= 1);
1165 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1166 const Index dst_block_idx = l0_index * l0_size;
1168 if (rng_size == l0_size) {
1169 addAllToBuffer<Alignment>(m * n,
1170 block_buffers[dst_block_idx + 1],
1171 block_buffers[dst_block_idx + 2],
1172 block_buffers[dst_block_idx + 3],
1173 block_buffers[dst_block_idx]);
1176 for (
int i = 1; i < rng_size; ++i) {
1177 addToBuffer<Alignment>(m * n,
1178 block_buffers[dst_block_idx + i],
1179 block_buffers[dst_block_idx]);
1186 template <
int Alignment>
1187 void aggregateL0Blocks()
const {
1190 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1191 addAllToBuffer<Alignment>(m * n,
1192 block_buffers[(l0_index + 0) * l0_size],
1193 block_buffers[(l0_index + 1) * l0_size],
1194 block_buffers[(l0_index + 2) * l0_size],
1198 for (; l0_index < l0_ranges; ++l0_index) {
1199 addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size], block_buffers[0]);
1203 void applyOutputKernel()
const {
1204 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1205 evaluator->m_output_kernel(OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1210 Index actualBlockSize(Index block_idx)
const {
1211 return block_idx + 1 < num_blocks ? block_size : k + block_size - block_size * num_blocks;
1215 Index actualRangeSize(Index num_ranges, Index range_size, Index range_idx)
const {
1216 eigen_assert(range_idx < num_ranges);
1217 return range_idx + 1 < num_ranges ? range_size : num_blocks + range_size - range_size * num_ranges;
1220 template <
int Alignment>
1221 EIGEN_STRONG_INLINE
static void addToBuffer(
size_t n,
const Scalar* src_buf, Scalar* tgt_buf) {
1222 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1224 const size_t num_packets = n / output_packet_size;
1225 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1226 const PacketReturnType src_val = internal::pload<PacketReturnType>(src_buf + i);
1227 const PacketReturnType tgt_val = internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
1228 const PacketReturnType sum = internal::padd(src_val, tgt_val);
1229 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, sum);
1231 for (; i < n; ++i) {
1232 tgt_buf[i] += src_buf[i];
1236 template <
int Alignment>
1237 EIGEN_STRONG_INLINE
static void addAllToBuffer(
size_t n,
const Scalar* src_buf0,
const Scalar* src_buf1,
1238 const Scalar* src_buf2, Scalar* dst_buf) {
1239 using ::Eigen::internal::padd;
1240 using ::Eigen::internal::pload;
1241 using ::Eigen::internal::ploadt;
1242 using ::Eigen::internal::pstoret;
1244 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1247 const size_t num_packets = n / output_packet_size;
1248 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1249 const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
1250 const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
1251 const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
1253 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
1254 const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
1256 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
1258 for (; i < n; ++i) {
1259 dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
1263 template <
int Alignment>
1264 void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1265 while (end_block_idx - start_block_idx > 1) {
1266 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1267 evaluator->m_device.enqueue([
this, &barrier, mid_block_idx, end_block_idx]() {
1268 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1270 end_block_idx = mid_block_idx;
1273 Index block_idx = start_block_idx;
1274 Index block_start = block_idx * block_size;
1275 Index block_end = block_start + actualBlockSize(block_idx);
1277 processBlock<Alignment>(block_idx, block_start, block_end);
1281 template <
int Alignment>
1282 void evalAsync(Index start_block_idx, Index end_block_idx) {
1283 while (end_block_idx - start_block_idx > 1) {
1284 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1285 evaluator->m_device.enqueue(
1286 [
this, mid_block_idx, end_block_idx]() {
1287 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1289 end_block_idx = mid_block_idx;
1292 Index block_idx = start_block_idx;
1294 Index block_start = block_idx * block_size;
1295 Index block_end = block_start + actualBlockSize(block_idx);
1297 processBlock<Alignment>(block_idx, block_start, block_end);
1299 int v = num_pending_blocks.fetch_sub(1);
1300 eigen_assert(v >= 1);
1304 aggregateL0Blocks<Alignment>();
1307 applyOutputKernel();
1314 DoneCallback done_copy = std::move(done);
1327 static Index blockSize(Index k,
int num_threads) {
1328 const auto round_up = [=](
Index index) -> Index {
1329 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1330 return numext::div_ceil<Index>(index, kmultiple) * kmultiple;
1333 const Index target_block_size = round_up(numext::div_ceil<Index>(k, num_threads));
1334 const Index desired_min_block_size = 12 * packet_size;
1336 return numext::mini<Index>(k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1339 EvalShardedByInnerDimContext(
const EvalShardedByInnerDimContext&) =
delete;
1340 void operator=(
const EvalShardedByInnerDimContext&) =
delete;
1349 static bool shardByCol(Index m, Index n, Index num_threads) {
1356 if (m / num_threads >= Traits::nr &&
1358 (n / num_threads < Traits::nr ||
1361 (n / num_threads < 4 * Traits::nr && (n % (num_threads * Traits::nr)) != 0 &&
1363 ((m % (num_threads * Traits::nr)) == 0 ||
1371 if (n / num_threads < 16 * Traits::nr && m > n * 32)
return false;
1375 Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
int num_threads,
bool shard_by_col)
const {
1378 Index nm0 = numext::div_ceil(m, bm);
1384 while (gm1 <= nm0 && nm1 == numext::div_ceil(nm0, gm1)) gm1++;
1385 if (gm1 > nm0)
break;
1387 int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads, shard_by_col);
1389 nm1 = numext::div_ceil(nm0, gm1);
1390 if (res == 0)
continue;
1397 Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
int num_threads,
bool shard_by_col)
const {
1400 Index nn0 = numext::div_ceil(n, bn);
1403 while (gn1 <= nn0 && nn1 == numext::div_ceil(nn0, gn1)) gn1++;
1404 if (gn1 > nn0)
break;
1405 int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads, shard_by_col);
1407 nn1 = numext::div_ceil(nn0, gn1);
1408 if (res == 0)
continue;
1416 int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm, Index gn, Index oldgm, Index oldgn,
1417 int num_threads,
bool shard_by_col)
const {
1418 const TensorOpCost cost = contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col,
true);
1419 double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
static_cast<double>(bm) * gm * bn * gn, cost);
1422 if (taskSize < 1)
return 1;
1424 if (taskSize > 2)
return -1;
1431 Index nm0 = numext::div_ceil(m, bm);
1432 Index nn0 = numext::div_ceil(n, bn);
1433 Index new_tasks = numext::div_ceil(nm0, gm) * numext::div_ceil(nn0, gn);
1434 double new_parallelism =
1435 static_cast<double>(new_tasks) / (numext::div_ceil<Index>(new_tasks, num_threads) * num_threads);
1436 Index old_tasks = numext::div_ceil(nm0, oldgm) * numext::div_ceil(nn0, oldgn);
1437 double old_parallelism =
1438 static_cast<double>(old_tasks) / (numext::div_ceil<Index>(old_tasks, num_threads) * num_threads);
1439 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
1443 TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
bool shard_by_col,
1444 bool prepacked)
const {
1445 const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size, PacketType<RhsScalar, Device>::size);
1446 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1447 const double kd =
static_cast<double>(bk);
1448 double compute_bandwidth = computeBandwidth(
false, bm, bn, bk);
1450 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth,
true, packed_size);
1452 cost += TensorOpCost(0,
sizeof(CoeffReturnType), 0,
true, output_packet_size);
1460 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd / n);
1461 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd / m);
1465 lhsCost.dropMemoryCost();
1467 rhsCost.dropMemoryCost();
1468 return cost + lhsCost + rhsCost;
1473 static bool shardByInnerDim(Index m, Index n, Index k,
int num_threads,
int num_threads_by_k) {
1474 std::ptrdiff_t bufsize = m * n *
sizeof(Scalar);
1475 bool shard_by_k =
false;
1477 num_threads_by_k < 2 ||
1478 num_threads_by_k < num_threads ||
1481 k / num_threads_by_k < 2 * Traits::nr) {
1483 }
else if (numext::maxi(m, n) / num_threads < Traits::nr ||
1485 (k / num_threads_by_k > 8 * Traits::nr &&
1488 (numext::mini(m, n) < 2 * Traits::nr || num_threads_by_k > num_threads))) {
1494 TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k)
const {
1496 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1497 TensorOpCost cost(0, 0, (computeBandwidth(
true, m, n, k) * m) * n,
true, output_packet_size);
1499 cost += TensorOpCost(0,
sizeof(CoeffReturnType), 0,
true, output_packet_size);
1500 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * m;
1501 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * n;
1504 lhsCost.dropMemoryCost();
1505 return cost + lhsCost + rhsCost;
1508 int numThreadsInnerDim(Index m, Index n, Index k)
const {
1509 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1510 TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1511 double total_parallel_cost = TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
1514 double reduction_cost =
1515 TensorCostModel<ThreadPoolDevice>::totalCost(m * n, TensorOpCost(2, 1, 1,
true, output_packet_size));
1516 int num_threads = 1;
1517 double min_cost = total_parallel_cost;
1518 double kPerThreadOverHead = 3000;
1519 double kFixedOverHead = 100000;
1520 for (
int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
1521 double sequential_cost = kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1522 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1523 if (parallel_cost < min_cost) {
1525 min_cost = parallel_cost;
1531 double computeBandwidth(
bool shard_by_col, Index bm, Index bn, Index bk)
const {
1535 double computeBandwidth = bk == 1 ? 4.0
1536 : (shard_by_col ? bn : bm) < Traits::nr || (shard_by_col ? bm : bn) < Traits::mr ? 2.0
1538#ifndef EIGEN_VECTORIZE_FMA
1543 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1545 return computeBandwidth;
Definition TensorContraction.h:303
Namespace containing all symbols from the Eigen library.
std::ptrdiff_t l2CacheSize()
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
std::ptrdiff_t l3CacheSize()
The tensor evaluator class.
Definition TensorEvaluator.h:30