Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2022-04-19 06:06:00 +0300
committerCopybara-Service <copybara-worker@google.com>2022-04-19 06:07:15 +0300
commita09683b8da7164b9c5704f88aef2dc65aa583e5d (patch)
tree5f885bb4616e1287ba31923b33973aaaaf3922bd
parent915898ed1a46401f1dfb3b23563cad7f89b83fa0 (diff)
Refactor Thread internals for clarity and efficiency.
On the clarity side, the thread main loop is now just: while (GetNewStateOtherThanReady() == State::HasWork) { RevertToReadyState(); } On the efficiency side: * Locking and atomic ops have been reduced, we used to lock state_mutex_ around the entire thread task execution,now we are only locking it anymore around notify/wait on the state_cond_ condition_variable, so this mutex is renamed state_cond_mutex_, which clarifies its purpose. * We used to perform a redundant reload-acquire of the new state_ in the main thread loop. * Some accesses are demoted to relaxed because they are already ordered by other release-acquire relationships. * A notify_all becomes notify_one. * Send all thread exit requests upfront so threads can exit in parallel. A comment is added on Thread::task_ to explain the release-acquire relationships making this all work. Internal code is broken into functions that are only ever called from the main thread, and functions that are only ever called from the worker thread. That specialization made further simplifications and performance gains obvious. It was found by continuous integration that some TFLite users construct and destroy the context from two different threads, due to the use of reference-counting. That means that the notion of "main thread" is not that solid. Accordingly, instances of "main thread" in comments and identifiers have been rephrased as "outside thread" as opposed to worker thread. Tested with TSan (also enabled on presubmits) so fairly confident that this is correct. PiperOrigin-RevId: 442697771
-rw-r--r--ruy/thread_pool.cc212
-rw-r--r--ruy/thread_pool.h4
2 files changed, 136 insertions, 80 deletions
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
index 298820e..2e2ca2c 100644
--- a/ruy/thread_pool.cc
+++ b/ruy/thread_pool.cc
@@ -34,125 +34,178 @@ namespace ruy {
// A worker thread.
class Thread {
public:
- enum class State {
- Startup, // The initial state before the thread main loop runs.
- Ready, // Is not working, has not yet received new work to do.
- HasWork, // Has work to do.
- ExitAsSoonAsPossible // Should exit at earliest convenience.
- };
-
- explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
- Duration spin_duration)
- : task_(nullptr),
- state_(State::Startup),
- counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
+ explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration)
+ : state_(State::Startup),
+ count_busy_threads_(count_busy_threads),
spin_duration_(spin_duration) {
thread_.reset(new std::thread(ThreadFunc, this));
}
+ void RequestExitAsSoonAsPossible() {
+ ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible);
+ }
+
~Thread() {
- ChangeState(State::ExitAsSoonAsPossible);
+ RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible);
thread_->join();
}
- // Changes State; may be called from either the worker thread
- // or the master thread; however, not all state transitions are legal,
- // which is guarded by assertions.
+ // Called by an outside thead to give work to the worker thread.
+ void StartWork(Task* task) {
+ ChangeStateFromOutsideThread(State::HasWork, task);
+ }
+
+ private:
+ enum class State {
+ Startup, // The initial state before the thread loop runs.
+ Ready, // Is not working, has not yet received new work to do.
+ HasWork, // Has work to do.
+ ExitAsSoonAsPossible // Should exit at earliest convenience.
+ };
+
+ // Implements the state_ change to State::Ready, which is where we consume
+ // task_. Only called on the worker thread.
+ // Reads task_, so assumes ordering past any prior writes to task_.
+ void RevertToReadyState() {
+ RUY_TRACE_SCOPE_NAME("Worker thread task");
+ // See task_ member comment for the ordering of accesses.
+ if (task_) {
+ task_->Run();
+ task_ = nullptr;
+ }
+ // No need to notify state_cond_, since only the worker thread ever waits
+ // on it, and we are that thread.
+ // Relaxed order because ordering is already provided by the
+ // count_busy_threads_->DecrementCount() at the next line, since the next
+ // state_ mutation will be to give new work and that won't happen before
+ // the outside thread has finished the current batch with a
+ // count_busy_threads_->Wait().
+ state_.store(State::Ready, std::memory_order_relaxed);
+ count_busy_threads_->DecrementCount();
+ }
+
+ // Changes State, from outside thread.
//
// The Task argument is to be used only with new_state==HasWork.
// It specifies the Task being handed to this Thread.
- void ChangeState(State new_state, Task* task = nullptr) {
- state_mutex_.lock();
- State old_state = state_.load(std::memory_order_relaxed);
+ //
+ // new_task is only used with State::HasWork.
+ void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) {
+ RUY_DCHECK(new_state == State::ExitAsSoonAsPossible ||
+ new_state == State::HasWork);
+ RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork));
+
+#ifndef NDEBUG
+ // Debug-only sanity checks based on old_state.
+ State old_state = state_.load();
+ RUY_DCHECK_NE(old_state, new_state);
+ RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork);
RUY_DCHECK_NE(old_state, new_state);
+#endif
+
switch (new_state) {
- case State::Ready:
- RUY_DCHECK(old_state == State::Startup || old_state == State::HasWork);
- if (task_) {
- // Doing work is part of reverting to 'ready' state.
- task_->Run();
- task_ = nullptr;
- }
- break;
case State::HasWork:
- RUY_DCHECK(old_state == State::Ready);
+ // See task_ member comment for the ordering of accesses.
RUY_DCHECK(!task_);
- task_ = task;
+ task_ = new_task;
break;
case State::ExitAsSoonAsPossible:
- RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork);
break;
default:
abort();
}
- state_.store(new_state, std::memory_order_relaxed);
- state_cond_.notify_all();
- state_mutex_.unlock();
- if (new_state == State::Ready) {
- counter_to_decrement_when_ready_->DecrementCount();
- }
+ // Release order because the worker thread will read this with acquire
+ // order.
+ state_.store(new_state, std::memory_order_release);
+ state_cond_mutex_.lock();
+ state_cond_.notify_one(); // Only this one worker thread cares.
+ state_cond_mutex_.unlock();
}
static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
- // Called by the master thead to give this thread work to do.
- void StartWork(Task* task) { ChangeState(State::HasWork, task); }
+ // Waits for state_ to be different from State::Ready, and returns that
+ // new value.
+ State GetNewStateOtherThanReady() {
+ State new_state;
+ const auto& new_state_not_ready = [this, &new_state]() {
+ new_state = state_.load(std::memory_order_acquire);
+ return new_state != State::Ready;
+ };
+ RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
+ Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_);
+ return new_state;
+ }
- private:
// Thread entry point.
void ThreadFuncImpl() {
RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
- ChangeState(State::Ready);
+ RevertToReadyState();
// Suppress denormals to avoid computation inefficiency.
ScopedSuppressDenormals suppress_denormals;
- // Thread main loop
- while (true) {
- RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
- // In the 'Ready' state, we have nothing to do but to wait until
- // we switch to another state.
- const auto& condition = [this]() {
- return state_.load(std::memory_order_acquire) != State::Ready;
- };
- RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
- Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
-
- // Act on new state.
- switch (state_.load(std::memory_order_acquire)) {
- case State::HasWork: {
- RUY_TRACE_SCOPE_NAME("Worker thread task");
- // Got work to do! So do it, and then revert to 'Ready' state.
- ChangeState(State::Ready);
- break;
- }
- case State::ExitAsSoonAsPossible:
- return;
- default:
- abort();
- }
+ // Thread loop
+ while (GetNewStateOtherThanReady() == State::HasWork) {
+ RevertToReadyState();
}
+
+ // Thread end. We should only get here if we were told to exit.
+ RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible);
}
- // The underlying thread.
+ // The underlying thread. Used to join on destruction.
std::unique_ptr<std::thread> thread_;
// The task to be worked on.
- Task* task_;
+ //
+ // The ordering of reads and writes to task_ is as follows.
+ //
+ // 1. The outside thread gives new work by calling
+ // ChangeStateFromOutsideThread(State::HasWork, new_task);
+ // That does:
+ // - a. Write task_ = new_task (non-atomic).
+ // - b. Store state_ = State::HasWork (memory_order_release).
+ // 2. The worker thread picks up the new state by calling
+ // GetNewStateOtherThanReady()
+ // That does:
+ // - c. Load state (memory_order_acquire).
+ // The worker thread then reads the new task in RevertToReadyState().
+ // That does:
+ // - d. Read task_ (non-atomic).
+ // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and
+ // does:
+ // - e. Write task_ = nullptr (non-atomic).
+ // And then calls Call count_busy_threads_->DecrementCount()
+ // which does
+ // - f. Store count_busy_threads_ (memory_order_release).
+ // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker
+ // threads by calling count_busy_threads_->Wait(), which does:
+ // - g. Load count_busy_threads_ (memory_order_acquire).
+ //
+ // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are
+ // ordered by the release-acquire relationship of accesses to state_ (b. ->
+ // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by
+ // the release-acquire relationship of accesses to count_busy_threads_ (f. ->
+ // g.).
+ Task* task_ = nullptr;
- // The condition variable and mutex guarding state changes.
+ // Condition variable used by the outside thread to notify the worker thread
+ // of a state change.
std::condition_variable state_cond_;
- std::mutex state_mutex_;
+
+ // Mutex used to guard state_cond_
+ std::mutex state_cond_mutex_;
// The state enum tells if we're currently working, waiting for work, etc.
- // Its concurrent accesses by the thread and main threads are guarded by
- // state_mutex_, and can thus use memory_order_relaxed. This still needs
- // to be a std::atomic because we use WaitForVariableChange.
+ // It is written to from either the outside thread or the worker thread,
+ // in the ChangeState method.
+ // It is only ever read by the worker thread.
std::atomic<State> state_;
// pointer to the master's thread BlockingCounter object, to notify the
// master thread of when this thread switches to the 'Ready' state.
- BlockingCounter* const counter_to_decrement_when_ready_;
+ BlockingCounter* const count_busy_threads_;
// See ThreadPool::spin_duration_.
const Duration spin_duration_;
@@ -170,7 +223,7 @@ void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
// Task #0 will be run on the current thread.
CreateThreads(task_count - 1);
- counter_to_decrement_when_ready_.Reset(task_count - 1);
+ count_busy_threads_.Reset(task_count - 1);
for (int i = 1; i < task_count; i++) {
RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
@@ -183,7 +236,7 @@ void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
// Wait for the threads submitted above to finish.
- counter_to_decrement_when_ready_.Wait(spin_duration_);
+ count_busy_threads_.Wait(spin_duration_);
}
// Ensures that the pool has at least the given count of threads.
@@ -195,15 +248,18 @@ void ThreadPool::CreateThreads(int threads_count) {
if (threads_.size() >= unsigned_threads_count) {
return;
}
- counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
+ count_busy_threads_.Reset(threads_count - threads_.size());
while (threads_.size() < unsigned_threads_count) {
- threads_.push_back(
- new Thread(&counter_to_decrement_when_ready_, spin_duration_));
+ threads_.push_back(new Thread(&count_busy_threads_, spin_duration_));
}
- counter_to_decrement_when_ready_.Wait(spin_duration_);
+ count_busy_threads_.Wait(spin_duration_);
}
ThreadPool::~ThreadPool() {
+ // Send all exit requests upfront so threads can work on them in parallel.
+ for (auto w : threads_) {
+ w->RequestExitAsSoonAsPossible();
+ }
for (auto w : threads_) {
delete w;
}
diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h
index e3b6803..946be3d 100644
--- a/ruy/thread_pool.h
+++ b/ruy/thread_pool.h
@@ -98,12 +98,12 @@ class ThreadPool {
// copy construction disallowed
ThreadPool(const ThreadPool&) = delete;
- // The threads in this pool. They are owned by the pool:
+ // The worker threads in this pool. They are owned by the pool:
// the pool creates threads and destroys them in its destructor.
std::vector<Thread*> threads_;
// The BlockingCounter used to wait for the threads.
- BlockingCounter counter_to_decrement_when_ready_;
+ BlockingCounter count_busy_threads_;
// This value was empirically derived with some microbenchmark, we don't have
// high confidence in it.