diff options
-rw-r--r-- | ruy/thread_pool.cc | 212 | ||||
-rw-r--r-- | ruy/thread_pool.h | 4 |
2 files changed, 136 insertions, 80 deletions
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc index 298820e..2e2ca2c 100644 --- a/ruy/thread_pool.cc +++ b/ruy/thread_pool.cc @@ -34,125 +34,178 @@ namespace ruy { // A worker thread. class Thread { public: - enum class State { - Startup, // The initial state before the thread main loop runs. - Ready, // Is not working, has not yet received new work to do. - HasWork, // Has work to do. - ExitAsSoonAsPossible // Should exit at earliest convenience. - }; - - explicit Thread(BlockingCounter* counter_to_decrement_when_ready, - Duration spin_duration) - : task_(nullptr), - state_(State::Startup), - counter_to_decrement_when_ready_(counter_to_decrement_when_ready), + explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration) + : state_(State::Startup), + count_busy_threads_(count_busy_threads), spin_duration_(spin_duration) { thread_.reset(new std::thread(ThreadFunc, this)); } + void RequestExitAsSoonAsPossible() { + ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible); + } + ~Thread() { - ChangeState(State::ExitAsSoonAsPossible); + RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible); thread_->join(); } - // Changes State; may be called from either the worker thread - // or the master thread; however, not all state transitions are legal, - // which is guarded by assertions. + // Called by an outside thead to give work to the worker thread. + void StartWork(Task* task) { + ChangeStateFromOutsideThread(State::HasWork, task); + } + + private: + enum class State { + Startup, // The initial state before the thread loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + // Implements the state_ change to State::Ready, which is where we consume + // task_. Only called on the worker thread. + // Reads task_, so assumes ordering past any prior writes to task_. + void RevertToReadyState() { + RUY_TRACE_SCOPE_NAME("Worker thread task"); + // See task_ member comment for the ordering of accesses. + if (task_) { + task_->Run(); + task_ = nullptr; + } + // No need to notify state_cond_, since only the worker thread ever waits + // on it, and we are that thread. + // Relaxed order because ordering is already provided by the + // count_busy_threads_->DecrementCount() at the next line, since the next + // state_ mutation will be to give new work and that won't happen before + // the outside thread has finished the current batch with a + // count_busy_threads_->Wait(). + state_.store(State::Ready, std::memory_order_relaxed); + count_busy_threads_->DecrementCount(); + } + + // Changes State, from outside thread. // // The Task argument is to be used only with new_state==HasWork. // It specifies the Task being handed to this Thread. - void ChangeState(State new_state, Task* task = nullptr) { - state_mutex_.lock(); - State old_state = state_.load(std::memory_order_relaxed); + // + // new_task is only used with State::HasWork. + void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) { + RUY_DCHECK(new_state == State::ExitAsSoonAsPossible || + new_state == State::HasWork); + RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork)); + +#ifndef NDEBUG + // Debug-only sanity checks based on old_state. + State old_state = state_.load(); + RUY_DCHECK_NE(old_state, new_state); + RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork); RUY_DCHECK_NE(old_state, new_state); +#endif + switch (new_state) { - case State::Ready: - RUY_DCHECK(old_state == State::Startup || old_state == State::HasWork); - if (task_) { - // Doing work is part of reverting to 'ready' state. - task_->Run(); - task_ = nullptr; - } - break; case State::HasWork: - RUY_DCHECK(old_state == State::Ready); + // See task_ member comment for the ordering of accesses. RUY_DCHECK(!task_); - task_ = task; + task_ = new_task; break; case State::ExitAsSoonAsPossible: - RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork); break; default: abort(); } - state_.store(new_state, std::memory_order_relaxed); - state_cond_.notify_all(); - state_mutex_.unlock(); - if (new_state == State::Ready) { - counter_to_decrement_when_ready_->DecrementCount(); - } + // Release order because the worker thread will read this with acquire + // order. + state_.store(new_state, std::memory_order_release); + state_cond_mutex_.lock(); + state_cond_.notify_one(); // Only this one worker thread cares. + state_cond_mutex_.unlock(); } static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } - // Called by the master thead to give this thread work to do. - void StartWork(Task* task) { ChangeState(State::HasWork, task); } + // Waits for state_ to be different from State::Ready, and returns that + // new value. + State GetNewStateOtherThanReady() { + State new_state; + const auto& new_state_not_ready = [this, &new_state]() { + new_state = state_.load(std::memory_order_acquire); + return new_state != State::Ready; + }; + RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING); + Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_); + return new_state; + } - private: // Thread entry point. void ThreadFuncImpl() { RUY_TRACE_SCOPE_NAME("Ruy worker thread function"); - ChangeState(State::Ready); + RevertToReadyState(); // Suppress denormals to avoid computation inefficiency. ScopedSuppressDenormals suppress_denormals; - // Thread main loop - while (true) { - RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration"); - // In the 'Ready' state, we have nothing to do but to wait until - // we switch to another state. - const auto& condition = [this]() { - return state_.load(std::memory_order_acquire) != State::Ready; - }; - RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING); - Wait(condition, spin_duration_, &state_cond_, &state_mutex_); - - // Act on new state. - switch (state_.load(std::memory_order_acquire)) { - case State::HasWork: { - RUY_TRACE_SCOPE_NAME("Worker thread task"); - // Got work to do! So do it, and then revert to 'Ready' state. - ChangeState(State::Ready); - break; - } - case State::ExitAsSoonAsPossible: - return; - default: - abort(); - } + // Thread loop + while (GetNewStateOtherThanReady() == State::HasWork) { + RevertToReadyState(); } + + // Thread end. We should only get here if we were told to exit. + RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible); } - // The underlying thread. + // The underlying thread. Used to join on destruction. std::unique_ptr<std::thread> thread_; // The task to be worked on. - Task* task_; + // + // The ordering of reads and writes to task_ is as follows. + // + // 1. The outside thread gives new work by calling + // ChangeStateFromOutsideThread(State::HasWork, new_task); + // That does: + // - a. Write task_ = new_task (non-atomic). + // - b. Store state_ = State::HasWork (memory_order_release). + // 2. The worker thread picks up the new state by calling + // GetNewStateOtherThanReady() + // That does: + // - c. Load state (memory_order_acquire). + // The worker thread then reads the new task in RevertToReadyState(). + // That does: + // - d. Read task_ (non-atomic). + // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and + // does: + // - e. Write task_ = nullptr (non-atomic). + // And then calls Call count_busy_threads_->DecrementCount() + // which does + // - f. Store count_busy_threads_ (memory_order_release). + // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker + // threads by calling count_busy_threads_->Wait(), which does: + // - g. Load count_busy_threads_ (memory_order_acquire). + // + // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are + // ordered by the release-acquire relationship of accesses to state_ (b. -> + // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by + // the release-acquire relationship of accesses to count_busy_threads_ (f. -> + // g.). + Task* task_ = nullptr; - // The condition variable and mutex guarding state changes. + // Condition variable used by the outside thread to notify the worker thread + // of a state change. std::condition_variable state_cond_; - std::mutex state_mutex_; + + // Mutex used to guard state_cond_ + std::mutex state_cond_mutex_; // The state enum tells if we're currently working, waiting for work, etc. - // Its concurrent accesses by the thread and main threads are guarded by - // state_mutex_, and can thus use memory_order_relaxed. This still needs - // to be a std::atomic because we use WaitForVariableChange. + // It is written to from either the outside thread or the worker thread, + // in the ChangeState method. + // It is only ever read by the worker thread. std::atomic<State> state_; // pointer to the master's thread BlockingCounter object, to notify the // master thread of when this thread switches to the 'Ready' state. - BlockingCounter* const counter_to_decrement_when_ready_; + BlockingCounter* const count_busy_threads_; // See ThreadPool::spin_duration_. const Duration spin_duration_; @@ -170,7 +223,7 @@ void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { // Task #0 will be run on the current thread. CreateThreads(task_count - 1); - counter_to_decrement_when_ready_.Reset(task_count - 1); + count_busy_threads_.Reset(task_count - 1); for (int i = 1; i < task_count; i++) { RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK); auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride; @@ -183,7 +236,7 @@ void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS); // Wait for the threads submitted above to finish. - counter_to_decrement_when_ready_.Wait(spin_duration_); + count_busy_threads_.Wait(spin_duration_); } // Ensures that the pool has at least the given count of threads. @@ -195,15 +248,18 @@ void ThreadPool::CreateThreads(int threads_count) { if (threads_.size() >= unsigned_threads_count) { return; } - counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); + count_busy_threads_.Reset(threads_count - threads_.size()); while (threads_.size() < unsigned_threads_count) { - threads_.push_back( - new Thread(&counter_to_decrement_when_ready_, spin_duration_)); + threads_.push_back(new Thread(&count_busy_threads_, spin_duration_)); } - counter_to_decrement_when_ready_.Wait(spin_duration_); + count_busy_threads_.Wait(spin_duration_); } ThreadPool::~ThreadPool() { + // Send all exit requests upfront so threads can work on them in parallel. + for (auto w : threads_) { + w->RequestExitAsSoonAsPossible(); + } for (auto w : threads_) { delete w; } diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h index e3b6803..946be3d 100644 --- a/ruy/thread_pool.h +++ b/ruy/thread_pool.h @@ -98,12 +98,12 @@ class ThreadPool { // copy construction disallowed ThreadPool(const ThreadPool&) = delete; - // The threads in this pool. They are owned by the pool: + // The worker threads in this pool. They are owned by the pool: // the pool creates threads and destroys them in its destructor. std::vector<Thread*> threads_; // The BlockingCounter used to wait for the threads. - BlockingCounter counter_to_decrement_when_ready_; + BlockingCounter count_busy_threads_; // This value was empirically derived with some microbenchmark, we don't have // high confidence in it. |