10#ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11#define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
14#include "./InternalHeaderCheck.h"
18template <
typename Environment>
19class ThreadPoolTempl :
public Eigen::ThreadPoolInterface {
21 typedef typename Environment::EnvThread Thread;
22 typedef typename Environment::Task Task;
23 typedef RunQueue<Task, 1024> Queue;
26 constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
27 ThreadPoolTempl* pool;
33 constexpr ThreadData() : thread(), steal_partition(0), queue() {}
34 std::unique_ptr<Thread> thread;
35 std::atomic<unsigned> steal_partition;
39 ThreadPoolTempl(
int num_threads, Environment env = Environment()) : ThreadPoolTempl(num_threads, true, env) {}
41 ThreadPoolTempl(
int num_threads,
bool allow_spinning, Environment env = Environment())
43 num_threads_(num_threads),
44 allow_spinning_(allow_spinning),
49 allow_spinning && num_threads > 0 ? kSpinCount / kMaxSpinningThreads / num_threads : 0),
50 thread_data_(num_threads),
51 all_coprimes_(num_threads),
52 waiters_(num_threads),
53 global_steal_partition_(EncodePartition(0, num_threads_)),
59 waiters_.resize(num_threads_);
67 eigen_plain_assert(num_threads_ < kMaxThreads);
68 for (
int i = 1; i <= num_threads_; ++i) {
69 all_coprimes_.emplace_back(i);
70 ComputeCoprimes(i, &all_coprimes_.back());
72#ifndef EIGEN_THREAD_LOCAL
73 init_barrier_.reset(
new Barrier(num_threads_));
75 thread_data_.resize(num_threads_);
76 for (
int i = 0; i < num_threads_; i++) {
77 SetStealPartition(i, EncodePartition(0, num_threads_));
78 thread_data_[i].thread.reset(env_.CreateThread([
this, i]() { WorkerLoop(i); }));
80#ifndef EIGEN_THREAD_LOCAL
83 init_barrier_->Wait();
98 for (
size_t i = 0; i < thread_data_.size(); i++) {
99 thread_data_[i].queue.Flush();
104 for (
size_t i = 0; i < thread_data_.size(); ++i) thread_data_[i].thread.reset();
107 void SetStealPartitions(
const std::vector<std::pair<unsigned, unsigned>>& partitions) {
108 eigen_plain_assert(partitions.size() ==
static_cast<std::size_t
>(num_threads_));
111 for (
int i = 0; i < num_threads_; i++) {
112 const auto& pair = partitions[i];
113 unsigned start = pair.first, end = pair.second;
114 AssertBounds(start, end);
115 unsigned val = EncodePartition(start, end);
116 SetStealPartition(i, val);
120 void Schedule(std::function<
void()> fn) EIGEN_OVERRIDE { ScheduleWithHint(std::move(fn), 0, num_threads_); }
122 void ScheduleWithHint(std::function<
void()> fn,
int start,
int limit)
override {
123 Task t = env_.CreateTask(std::move(fn));
124 PerThread* pt = GetPerThread();
125 if (pt->pool ==
this) {
127 Queue& q = thread_data_[pt->thread_id].queue;
128 t = q.PushFront(std::move(t));
132 eigen_plain_assert(start < limit);
133 eigen_plain_assert(limit <= num_threads_);
134 int num_queues = limit - start;
135 int rnd = Rand(&pt->rand) % num_queues;
136 eigen_plain_assert(start + rnd < limit);
137 Queue& q = thread_data_[start + rnd].queue;
138 t = q.PushBack(std::move(t));
148 if (IsNotifyParkedThreadRequired()) {
157 void MaybeGetTask(Task* t) {
158 PerThread* pt = GetPerThread();
159 const int thread_id = pt->thread_id;
161 if (thread_id < 0)
return;
162 Queue& q = thread_data_[thread_id].queue;
165 if (num_threads_ == 1) {
172 for (
int i = 0; i < spin_count_ && !t->f; ++i) *t = q.PopFront();
174 if (EIGEN_PREDICT_FALSE(!t->f)) *t = LocalSteal();
175 if (EIGEN_PREDICT_FALSE(!t->f)) *t = GlobalSteal();
176 if (EIGEN_PREDICT_FALSE(!t->f)) {
177 if (allow_spinning_ && StartSpinning()) {
178 for (
int i = 0; i < spin_count_ && !t->f; ++i) *t = GlobalSteal();
180 bool has_no_notify_task = StopSpinning();
187 if (has_no_notify_task && !t->f) *t = GlobalSteal();
193 void Cancel() EIGEN_OVERRIDE {
198#ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION
199 for (
size_t i = 0; i < thread_data_.size(); i++) {
200 thread_data_[i].thread->OnCancel();
208 int NumThreads() const EIGEN_FINAL {
return num_threads_; }
210 int CurrentThreadId() const EIGEN_FINAL {
211 const PerThread* pt =
const_cast<ThreadPoolTempl*
>(
this)->GetPerThread();
212 if (pt->pool ==
this) {
213 return pt->thread_id;
227 static constexpr int kMaxPartitionBits = 16;
228 static constexpr int kMaxThreads = 1 << kMaxPartitionBits;
230 inline unsigned EncodePartition(
unsigned start,
unsigned limit) {
return (start << kMaxPartitionBits) | limit; }
232 inline void DecodePartition(
unsigned val,
unsigned* start,
unsigned* limit) {
233 *limit = val & (kMaxThreads - 1);
234 val >>= kMaxPartitionBits;
238 void AssertBounds(
int start,
int end) {
239 eigen_plain_assert(start >= 0);
240 eigen_plain_assert(start < end);
241 eigen_plain_assert(end <= num_threads_);
244 inline void SetStealPartition(
size_t i,
unsigned val) {
245 thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
248 inline unsigned GetStealPartition(
int i) {
return thread_data_[i].steal_partition.load(std::memory_order_relaxed); }
250 void ComputeCoprimes(
int N, MaxSizeVector<unsigned>* coprimes) {
251 for (
int i = 1; i <= N; i++) {
261 coprimes->push_back(i);
267 static constexpr int kMaxSpinningThreads = 1;
271 static constexpr int kSpinCount = 5000;
281 static constexpr int kMinActiveThreadsToStartSpinning = 4;
283 struct SpinningState {
294 static constexpr uint64_t kNumSpinningMask = 0x00000000FFFFFFFF;
295 static constexpr uint64_t kNumNoNotifyMask = 0xFFFFFFFF00000000;
296 static constexpr uint64_t kNumNoNotifyShift = 32;
298 uint64_t num_spinning;
299 uint64_t num_no_notification;
303 static SpinningState Decode(uint64_t state) {
304 uint64_t num_spinning = (state & kNumSpinningMask);
305 uint64_t num_no_notification = (state & kNumNoNotifyMask) >> kNumNoNotifyShift;
307 eigen_plain_assert(num_no_notification <= num_spinning);
308 return {num_spinning, num_no_notification};
312 uint64_t Encode()
const {
313 eigen_plain_assert(num_no_notification <= num_spinning);
314 return (num_no_notification << kNumNoNotifyShift) | num_spinning;
319 const int num_threads_;
320 const bool allow_spinning_;
321 const int spin_count_;
322 MaxSizeVector<ThreadData> thread_data_;
323 MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
324 MaxSizeVector<EventCount::Waiter> waiters_;
325 unsigned global_steal_partition_;
326 std::atomic<uint64_t> spinning_state_;
327 std::atomic<unsigned> blocked_;
328 std::atomic<bool> done_;
329 std::atomic<bool> cancelled_;
331#ifndef EIGEN_THREAD_LOCAL
332 std::unique_ptr<Barrier> init_barrier_;
333 EIGEN_MUTEX per_thread_map_mutex_;
334 std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
337 unsigned NumBlockedThreads()
const {
return blocked_.load(); }
338 unsigned NumActiveThreads()
const {
return num_threads_ - blocked_.load(); }
341 void WorkerLoop(
int thread_id) {
342#ifndef EIGEN_THREAD_LOCAL
343 std::unique_ptr<PerThread> new_pt(
new PerThread());
344 per_thread_map_mutex_.lock();
345 bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
346 eigen_plain_assert(insertOK);
347 EIGEN_UNUSED_VARIABLE(insertOK);
348 per_thread_map_mutex_.unlock();
349 init_barrier_->Notify();
350 init_barrier_->Wait();
352 PerThread* pt = GetPerThread();
354 pt->rand = GlobalThreadIdHash();
355 pt->thread_id = thread_id;
357 while (!cancelled_.load(std::memory_order_relaxed)) {
361 if (EIGEN_PREDICT_FALSE(!t.f)) {
362 EventCount::Waiter* waiter = &waiters_[pt->thread_id];
363 if (!WaitForWork(waiter, &t))
return;
365 if (EIGEN_PREDICT_TRUE(t.f)) env_.ExecuteTask(t);
371 Task Steal(
unsigned start,
unsigned limit) {
372 PerThread* pt = GetPerThread();
373 const size_t size = limit - start;
374 unsigned r = Rand(&pt->rand);
377 eigen_plain_assert(all_coprimes_[size - 1].size() < (1 << 30));
378 unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
379 unsigned index = ((uint64_t)all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
380 unsigned inc = all_coprimes_[size - 1][index];
382 for (
unsigned i = 0; i < size; i++) {
383 eigen_plain_assert(start + victim < limit);
384 Task t = thread_data_[start + victim].queue.PopBack();
389 if (victim >= size) {
390 victim -=
static_cast<unsigned int>(size);
398 PerThread* pt = GetPerThread();
399 unsigned partition = GetStealPartition(pt->thread_id);
402 if (global_steal_partition_ == partition)
return Task();
403 unsigned start, limit;
404 DecodePartition(partition, &start, &limit);
405 AssertBounds(start, limit);
407 return Steal(start, limit);
411 Task GlobalSteal() {
return Steal(0, num_threads_); }
416 bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
417 eigen_plain_assert(!t->f);
422 int victim = NonEmptyQueueIndex();
428 *t = thread_data_[victim].queue.PopBack();
437 if (done_ && blocked_ ==
static_cast<unsigned>(num_threads_)) {
444 if (NonEmptyQueueIndex() != -1) {
458 ec_.CommitWait(waiter);
463 int NonEmptyQueueIndex() {
464 PerThread* pt = GetPerThread();
468 const size_t size = thread_data_.size();
469 unsigned r = Rand(&pt->rand);
470 unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
471 unsigned victim = r % size;
472 for (
unsigned i = 0; i < size; i++) {
473 if (!thread_data_[victim].queue.Empty()) {
477 if (victim >= size) {
478 victim -=
static_cast<unsigned int>(size);
488 bool StartSpinning() {
489 if (NumActiveThreads() > kMinActiveThreadsToStartSpinning)
return false;
491 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
493 SpinningState state = SpinningState::Decode(spinning);
495 if ((state.num_spinning - state.num_no_notification) >= kMaxSpinningThreads) {
500 ++state.num_spinning;
502 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
513 bool StopSpinning() {
514 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
516 SpinningState state = SpinningState::Decode(spinning);
519 --state.num_spinning;
522 bool has_no_notify_task = state.num_no_notification > 0;
523 if (has_no_notify_task) --state.num_no_notification;
525 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
526 return has_no_notify_task;
535 bool IsNotifyParkedThreadRequired() {
536 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
538 SpinningState state = SpinningState::Decode(spinning);
543 if (state.num_no_notification == state.num_spinning)
return true;
546 ++state.num_no_notification;
548 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
554 static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
555 return std::hash<std::thread::id>()(std::this_thread::get_id());
558 EIGEN_STRONG_INLINE PerThread* GetPerThread() {
559#ifndef EIGEN_THREAD_LOCAL
560 static PerThread dummy;
561 auto it = per_thread_map_.find(GlobalThreadIdHash());
562 if (it == per_thread_map_.end()) {
565 return it->second.get();
568 EIGEN_THREAD_LOCAL PerThread per_thread_;
569 PerThread* pt = &per_thread_;
574 static EIGEN_STRONG_INLINE
unsigned Rand(uint64_t* state) {
575 uint64_t current = *state;
577 *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
579 return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
583typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1