Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
NonBlockingThreadPool.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11#define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18template <typename Environment>
19class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
20 public:
21 typedef typename Environment::EnvThread Thread;
22 typedef typename Environment::Task Task;
23 typedef RunQueue<Task, 1024> Queue;
24
25 struct PerThread {
26 constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
27 ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
28 uint64_t rand; // Random generator state.
29 int thread_id; // Worker thread index in pool.
30 };
31
32 struct ThreadData {
33 constexpr ThreadData() : thread(), steal_partition(0), queue() {}
34 std::unique_ptr<Thread> thread;
35 std::atomic<unsigned> steal_partition;
36 Queue queue;
37 };
38
39 ThreadPoolTempl(int num_threads, Environment env = Environment()) : ThreadPoolTempl(num_threads, true, env) {}
40
41 ThreadPoolTempl(int num_threads, bool allow_spinning, Environment env = Environment())
42 : env_(env),
43 num_threads_(num_threads),
44 allow_spinning_(allow_spinning),
45 spin_count_(
46 // TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is proportional to num_threads_ and
47 // we assume that new work is scheduled at a constant rate, so we divide `kSpintCount` by number of
48 // threads and number of spinning threads. The constant was picked based on a fair dice roll, tune it.
49 allow_spinning && num_threads > 0 ? kSpinCount / kMaxSpinningThreads / num_threads : 0),
50 thread_data_(num_threads),
51 all_coprimes_(num_threads),
52 waiters_(num_threads),
53 global_steal_partition_(EncodePartition(0, num_threads_)),
54 spinning_state_(0),
55 blocked_(0),
56 done_(false),
57 cancelled_(false),
58 ec_(waiters_) {
59 waiters_.resize(num_threads_);
60 // Calculate coprimes of all numbers [1, num_threads].
61 // Coprimes are used for random walks over all threads in Steal
62 // and NonEmptyQueueIndex. Iteration is based on the fact that if we take
63 // a random starting thread index t and calculate num_threads - 1 subsequent
64 // indices as (t + coprime) % num_threads, we will cover all threads without
65 // repetitions (effectively getting a presudo-random permutation of thread
66 // indices).
67 eigen_plain_assert(num_threads_ < kMaxThreads);
68 for (int i = 1; i <= num_threads_; ++i) {
69 all_coprimes_.emplace_back(i);
70 ComputeCoprimes(i, &all_coprimes_.back());
71 }
72#ifndef EIGEN_THREAD_LOCAL
73 init_barrier_.reset(new Barrier(num_threads_));
74#endif
75 thread_data_.resize(num_threads_);
76 for (int i = 0; i < num_threads_; i++) {
77 SetStealPartition(i, EncodePartition(0, num_threads_));
78 thread_data_[i].thread.reset(env_.CreateThread([this, i]() { WorkerLoop(i); }));
79 }
80#ifndef EIGEN_THREAD_LOCAL
81 // Wait for workers to initialize per_thread_map_. Otherwise we might race
82 // with them in Schedule or CurrentThreadId.
83 init_barrier_->Wait();
84#endif
85 }
86
87 ~ThreadPoolTempl() {
88 done_ = true;
89
90 // Now if all threads block without work, they will start exiting.
91 // But note that threads can continue to work arbitrary long,
92 // block, submit new work, unblock and otherwise live full life.
93 if (!cancelled_) {
94 ec_.Notify(true);
95 } else {
96 // Since we were cancelled, there might be entries in the queues.
97 // Empty them to prevent their destructor from asserting.
98 for (size_t i = 0; i < thread_data_.size(); i++) {
99 thread_data_[i].queue.Flush();
100 }
101 }
102 // Join threads explicitly (by destroying) to avoid destruction order within
103 // this class.
104 for (size_t i = 0; i < thread_data_.size(); ++i) thread_data_[i].thread.reset();
105 }
106
107 void SetStealPartitions(const std::vector<std::pair<unsigned, unsigned>>& partitions) {
108 eigen_plain_assert(partitions.size() == static_cast<std::size_t>(num_threads_));
109
110 // Pass this information to each thread queue.
111 for (int i = 0; i < num_threads_; i++) {
112 const auto& pair = partitions[i];
113 unsigned start = pair.first, end = pair.second;
114 AssertBounds(start, end);
115 unsigned val = EncodePartition(start, end);
116 SetStealPartition(i, val);
117 }
118 }
119
120 void Schedule(std::function<void()> fn) EIGEN_OVERRIDE { ScheduleWithHint(std::move(fn), 0, num_threads_); }
121
122 void ScheduleWithHint(std::function<void()> fn, int start, int limit) override {
123 Task t = env_.CreateTask(std::move(fn));
124 PerThread* pt = GetPerThread();
125 if (pt->pool == this) {
126 // Worker thread of this pool, push onto the thread's queue.
127 Queue& q = thread_data_[pt->thread_id].queue;
128 t = q.PushFront(std::move(t));
129 } else {
130 // A free-standing thread (or worker of another pool), push onto a random
131 // queue.
132 eigen_plain_assert(start < limit);
133 eigen_plain_assert(limit <= num_threads_);
134 int num_queues = limit - start;
135 int rnd = Rand(&pt->rand) % num_queues;
136 eigen_plain_assert(start + rnd < limit);
137 Queue& q = thread_data_[start + rnd].queue;
138 t = q.PushBack(std::move(t));
139 }
140 // Note: below we touch this after making w available to worker threads.
141 // Strictly speaking, this can lead to a racy-use-after-free. Consider that
142 // Schedule is called from a thread that is neither main thread nor a worker
143 // thread of this pool. Then, execution of w directly or indirectly
144 // completes overall computations, which in turn leads to destruction of
145 // this. We expect that such scenario is prevented by program, that is,
146 // this is kept alive while any threads can potentially be in Schedule.
147 if (!t.f) {
148 if (IsNotifyParkedThreadRequired()) {
149 ec_.Notify(false);
150 }
151 } else {
152 env_.ExecuteTask(t); // Push failed, execute directly.
153 }
154 }
155
156 // Tries to assign work to the current task.
157 void MaybeGetTask(Task* t) {
158 PerThread* pt = GetPerThread();
159 const int thread_id = pt->thread_id;
160 // If we are not a worker thread of this pool, we can't get any work.
161 if (thread_id < 0) return;
162 Queue& q = thread_data_[thread_id].queue;
163 *t = q.PopFront();
164 if (t->f) return;
165 if (num_threads_ == 1) {
166 // For num_threads_ == 1 there is no point in going through the expensive
167 // steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the
168 // victim queues it might reverse the order in which ops are executed
169 // compared to the order in which they are scheduled, which tends to be
170 // counter-productive for the types of I/O workloads single thread pools
171 // tend to be used for.
172 for (int i = 0; i < spin_count_ && !t->f; ++i) *t = q.PopFront();
173 } else {
174 if (EIGEN_PREDICT_FALSE(!t->f)) *t = LocalSteal();
175 if (EIGEN_PREDICT_FALSE(!t->f)) *t = GlobalSteal();
176 if (EIGEN_PREDICT_FALSE(!t->f)) {
177 if (allow_spinning_ && StartSpinning()) {
178 for (int i = 0; i < spin_count_ && !t->f; ++i) *t = GlobalSteal();
179 // Notify `spinning_state_` that we are no longer spinning.
180 bool has_no_notify_task = StopSpinning();
181 // If a task was submitted to the queue without a call to
182 // `ec_.Notify()` (if `IsNotifyParkedThreadRequired()` returned
183 // false), and we didn't steal anything above, we must try to
184 // steal one more time, to make sure that this task will be
185 // executed. We will not necessarily find it, because it might
186 // have been already stolen by some other thread.
187 if (has_no_notify_task && !t->f) *t = GlobalSteal();
188 }
189 }
190 }
191 }
192
193 void Cancel() EIGEN_OVERRIDE {
194 cancelled_ = true;
195 done_ = true;
196
197 // Let each thread know it's been cancelled.
198#ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION
199 for (size_t i = 0; i < thread_data_.size(); i++) {
200 thread_data_[i].thread->OnCancel();
201 }
202#endif
203
204 // Wake up the threads without work to let them exit on their own.
205 ec_.Notify(true);
206 }
207
208 int NumThreads() const EIGEN_FINAL { return num_threads_; }
209
210 int CurrentThreadId() const EIGEN_FINAL {
211 const PerThread* pt = const_cast<ThreadPoolTempl*>(this)->GetPerThread();
212 if (pt->pool == this) {
213 return pt->thread_id;
214 } else {
215 return -1;
216 }
217 }
218
219 private:
220 // Create a single atomic<int> that encodes start and limit information for
221 // each thread.
222 // We expect num_threads_ < 65536, so we can store them in a single
223 // std::atomic<unsigned>.
224 // Exposed publicly as static functions so that external callers can reuse
225 // this encode/decode logic for maintaining their own thread-safe copies of
226 // scheduling and steal domain(s).
227 static constexpr int kMaxPartitionBits = 16;
228 static constexpr int kMaxThreads = 1 << kMaxPartitionBits;
229
230 inline unsigned EncodePartition(unsigned start, unsigned limit) { return (start << kMaxPartitionBits) | limit; }
231
232 inline void DecodePartition(unsigned val, unsigned* start, unsigned* limit) {
233 *limit = val & (kMaxThreads - 1);
234 val >>= kMaxPartitionBits;
235 *start = val;
236 }
237
238 void AssertBounds(int start, int end) {
239 eigen_plain_assert(start >= 0);
240 eigen_plain_assert(start < end); // non-zero sized partition
241 eigen_plain_assert(end <= num_threads_);
242 }
243
244 inline void SetStealPartition(size_t i, unsigned val) {
245 thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
246 }
247
248 inline unsigned GetStealPartition(int i) { return thread_data_[i].steal_partition.load(std::memory_order_relaxed); }
249
250 void ComputeCoprimes(int N, MaxSizeVector<unsigned>* coprimes) {
251 for (int i = 1; i <= N; i++) {
252 unsigned a = i;
253 unsigned b = N;
254 // If GCD(a, b) == 1, then a and b are coprimes.
255 while (b != 0) {
256 unsigned tmp = a;
257 a = b;
258 b = tmp % b;
259 }
260 if (a == 1) {
261 coprimes->push_back(i);
262 }
263 }
264 }
265
266 // Maximum number of threads that can spin in steal loop.
267 static constexpr int kMaxSpinningThreads = 1;
268
269 // The number of steal loop spin iterations before parking (this number is
270 // divided by the number of threads, to get spin count for each thread).
271 static constexpr int kSpinCount = 5000;
272
273 // If there are enough active threads with empty pending-task queues, a thread
274 // that runs out of work can just be parked without spinning, because these
275 // active threads will go into a steal loop after finishing their current
276 // tasks.
277 //
278 // In the worst case when all active threads are executing long/expensive
279 // tasks, the next Schedule() will have to wait until one of the parked
280 // threads will be unparked, however this should be very rare in practice.
281 static constexpr int kMinActiveThreadsToStartSpinning = 4;
282
283 struct SpinningState {
284 // Spinning state layout:
285 //
286 // - Low 32 bits encode the number of threads that are spinning in steal
287 // loop.
288 //
289 // - High 32 bits encode the number of tasks that were submitted to the pool
290 // without a call to `ec_.Notify()`. This number can't be larger than
291 // the number of spinning threads. Each spinning thread, when it exits the
292 // spin loop must check if this number is greater than zero, and maybe
293 // make another attempt to steal a task and decrement it by one.
294 static constexpr uint64_t kNumSpinningMask = 0x00000000FFFFFFFF;
295 static constexpr uint64_t kNumNoNotifyMask = 0xFFFFFFFF00000000;
296 static constexpr uint64_t kNumNoNotifyShift = 32;
297
298 uint64_t num_spinning; // number of spinning threads
299 uint64_t num_no_notification; // number of tasks submitted without
300 // notifying waiting threads
301
302 // Decodes `spinning_state_` value.
303 static SpinningState Decode(uint64_t state) {
304 uint64_t num_spinning = (state & kNumSpinningMask);
305 uint64_t num_no_notification = (state & kNumNoNotifyMask) >> kNumNoNotifyShift;
306
307 eigen_plain_assert(num_no_notification <= num_spinning);
308 return {num_spinning, num_no_notification};
309 }
310
311 // Encodes as `spinning_state_` value.
312 uint64_t Encode() const {
313 eigen_plain_assert(num_no_notification <= num_spinning);
314 return (num_no_notification << kNumNoNotifyShift) | num_spinning;
315 }
316 };
317
318 Environment env_;
319 const int num_threads_;
320 const bool allow_spinning_;
321 const int spin_count_;
322 MaxSizeVector<ThreadData> thread_data_;
323 MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
324 MaxSizeVector<EventCount::Waiter> waiters_;
325 unsigned global_steal_partition_;
326 std::atomic<uint64_t> spinning_state_;
327 std::atomic<unsigned> blocked_;
328 std::atomic<bool> done_;
329 std::atomic<bool> cancelled_;
330 EventCount ec_;
331#ifndef EIGEN_THREAD_LOCAL
332 std::unique_ptr<Barrier> init_barrier_;
333 EIGEN_MUTEX per_thread_map_mutex_; // Protects per_thread_map_.
334 std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
335#endif
336
337 unsigned NumBlockedThreads() const { return blocked_.load(); }
338 unsigned NumActiveThreads() const { return num_threads_ - blocked_.load(); }
339
340 // Main worker thread loop.
341 void WorkerLoop(int thread_id) {
342#ifndef EIGEN_THREAD_LOCAL
343 std::unique_ptr<PerThread> new_pt(new PerThread());
344 per_thread_map_mutex_.lock();
345 bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
346 eigen_plain_assert(insertOK);
347 EIGEN_UNUSED_VARIABLE(insertOK);
348 per_thread_map_mutex_.unlock();
349 init_barrier_->Notify();
350 init_barrier_->Wait();
351#endif
352 PerThread* pt = GetPerThread();
353 pt->pool = this;
354 pt->rand = GlobalThreadIdHash();
355 pt->thread_id = thread_id;
356 Task t;
357 while (!cancelled_.load(std::memory_order_relaxed)) {
358 MaybeGetTask(&t);
359 // If we still don't have a task, wait for one. Return if thread pool is
360 // in cancelled state.
361 if (EIGEN_PREDICT_FALSE(!t.f)) {
362 EventCount::Waiter* waiter = &waiters_[pt->thread_id];
363 if (!WaitForWork(waiter, &t)) return;
364 }
365 if (EIGEN_PREDICT_TRUE(t.f)) env_.ExecuteTask(t);
366 }
367 }
368
369 // Steal tries to steal work from other worker threads in the range [start,
370 // limit) in best-effort manner.
371 Task Steal(unsigned start, unsigned limit) {
372 PerThread* pt = GetPerThread();
373 const size_t size = limit - start;
374 unsigned r = Rand(&pt->rand);
375 // Reduce r into [0, size) range, this utilizes trick from
376 // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
377 eigen_plain_assert(all_coprimes_[size - 1].size() < (1 << 30));
378 unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
379 unsigned index = ((uint64_t)all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
380 unsigned inc = all_coprimes_[size - 1][index];
381
382 for (unsigned i = 0; i < size; i++) {
383 eigen_plain_assert(start + victim < limit);
384 Task t = thread_data_[start + victim].queue.PopBack();
385 if (t.f) {
386 return t;
387 }
388 victim += inc;
389 if (victim >= size) {
390 victim -= static_cast<unsigned int>(size);
391 }
392 }
393 return Task();
394 }
395
396 // Steals work within threads belonging to the partition.
397 Task LocalSteal() {
398 PerThread* pt = GetPerThread();
399 unsigned partition = GetStealPartition(pt->thread_id);
400 // If thread steal partition is the same as global partition, there is no
401 // need to go through the steal loop twice.
402 if (global_steal_partition_ == partition) return Task();
403 unsigned start, limit;
404 DecodePartition(partition, &start, &limit);
405 AssertBounds(start, limit);
406
407 return Steal(start, limit);
408 }
409
410 // Steals work from any other thread in the pool.
411 Task GlobalSteal() { return Steal(0, num_threads_); }
412
413 // WaitForWork blocks until new work is available (returns true), or if it is
414 // time to exit (returns false). Can optionally return a task to execute in t
415 // (in such case t.f != nullptr on return).
416 bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
417 eigen_plain_assert(!t->f);
418 // We already did best-effort emptiness check in Steal, so prepare for
419 // blocking.
420 ec_.Prewait();
421 // Now do a reliable emptiness check.
422 int victim = NonEmptyQueueIndex();
423 if (victim != -1) {
424 ec_.CancelWait();
425 if (cancelled_) {
426 return false;
427 } else {
428 *t = thread_data_[victim].queue.PopBack();
429 return true;
430 }
431 }
432 // Number of blocked threads is used as termination condition.
433 // If we are shutting down and all worker threads blocked without work,
434 // that's we are done.
435 blocked_++;
436 // TODO is blocked_ required to be unsigned?
437 if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
438 ec_.CancelWait();
439 // Almost done, but need to re-check queues.
440 // Consider that all queues are empty and all worker threads are preempted
441 // right after incrementing blocked_ above. Now a free-standing thread
442 // submits work and calls destructor (which sets done_). If we don't
443 // re-check queues, we will exit leaving the work unexecuted.
444 if (NonEmptyQueueIndex() != -1) {
445 // Note: we must not pop from queues before we decrement blocked_,
446 // otherwise the following scenario is possible. Consider that instead
447 // of checking for emptiness we popped the only element from queues.
448 // Now other worker threads can start exiting, which is bad if the
449 // work item submits other work. So we just check emptiness here,
450 // which ensures that all worker threads exit at the same time.
451 blocked_--;
452 return true;
453 }
454 // Reached stable termination state.
455 ec_.Notify(true);
456 return false;
457 }
458 ec_.CommitWait(waiter);
459 blocked_--;
460 return true;
461 }
462
463 int NonEmptyQueueIndex() {
464 PerThread* pt = GetPerThread();
465 // We intentionally design NonEmptyQueueIndex to steal work from
466 // anywhere in the queue so threads don't block in WaitForWork() forever
467 // when all threads in their partition go to sleep. Steal is still local.
468 const size_t size = thread_data_.size();
469 unsigned r = Rand(&pt->rand);
470 unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
471 unsigned victim = r % size;
472 for (unsigned i = 0; i < size; i++) {
473 if (!thread_data_[victim].queue.Empty()) {
474 return victim;
475 }
476 victim += inc;
477 if (victim >= size) {
478 victim -= static_cast<unsigned int>(size);
479 }
480 }
481 return -1;
482 }
483
484 // StartSpinning() checks if the number of threads in the spin loop is less
485 // than the allowed maximum. If so, increments the number of spinning threads
486 // by one and returns true (caller must enter the spin loop). Otherwise
487 // returns false, and the caller must not enter the spin loop.
488 bool StartSpinning() {
489 if (NumActiveThreads() > kMinActiveThreadsToStartSpinning) return false;
490
491 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
492 for (;;) {
493 SpinningState state = SpinningState::Decode(spinning);
494
495 if ((state.num_spinning - state.num_no_notification) >= kMaxSpinningThreads) {
496 return false;
497 }
498
499 // Increment the number of spinning threads.
500 ++state.num_spinning;
501
502 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
503 return true;
504 }
505 }
506 }
507
508 // StopSpinning() decrements the number of spinning threads by one. It also
509 // checks if there were any tasks submitted into the pool without notifying
510 // parked threads, and decrements the count by one. Returns true if the number
511 // of tasks submitted without notification was decremented. In this case,
512 // caller thread might have to call Steal() one more time.
513 bool StopSpinning() {
514 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
515 for (;;) {
516 SpinningState state = SpinningState::Decode(spinning);
517
518 // Decrement the number of spinning threads.
519 --state.num_spinning;
520
521 // Maybe decrement the number of tasks submitted without notification.
522 bool has_no_notify_task = state.num_no_notification > 0;
523 if (has_no_notify_task) --state.num_no_notification;
524
525 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
526 return has_no_notify_task;
527 }
528 }
529 }
530
531 // IsNotifyParkedThreadRequired() returns true if parked thread must be
532 // notified about new added task. If there are threads spinning in the steal
533 // loop, there is no need to unpark any of the waiting threads, the task will
534 // be picked up by one of the spinning threads.
535 bool IsNotifyParkedThreadRequired() {
536 uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
537 for (;;) {
538 SpinningState state = SpinningState::Decode(spinning);
539
540 // If the number of tasks submitted without notifying parked threads is
541 // equal to the number of spinning threads, we must wake up one of the
542 // parked threads.
543 if (state.num_no_notification == state.num_spinning) return true;
544
545 // Increment the number of tasks submitted without notification.
546 ++state.num_no_notification;
547
548 if (spinning_state_.compare_exchange_weak(spinning, state.Encode(), std::memory_order_relaxed)) {
549 return false;
550 }
551 }
552 }
553
554 static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
555 return std::hash<std::thread::id>()(std::this_thread::get_id());
556 }
557
558 EIGEN_STRONG_INLINE PerThread* GetPerThread() {
559#ifndef EIGEN_THREAD_LOCAL
560 static PerThread dummy;
561 auto it = per_thread_map_.find(GlobalThreadIdHash());
562 if (it == per_thread_map_.end()) {
563 return &dummy;
564 } else {
565 return it->second.get();
566 }
567#else
568 EIGEN_THREAD_LOCAL PerThread per_thread_;
569 PerThread* pt = &per_thread_;
570 return pt;
571#endif
572 }
573
574 static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
575 uint64_t current = *state;
576 // Update the internal state
577 *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
578 // Generate the random output (using the PCG-XSH-RS scheme)
579 return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
580 }
581};
582
583typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
584
585} // namespace Eigen
586
587#endif // EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1