Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'include/ctranslate2/replica_pool.h')
-rw-r--r--include/ctranslate2/replica_pool.h63
1 files changed, 48 insertions, 15 deletions
diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h
index 5560a3c7..1cb9ed1a 100644
--- a/include/ctranslate2/replica_pool.h
+++ b/include/ctranslate2/replica_pool.h
@@ -1,6 +1,7 @@
#pragma once
#include <chrono>
+#include <future>
#include "batch_reader.h"
#include "models/model.h"
@@ -51,16 +52,46 @@ namespace ctranslate2 {
max_queued_batches);
}
+ // Posts a function and return its result as a future.
+ // The function will be run with the first available replica.
+ // The function must have the signature: Result(Replica&)
template <typename Result, typename Func>
std::future<Result> post(Func func) {
- std::function<Result()> wrapped_func = [func = std::move(func)]() {
+ auto batched_func = [func = std::move(func)](Replica& replica) {
+ std::vector<Result> results;
+ results.reserve(1);
+ results.emplace_back(func(replica));
+ return results;
+ };
+
+ auto futures = post_batch<Result>(std::move(batched_func), 1);
+ return std::move(futures[0]);
+ }
+
+ // Posts a function and return one future per result.
+ // The function will be run with the first available replica.
+ // The function must have the signature: std::vector<Result>(Replica&)
+ template <typename Result, typename Func>
+ std::vector<std::future<Result>> post_batch(Func func, size_t num_results) {
+ std::vector<std::promise<Result>> promises(num_results);
+ std::vector<std::future<Result>> futures;
+ futures.reserve(promises.size());
+ for (auto& promise : promises)
+ futures.emplace_back(promise.get_future());
+
+ post_batch(std::move(func), std::move(promises));
+
+ return futures;
+ }
+
+ // Same as above, but taking the list of promises directly.
+ template <typename Result, typename Func>
+ void post_batch(Func func, std::vector<std::promise<Result>> promises) {
+ auto wrapped_func = [func = std::move(func)]() {
return func(get_thread_replica());
};
- auto job = std::make_unique<FunctionJob<Result>>(std::move(wrapped_func));
- auto future = job->get_future();
- _thread_pool->post(std::move(job));
- return future;
+ post_func(std::move(wrapped_func), std::move(promises));
}
// Number of batches in the work queue.
@@ -144,11 +175,9 @@ namespace ctranslate2 {
for (const size_t index : batch.example_index)
batch_promises.emplace_back(std::move(promises[index]));
- auto job = std::make_unique<BatchJob<Result, Func>>(std::move(batch),
- std::move(batch_promises),
- func);
-
- _thread_pool->post(std::move(job));
+ post_batch<Result>(
+ [batch = std::move(batch), func](Replica& replica) { return func(replica, batch); },
+ std::move(batch_promises));
}
}
@@ -233,11 +262,16 @@ namespace ctranslate2 {
}
template <typename Result, typename Func>
+ void post_func(Func func, std::vector<std::promise<Result>> promises) {
+ _thread_pool->post(std::make_unique<BatchJob<Result, Func>>(std::move(promises),
+ std::move(func)));
+ }
+
+ template <typename Result, typename Func>
class BatchJob : public Job {
public:
- BatchJob(Batch batch, std::vector<std::promise<Result>> promises, Func func)
- : _batch(std::move(batch))
- , _promises(std::move(promises))
+ BatchJob(std::vector<std::promise<Result>> promises, Func func)
+ : _promises(std::move(promises))
, _func(std::move(func))
{
}
@@ -247,7 +281,7 @@ namespace ctranslate2 {
std::exception_ptr exception;
try {
- results = _func(get_thread_replica(), _batch);
+ results = _func();
} catch (...) {
exception = std::current_exception();
}
@@ -261,7 +295,6 @@ namespace ctranslate2 {
}
private:
- const Batch _batch;
std::vector<std::promise<Result>> _promises;
const Func _func;
};