diff options
Diffstat (limited to 'ruy/thread_pool.h')
-rw-r--r-- | ruy/thread_pool.h | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h new file mode 100644 index 0000000..04c201c --- /dev/null +++ b/ruy/thread_pool.h @@ -0,0 +1,102 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 +// license. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ + +#include <vector> + +#include "ruy/blocking_counter.h" + +namespace ruy { + +// A workload for a thread. +struct Task { + virtual ~Task() {} + virtual void Run() = 0; +}; + +class Thread; + +// A simple pool of threads, that only allows the very +// specific parallelization pattern that we use here: +// One thread, which we call the 'main thread', calls Execute, distributing +// a Task each to N threads, being N-1 'worker threads' and the main thread +// itself. After the main thread has completed its own Task, it waits for +// the worker threads to have all completed. That is the only synchronization +// performed by this ThreadPool. +// +// In particular, there is a naive 1:1 mapping of Tasks to threads. +// This ThreadPool considers it outside of its own scope to try to work +// with fewer threads than there are Tasks. The idea is that such N:M mappings +// of tasks to threads can be implemented as a higher-level feature on top of +// the present low-level 1:1 threadpool. For example, a user might have a +// Task subclass referencing a shared atomic counter indexing into a vector of +// finer-granularity subtasks. Different threads would then concurrently +// increment this atomic counter, getting each their own subtasks to work on. +// That approach is the one used in ruy's multi-thread matrix multiplication +// implementation --- see ruy's TrMulTask. +class ThreadPool { + public: + ThreadPool() {} + + ~ThreadPool(); + + // Executes task_count tasks on task_count threads. + // Grows the threadpool as needed to have at least (task_count-1) threads. + // The 0-th task is run on the thread on which Execute is called: that + // is by definition what we call the "main thread". Synchronization of all + // threads is performed before this function returns. + // + // As explained in the class comment, there is a 1:1 mapping of tasks to + // threads. If you need something smarter than that, for instance if you + // want to run an unbounded number of tasks on a bounded number of threads, + // then you need something higher-level than this ThreadPool, that can + // be layered on top of it by appropriately subclassing Tasks. + // + // TaskType must be a subclass of ruy::Task. That is implicitly guarded by + // the static_cast in this inline implementation. + template <typename TaskType> + void Execute(int task_count, TaskType* tasks) { + ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks)); + } + + private: + // Ensures that the pool has at least the given count of threads. + // If any new thread has to be created, this function waits for it to + // be ready. + void CreateThreads(int threads_count); + + // Non-templatized implementation of the public Execute method. + // See the inline implementation of Execute for how this is used. + void ExecuteImpl(int task_count, int stride, Task* tasks); + + // copy construction disallowed + ThreadPool(const ThreadPool&) = delete; + + // The threads in this pool. They are owned by the pool: + // the pool creates threads and destroys them in its destructor. + std::vector<Thread*> threads_; + + // The BlockingCounter used to wait for the threads. + BlockingCounter counter_to_decrement_when_ready_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ |