diff options
Diffstat (limited to 'include/ctranslate2/replica_pool.h')
-rw-r--r-- | include/ctranslate2/replica_pool.h | 63 |
1 files changed, 48 insertions, 15 deletions
diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 5560a3c7..1cb9ed1a 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -1,6 +1,7 @@ #pragma once #include <chrono> +#include <future> #include "batch_reader.h" #include "models/model.h" @@ -51,16 +52,46 @@ namespace ctranslate2 { max_queued_batches); } + // Posts a function and return its result as a future. + // The function will be run with the first available replica. + // The function must have the signature: Result(Replica&) template <typename Result, typename Func> std::future<Result> post(Func func) { - std::function<Result()> wrapped_func = [func = std::move(func)]() { + auto batched_func = [func = std::move(func)](Replica& replica) { + std::vector<Result> results; + results.reserve(1); + results.emplace_back(func(replica)); + return results; + }; + + auto futures = post_batch<Result>(std::move(batched_func), 1); + return std::move(futures[0]); + } + + // Posts a function and return one future per result. + // The function will be run with the first available replica. + // The function must have the signature: std::vector<Result>(Replica&) + template <typename Result, typename Func> + std::vector<std::future<Result>> post_batch(Func func, size_t num_results) { + std::vector<std::promise<Result>> promises(num_results); + std::vector<std::future<Result>> futures; + futures.reserve(promises.size()); + for (auto& promise : promises) + futures.emplace_back(promise.get_future()); + + post_batch(std::move(func), std::move(promises)); + + return futures; + } + + // Same as above, but taking the list of promises directly. + template <typename Result, typename Func> + void post_batch(Func func, std::vector<std::promise<Result>> promises) { + auto wrapped_func = [func = std::move(func)]() { return func(get_thread_replica()); }; - auto job = std::make_unique<FunctionJob<Result>>(std::move(wrapped_func)); - auto future = job->get_future(); - _thread_pool->post(std::move(job)); - return future; + post_func(std::move(wrapped_func), std::move(promises)); } // Number of batches in the work queue. @@ -144,11 +175,9 @@ namespace ctranslate2 { for (const size_t index : batch.example_index) batch_promises.emplace_back(std::move(promises[index])); - auto job = std::make_unique<BatchJob<Result, Func>>(std::move(batch), - std::move(batch_promises), - func); - - _thread_pool->post(std::move(job)); + post_batch<Result>( + [batch = std::move(batch), func](Replica& replica) { return func(replica, batch); }, + std::move(batch_promises)); } } @@ -233,11 +262,16 @@ namespace ctranslate2 { } template <typename Result, typename Func> + void post_func(Func func, std::vector<std::promise<Result>> promises) { + _thread_pool->post(std::make_unique<BatchJob<Result, Func>>(std::move(promises), + std::move(func))); + } + + template <typename Result, typename Func> class BatchJob : public Job { public: - BatchJob(Batch batch, std::vector<std::promise<Result>> promises, Func func) - : _batch(std::move(batch)) - , _promises(std::move(promises)) + BatchJob(std::vector<std::promise<Result>> promises, Func func) + : _promises(std::move(promises)) , _func(std::move(func)) { } @@ -247,7 +281,7 @@ namespace ctranslate2 { std::exception_ptr exception; try { - results = _func(get_thread_replica(), _batch); + results = _func(); } catch (...) { exception = std::current_exception(); } @@ -261,7 +295,6 @@ namespace ctranslate2 { } private: - const Batch _batch; std::vector<std::promise<Result>> _promises; const Func _func; }; |