Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-10-28 12:00:15 +0300
committerGitHub <noreply@github.com>2022-10-28 12:00:15 +0300
commitab31d7e116ef03dcc8309a8847c597a57a8bfc63 (patch)
tree0bace20e99d81f587e435f593aeb5aa2ea7eac95
parent39bb3c2da601ff51a8784b250ec827c17b6dbfa4 (diff)
Add method post_batch in ReplicaPool class (#952)
-rw-r--r--include/ctranslate2/replica_pool.h63
-rw-r--r--include/ctranslate2/thread_pool.h27
2 files changed, 48 insertions, 42 deletions
diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h
index 5560a3c7..1cb9ed1a 100644
--- a/include/ctranslate2/replica_pool.h
+++ b/include/ctranslate2/replica_pool.h
@@ -1,6 +1,7 @@
#pragma once
#include <chrono>
+#include <future>
#include "batch_reader.h"
#include "models/model.h"
@@ -51,16 +52,46 @@ namespace ctranslate2 {
max_queued_batches);
}
+ // Posts a function and return its result as a future.
+ // The function will be run with the first available replica.
+ // The function must have the signature: Result(Replica&)
template <typename Result, typename Func>
std::future<Result> post(Func func) {
- std::function<Result()> wrapped_func = [func = std::move(func)]() {
+ auto batched_func = [func = std::move(func)](Replica& replica) {
+ std::vector<Result> results;
+ results.reserve(1);
+ results.emplace_back(func(replica));
+ return results;
+ };
+
+ auto futures = post_batch<Result>(std::move(batched_func), 1);
+ return std::move(futures[0]);
+ }
+
+ // Posts a function and return one future per result.
+ // The function will be run with the first available replica.
+ // The function must have the signature: std::vector<Result>(Replica&)
+ template <typename Result, typename Func>
+ std::vector<std::future<Result>> post_batch(Func func, size_t num_results) {
+ std::vector<std::promise<Result>> promises(num_results);
+ std::vector<std::future<Result>> futures;
+ futures.reserve(promises.size());
+ for (auto& promise : promises)
+ futures.emplace_back(promise.get_future());
+
+ post_batch(std::move(func), std::move(promises));
+
+ return futures;
+ }
+
+ // Same as above, but taking the list of promises directly.
+ template <typename Result, typename Func>
+ void post_batch(Func func, std::vector<std::promise<Result>> promises) {
+ auto wrapped_func = [func = std::move(func)]() {
return func(get_thread_replica());
};
- auto job = std::make_unique<FunctionJob<Result>>(std::move(wrapped_func));
- auto future = job->get_future();
- _thread_pool->post(std::move(job));
- return future;
+ post_func(std::move(wrapped_func), std::move(promises));
}
// Number of batches in the work queue.
@@ -144,11 +175,9 @@ namespace ctranslate2 {
for (const size_t index : batch.example_index)
batch_promises.emplace_back(std::move(promises[index]));
- auto job = std::make_unique<BatchJob<Result, Func>>(std::move(batch),
- std::move(batch_promises),
- func);
-
- _thread_pool->post(std::move(job));
+ post_batch<Result>(
+ [batch = std::move(batch), func](Replica& replica) { return func(replica, batch); },
+ std::move(batch_promises));
}
}
@@ -233,11 +262,16 @@ namespace ctranslate2 {
}
template <typename Result, typename Func>
+ void post_func(Func func, std::vector<std::promise<Result>> promises) {
+ _thread_pool->post(std::make_unique<BatchJob<Result, Func>>(std::move(promises),
+ std::move(func)));
+ }
+
+ template <typename Result, typename Func>
class BatchJob : public Job {
public:
- BatchJob(Batch batch, std::vector<std::promise<Result>> promises, Func func)
- : _batch(std::move(batch))
- , _promises(std::move(promises))
+ BatchJob(std::vector<std::promise<Result>> promises, Func func)
+ : _promises(std::move(promises))
, _func(std::move(func))
{
}
@@ -247,7 +281,7 @@ namespace ctranslate2 {
std::exception_ptr exception;
try {
- results = _func(get_thread_replica(), _batch);
+ results = _func();
} catch (...) {
exception = std::current_exception();
}
@@ -261,7 +295,6 @@ namespace ctranslate2 {
}
private:
- const Batch _batch;
std::vector<std::promise<Result>> _promises;
const Func _func;
};
diff --git a/include/ctranslate2/thread_pool.h b/include/ctranslate2/thread_pool.h
index 1cffb017..826b7e57 100644
--- a/include/ctranslate2/thread_pool.h
+++ b/include/ctranslate2/thread_pool.h
@@ -3,7 +3,6 @@
#include <atomic>
#include <condition_variable>
#include <functional>
-#include <future>
#include <limits>
#include <memory>
#include <mutex>
@@ -26,32 +25,6 @@ namespace ctranslate2 {
std::atomic<size_t>* _counter = nullptr;
};
- // Job running a function.
- template <typename Result>
- class FunctionJob : public Job {
- public:
- FunctionJob(std::function<Result()> func)
- : _func(std::move(func))
- {
- }
-
- std::future<Result> get_future() {
- return _promise.get_future();
- }
-
- void run() override {
- try {
- _promise.set_value(_func());
- } catch (...) {
- _promise.set_exception(std::current_exception());
- }
- }
-
- private:
- std::function<Result()> _func;
- std::promise<Result> _promise;
- };
-
// A thread-safe queue of jobs.
class JobQueue {
public: