Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2021-06-09 11:15:50 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2021-06-09 11:15:50 +0300
commit6f949c07991a41a2663008201d3ffb225276e753 (patch)
treeff3b533c073b3c83156de0ad26968a640290873f /tests
parentca52e19c9f300d2c98dd9e187a559ab1b2ec14a3 (diff)
Add a translation wrapper that buffers and batches incoming inputs
Diffstat (limited to 'tests')
-rw-r--r--tests/translator_test.cc25
1 files changed, 23 insertions, 2 deletions
diff --git a/tests/translator_test.cc b/tests/translator_test.cc
index 3cbfa21a..20895d72 100644
--- a/tests/translator_test.cc
+++ b/tests/translator_test.cc
@@ -1,4 +1,4 @@
-#include <ctranslate2/translator.h>
+#include <ctranslate2/buffered_translation_wrapper.h>
#include <ctranslate2/decoding.h>
#include <algorithm>
@@ -99,8 +99,12 @@ INSTANTIATE_TEST_CASE_P(
class SearchVariantTest : public ::testing::TestWithParam<size_t> {
};
+static std::string default_model_dir() {
+ return g_data_dir + "/models/v2/aren-transliteration";
+}
+
static Translator default_translator(Device device = Device::CPU) {
- return Translator(g_data_dir + "/models/v2/aren-transliteration", device);
+ return Translator(default_model_dir(), device);
}
TEST_P(SearchVariantTest, SetMaxDecodingLength) {
@@ -739,3 +743,20 @@ TEST(TranslatorTest, SameBeamAndGreedyScore) {
const auto beam_score = translator.translate(input, options).score();
EXPECT_NEAR(greedy_score, beam_score, 1e-5);
}
+
+TEST(BufferedTranslationWrapperTest, Basic) {
+ auto translator_pool = std::make_shared<TranslatorPool>(/*num_translators=*/1,
+ /*num_threads_per_translator=*/2,
+ default_model_dir());
+ BufferedTranslationWrapper wrapper(translator_pool,
+ /*max_batch_size=*/32,
+ /*batch_timeout_in_micros=*/5000);
+
+ auto future1 = wrapper.translate_async({"آ", "ز", "ا"});
+ auto future2 = wrapper.translate_async({"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"});
+
+ EXPECT_EQ(future1.get().hypotheses[0],
+ (std::vector<std::string>{"a", "z", "z", "a"}));
+ EXPECT_EQ(future2.get().hypotheses[0],
+ (std::vector<std::string>{"a", "t", "z", "m", "o", "n"}));
+}