10#ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11#define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
14#include "./InternalHeaderCheck.h"
18template <
typename Environment>
19class ThreadPoolTempl :
public Eigen::ThreadPoolInterface {
21 typedef typename Environment::EnvThread Thread;
22 typedef typename Environment::Task Task;
23 typedef RunQueue<Task, 1024> Queue;
26 constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
27 ThreadPoolTempl* pool;
33 constexpr ThreadData() : thread(), steal_partition(0), queue() {}
34 std::unique_ptr<Thread> thread;
35 std::atomic<unsigned> steal_partition;
39 ThreadPoolTempl(
int num_threads, Environment env = Environment()) : ThreadPoolTempl(num_threads, true, env) {}
41 ThreadPoolTempl(
int num_threads,
bool allow_spinning, Environment env = Environment())
43 num_threads_(num_threads),
44 allow_spinning_(allow_spinning),
49 allow_spinning && num_threads > 0 ? kSpinCount / kMaxSpinningThreads / num_threads : 0),
50 thread_data_(num_threads),
51 all_coprimes_(num_threads),
52 waiters_(num_threads),
53 global_steal_partition_(EncodePartition(0, num_threads_)),
59 waiters_.resize(num_threads_);
67 eigen_plain_assert(num_threads_ < kMaxThreads);
68 for (
int i = 1; i <= num_threads_; ++i) {
69 all_coprimes_.emplace_back(i);
70 ComputeCoprimes(i, &all_coprimes_.back());
72#ifndef EIGEN_THREAD_LOCAL
73 init_barrier_.reset(
new Barrier(num_threads_));
75 thread_data_.resize(num_threads_);
76 for (
int i = 0; i < num_threads_; i++) {
77 SetStealPartition(i, EncodePartition(0, num_threads_));
78 thread_data_[i].thread.reset(env_.CreateThread([
this, i]() { WorkerLoop(i); }));
80#ifndef EIGEN_THREAD_LOCAL
83 init_barrier_->Wait();
98 for (
size_t i = 0; i < thread_data_.size(); i++) {
99 thread_data_[i].queue.Flush();
104 for (
size_t i = 0; i < thread_data_.size(); ++i) thread_data_[i].thread.reset();
107 void SetStealPartitions(
const std::vector<std::pair<unsigned, unsigned>>& partitions) {
108 eigen_plain_assert(partitions.size() ==
static_cast<std::size_t
>(num_threads_));
111 for (
int i = 0; i < num_threads_; i++) {
112 const auto& pair = partitions[i];
113 unsigned start = pair.first, end = pair.second;
114 AssertBounds(start, end);
115 unsigned val = EncodePartition(start, end);
116 SetStealPartition(i, val);
120 void Schedule(std::function<
void()> fn) EIGEN_OVERRIDE { ScheduleWithHint(std::move(fn), 0, num_threads_); }
122 void ScheduleWithHint(std::function<
void()> fn,
int start,
int limit)
override {
123 Task t = env_.CreateTask(std::move(fn));
124 PerThread* pt = GetPerThread();
125 if (pt->pool ==
this) {
127 Queue& q = thread_data_[pt->thread_id].queue;
128 t = q.PushFront(std::move(t));
132 eigen_plain_assert(start < limit);
133 eigen_plain_assert(limit <= num_threads_);
134 int num_queues = limit - start;
135 int rnd = Rand(&pt->rand) % num_queues;
136 eigen_plain_assert(start + rnd < limit);
137 Queue& q = thread_data_[start + rnd].queue;
138 t = q.PushBack(std::move(t));
148 if (IsNotifyParkedThreadRequired()) {
157 void MaybeGetTask(Task* t) {
158 PerThread* pt = GetPerThread();
159 Queue& q = thread_data_[pt->thread_id].queue;
162 if (num_threads_ == 1) {
169 for (
int i = 0; i < spin_count_ && !t->f; ++i) *t = q.PopFront();
171 if (EIGEN_PREDICT_FALSE(!t->f)) *t = LocalSteal();
172 if (EIGEN_PREDICT_FALSE(!t->f)) *t = GlobalSteal();
173 if (EIGEN_PREDICT_FALSE(!t->f)) {
174 if (allow_spinning_ && StartSpinning()) {
175 for (
int i = 0; i < spin_count_ && !t->f; ++i) *t = GlobalSteal();
177 bool has_no_notify_task = StopSpinning();
184 if (has_no_notify_task && !t->f) *t = q.PopFront();
190 void Cancel() EIGEN_OVERRIDE {
195#ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION
196 for (
size_t i = 0; i < thread_data_.size(); i++) {
197 thread_data_[i].thread->OnCancel();
205 int NumThreads() const EIGEN_FINAL {
return num_threads_; }
207 int CurrentThreadId() const EIGEN_FINAL {
208 const PerThread* pt =
const_cast<ThreadPoolTempl*
>(
this)->GetPerThread();
209 if (pt->pool ==
this) {
210 return pt->thread_id;
224 static constexpr int kMaxPartitionBits = 16;
225 static constexpr int kMaxThreads = 1 << kMaxPartitionBits;
227 inline unsigned EncodePartition(
unsigned start,
unsigned limit) {
return (start << kMaxPartitionBits) | limit; }
229 inline void DecodePartition(
unsigned val,
unsigned* start,
unsigned* limit) {
230 *limit = val & (kMaxThreads - 1);
231 val >>= kMaxPartitionBits;
235 void AssertBounds(
int start,
int end) {
236 eigen_plain_assert(start >= 0);
237 eigen_plain_assert(start < end);
238 eigen_plain_assert(end <= num_threads_);
241 inline void SetStealPartition(
size_t i,
unsigned val) {
242 thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
245 inline unsigned GetStealPartition(
int i) {
return thread_data_[i].steal_partition.load(std::memory_order_relaxed); }
247 void ComputeCoprimes(
int N, MaxSizeVector<unsigned>* coprimes) {
248 for (
int i = 1; i <= N; i++) {
258 coprimes->push_back(i);
264 static constexpr int kMaxSpinningThreads = 1;
268 static constexpr int kSpinCount = 5000;
278 static constexpr int kMinActiveThreadsToStartSpinning = 4;
280 struct SpinningState {
291 static constexpr uint64_t kNumSpinningMask = 0x00000000FFFFFFFF;
292 static constexpr uint64_t kNumNoNotifyMask = 0xFFFFFFFF00000000;
293 static constexpr uint64_t kNumNoNotifyShift = 32;
295 uint64_t num_spinning;
296 uint64_t num_no_notification;
300 static SpinningState Decode(uint64_t state) {
301 uint64_t num_spinning = (state & kNumSpinningMask);
302 uint64_t num_no_notification = (state & kNumNoNotifyMask) >> kNumNoNotifyShift;
304 eigen_plain_assert(num_no_notification <= num_spinning);
305 return {num_spinning, num_no_notification};
309 uint64_t Encode()
const {
310 eigen_plain_assert(num_no_notification <= num_spinning);
311 return (num_no_notification << kNumNoNotifyShift) | num_spinning;
316 const int num_threads_;
317 const bool allow_spinning_;
318 const int spin_count_;
319 MaxSizeVector<ThreadData> thread_data_;
320 MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
321 MaxSizeVector<EventCount::Waiter> waiters_;
322 unsigned global_steal_partition_;
323 std::atomic<uint64_t> spinning_state_;
324 std::atomic<unsigned> blocked_;
325 std::atomic<bool> done_;
326 std::atomic<bool> cancelled_;
328#ifndef EIGEN_THREAD_LOCAL
329 std::unique_ptr<Barrier> init_barrier_;
330 EIGEN_MUTEX per_thread_map_mutex_;
331 std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
334 unsigned NumBlockedThreads()
const {
return blocked_.load(); }
335 unsigned NumActiveThreads()
const {
return num_threads_ - blocked_.load(); }
338 void WorkerLoop(
int thread_id) {
339#ifndef EIGEN_THREAD_LOCAL
340 std::unique_ptr<PerThread> new_pt(
new PerThread());
341 per_thread_map_mutex_.lock();
342 bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
343 eigen_plain_assert(insertOK);
344 EIGEN_UNUSED_VARIABLE(insertOK);
345 per_thread_map_mutex_.unlock();
346 init_barrier_->Notify();
347 init_barrier_->Wait();
349 PerThread* pt = GetPerThread();
351 pt->rand = GlobalThreadIdHash();
352 pt->thread_id = thread_id;
354 while (!cancelled_.load(std::memory_order_relaxed)) {
358 if (EIGEN_PREDICT_FALSE(!t.f)) {
359 EventCount::Waiter* waiter = &waiters_[pt->thread_id];
360 if (!WaitForWork(waiter, &t))
return;
362 if (EIGEN_PREDICT_TRUE(t.f)) env_.ExecuteTask(t);
368 Task Steal(
unsigned start,
unsigned limit) {
369 PerThread* pt = GetPerThread();
370 const size_t size = limit - start;
371 unsigned r = Rand(&pt->rand);
374 eigen_plain_assert(all_coprimes_[size - 1].size() < (1 << 30));
375 unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
376 unsigned index = ((uint64_t)all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
377 unsigned inc = all_coprimes_[size - 1][index];
379 for (
unsigned i = 0; i < size; i++) {
380 eigen_plain_assert(start + victim < limit);
381 Task t = thread_data_[start + victim].queue.PopBack();
386 if (victim >= size) {
387 victim -=
static_cast<unsigned int>(size);
395 PerThread* pt = GetPerThread();
396 unsigned partition = GetStealPartition(pt->thread_id);
399 if (global_steal_partition_ == partition)
return Task();
400 unsigned start, limit;
401 DecodePartition(partition, &start, &limit);
402 AssertBounds(start, limit);
404 return Steal(start, limit);
408 Task GlobalSteal() {
return Steal(0, num_threads_); }
413 bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
414 eigen_plain_assert(!t->f);
419 int victim = NonEmptyQueueIndex();
425 *t = thread_data_[victim].queue.PopBack();
434 if (done_ && blocked_ ==
static_cast<unsigned>(num_threads_)) {
441 if (NonEmptyQueueIndex() != -1) {
455 ec_.CommitWait(waiter);
460 int NonEmptyQueueIndex() {
461 PerThread* pt = GetPerThread();
465 const size_t size = thread_data_.size();
466 unsigned r = Rand(&pt->rand);
467 unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
468 unsigned victim = r % size;
469 for (
unsigned i = 0; i < size; i++) {
470 if (!thread_data_[victim].queue.Empty()) {
474 if (victim >= size) {
475 victim -=
static_cast<unsigned int>(size);
485 bool StartSpinning() {
486 if (NumActiveThreads() > kMinActiveThreadsToStartSpinning)
return false;
488 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
490 SpinningState state = SpinningState::Decode(spinning);
492 if ((state.num_spinning - state.num_no_notification) >= kMaxSpinningThreads) {
497 ++state.num_spinning;
499 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
510 bool StopSpinning() {
511 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
513 SpinningState state = SpinningState::Decode(spinning);
516 --state.num_spinning;
519 bool has_no_notify_task = state.num_no_notification > 0;
520 if (has_no_notify_task) --state.num_no_notification;
522 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
523 return has_no_notify_task;
532 bool IsNotifyParkedThreadRequired() {
533 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
535 SpinningState state = SpinningState::Decode(spinning);
540 if (state.num_no_notification == state.num_spinning)
return true;
543 ++state.num_no_notification;
545 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
551 static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
552 return std::hash<std::thread::id>()(std::this_thread::get_id());
555 EIGEN_STRONG_INLINE PerThread* GetPerThread() {
556#ifndef EIGEN_THREAD_LOCAL
557 static PerThread dummy;
558 auto it = per_thread_map_.find(GlobalThreadIdHash());
559 if (it == per_thread_map_.end()) {
562 return it->second.get();
565 EIGEN_THREAD_LOCAL PerThread per_thread_;
566 PerThread* pt = &per_thread_;
571 static EIGEN_STRONG_INLINE
unsigned Rand(uint64_t* state) {
572 uint64_t current = *state;
574 *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
576 return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
580typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1