diff options
-rw-r--r-- | include/ctranslate2/replica_pool.h | 63 | ||||
-rw-r--r-- | include/ctranslate2/thread_pool.h | 27 |
2 files changed, 48 insertions, 42 deletions
diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 5560a3c7..1cb9ed1a 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -1,6 +1,7 @@ #pragma once #include <chrono> +#include <future> #include "batch_reader.h" #include "models/model.h" @@ -51,16 +52,46 @@ namespace ctranslate2 { max_queued_batches); } + // Posts a function and return its result as a future. + // The function will be run with the first available replica. + // The function must have the signature: Result(Replica&) template <typename Result, typename Func> std::future<Result> post(Func func) { - std::function<Result()> wrapped_func = [func = std::move(func)]() { + auto batched_func = [func = std::move(func)](Replica& replica) { + std::vector<Result> results; + results.reserve(1); + results.emplace_back(func(replica)); + return results; + }; + + auto futures = post_batch<Result>(std::move(batched_func), 1); + return std::move(futures[0]); + } + + // Posts a function and return one future per result. + // The function will be run with the first available replica. + // The function must have the signature: std::vector<Result>(Replica&) + template <typename Result, typename Func> + std::vector<std::future<Result>> post_batch(Func func, size_t num_results) { + std::vector<std::promise<Result>> promises(num_results); + std::vector<std::future<Result>> futures; + futures.reserve(promises.size()); + for (auto& promise : promises) + futures.emplace_back(promise.get_future()); + + post_batch(std::move(func), std::move(promises)); + + return futures; + } + + // Same as above, but taking the list of promises directly. + template <typename Result, typename Func> + void post_batch(Func func, std::vector<std::promise<Result>> promises) { + auto wrapped_func = [func = std::move(func)]() { return func(get_thread_replica()); }; - auto job = std::make_unique<FunctionJob<Result>>(std::move(wrapped_func)); - auto future = job->get_future(); - _thread_pool->post(std::move(job)); - return future; + post_func(std::move(wrapped_func), std::move(promises)); } // Number of batches in the work queue. @@ -144,11 +175,9 @@ namespace ctranslate2 { for (const size_t index : batch.example_index) batch_promises.emplace_back(std::move(promises[index])); - auto job = std::make_unique<BatchJob<Result, Func>>(std::move(batch), - std::move(batch_promises), - func); - - _thread_pool->post(std::move(job)); + post_batch<Result>( + [batch = std::move(batch), func](Replica& replica) { return func(replica, batch); }, + std::move(batch_promises)); } } @@ -233,11 +262,16 @@ namespace ctranslate2 { } template <typename Result, typename Func> + void post_func(Func func, std::vector<std::promise<Result>> promises) { + _thread_pool->post(std::make_unique<BatchJob<Result, Func>>(std::move(promises), + std::move(func))); + } + + template <typename Result, typename Func> class BatchJob : public Job { public: - BatchJob(Batch batch, std::vector<std::promise<Result>> promises, Func func) - : _batch(std::move(batch)) - , _promises(std::move(promises)) + BatchJob(std::vector<std::promise<Result>> promises, Func func) + : _promises(std::move(promises)) , _func(std::move(func)) { } @@ -247,7 +281,7 @@ namespace ctranslate2 { std::exception_ptr exception; try { - results = _func(get_thread_replica(), _batch); + results = _func(); } catch (...) { exception = std::current_exception(); } @@ -261,7 +295,6 @@ namespace ctranslate2 { } private: - const Batch _batch; std::vector<std::promise<Result>> _promises; const Func _func; }; diff --git a/include/ctranslate2/thread_pool.h b/include/ctranslate2/thread_pool.h index 1cffb017..826b7e57 100644 --- a/include/ctranslate2/thread_pool.h +++ b/include/ctranslate2/thread_pool.h @@ -3,7 +3,6 @@ #include <atomic> #include <condition_variable> #include <functional> -#include <future> #include <limits> #include <memory> #include <mutex> @@ -26,32 +25,6 @@ namespace ctranslate2 { std::atomic<size_t>* _counter = nullptr; }; - // Job running a function. - template <typename Result> - class FunctionJob : public Job { - public: - FunctionJob(std::function<Result()> func) - : _func(std::move(func)) - { - } - - std::future<Result> get_future() { - return _promise.get_future(); - } - - void run() override { - try { - _promise.set_value(_func()); - } catch (...) { - _promise.set_exception(std::current_exception()); - } - } - - private: - std::function<Result()> _func; - std::promise<Result> _promise; - }; - // A thread-safe queue of jobs. class JobQueue { public: |