From ab31d7e116ef03dcc8309a8847c597a57a8bfc63 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Oct 2022 11:00:15 +0200 Subject: Add method post_batch in ReplicaPool class (#952) --- include/ctranslate2/replica_pool.h | 63 +++++++++++++++++++++++++++++--------- include/ctranslate2/thread_pool.h | 27 ---------------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 5560a3c7..1cb9ed1a 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "batch_reader.h" #include "models/model.h" @@ -51,16 +52,46 @@ namespace ctranslate2 { max_queued_batches); } + // Posts a function and return its result as a future. + // The function will be run with the first available replica. + // The function must have the signature: Result(Replica&) template std::future post(Func func) { - std::function wrapped_func = [func = std::move(func)]() { + auto batched_func = [func = std::move(func)](Replica& replica) { + std::vector results; + results.reserve(1); + results.emplace_back(func(replica)); + return results; + }; + + auto futures = post_batch(std::move(batched_func), 1); + return std::move(futures[0]); + } + + // Posts a function and return one future per result. + // The function will be run with the first available replica. + // The function must have the signature: std::vector(Replica&) + template + std::vector> post_batch(Func func, size_t num_results) { + std::vector> promises(num_results); + std::vector> futures; + futures.reserve(promises.size()); + for (auto& promise : promises) + futures.emplace_back(promise.get_future()); + + post_batch(std::move(func), std::move(promises)); + + return futures; + } + + // Same as above, but taking the list of promises directly. + template + void post_batch(Func func, std::vector> promises) { + auto wrapped_func = [func = std::move(func)]() { return func(get_thread_replica()); }; - auto job = std::make_unique>(std::move(wrapped_func)); - auto future = job->get_future(); - _thread_pool->post(std::move(job)); - return future; + post_func(std::move(wrapped_func), std::move(promises)); } // Number of batches in the work queue. @@ -144,11 +175,9 @@ namespace ctranslate2 { for (const size_t index : batch.example_index) batch_promises.emplace_back(std::move(promises[index])); - auto job = std::make_unique>(std::move(batch), - std::move(batch_promises), - func); - - _thread_pool->post(std::move(job)); + post_batch( + [batch = std::move(batch), func](Replica& replica) { return func(replica, batch); }, + std::move(batch_promises)); } } @@ -232,12 +261,17 @@ namespace ctranslate2 { get_core_offset()); } + template + void post_func(Func func, std::vector> promises) { + _thread_pool->post(std::make_unique>(std::move(promises), + std::move(func))); + } + template class BatchJob : public Job { public: - BatchJob(Batch batch, std::vector> promises, Func func) - : _batch(std::move(batch)) - , _promises(std::move(promises)) + BatchJob(std::vector> promises, Func func) + : _promises(std::move(promises)) , _func(std::move(func)) { } @@ -247,7 +281,7 @@ namespace ctranslate2 { std::exception_ptr exception; try { - results = _func(get_thread_replica(), _batch); + results = _func(); } catch (...) { exception = std::current_exception(); } @@ -261,7 +295,6 @@ namespace ctranslate2 { } private: - const Batch _batch; std::vector> _promises; const Func _func; }; diff --git a/include/ctranslate2/thread_pool.h b/include/ctranslate2/thread_pool.h index 1cffb017..826b7e57 100644 --- a/include/ctranslate2/thread_pool.h +++ b/include/ctranslate2/thread_pool.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -26,32 +25,6 @@ namespace ctranslate2 { std::atomic* _counter = nullptr; }; - // Job running a function. - template - class FunctionJob : public Job { - public: - FunctionJob(std::function func) - : _func(std::move(func)) - { - } - - std::future get_future() { - return _promise.get_future(); - } - - void run() override { - try { - _promise.set_value(_func()); - } catch (...) { - _promise.set_exception(std::current_exception()); - } - } - - private: - std::function _func; - std::promise _promise; - }; - // A thread-safe queue of jobs. class JobQueue { public: -- cgit v1.2.3