diff options
author | Guillaume Klein <guillaume.klein@systrangroup.com> | 2021-06-09 11:15:50 +0300 |
---|---|---|
committer | Guillaume Klein <guillaume.klein@systrangroup.com> | 2021-06-09 11:15:50 +0300 |
commit | 6f949c07991a41a2663008201d3ffb225276e753 (patch) | |
tree | ff3b533c073b3c83156de0ad26968a640290873f /tests | |
parent | ca52e19c9f300d2c98dd9e187a559ab1b2ec14a3 (diff) |
Add a translation wrapper that buffers and batches incoming inputs
Diffstat (limited to 'tests')
-rw-r--r-- | tests/translator_test.cc | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/tests/translator_test.cc b/tests/translator_test.cc index 3cbfa21a..20895d72 100644 --- a/tests/translator_test.cc +++ b/tests/translator_test.cc @@ -1,4 +1,4 @@ -#include <ctranslate2/translator.h> +#include <ctranslate2/buffered_translation_wrapper.h> #include <ctranslate2/decoding.h> #include <algorithm> @@ -99,8 +99,12 @@ INSTANTIATE_TEST_CASE_P( class SearchVariantTest : public ::testing::TestWithParam<size_t> { }; +static std::string default_model_dir() { + return g_data_dir + "/models/v2/aren-transliteration"; +} + static Translator default_translator(Device device = Device::CPU) { - return Translator(g_data_dir + "/models/v2/aren-transliteration", device); + return Translator(default_model_dir(), device); } TEST_P(SearchVariantTest, SetMaxDecodingLength) { @@ -739,3 +743,20 @@ TEST(TranslatorTest, SameBeamAndGreedyScore) { const auto beam_score = translator.translate(input, options).score(); EXPECT_NEAR(greedy_score, beam_score, 1e-5); } + +TEST(BufferedTranslationWrapperTest, Basic) { + auto translator_pool = std::make_shared<TranslatorPool>(/*num_translators=*/1, + /*num_threads_per_translator=*/2, + default_model_dir()); + BufferedTranslationWrapper wrapper(translator_pool, + /*max_batch_size=*/32, + /*batch_timeout_in_micros=*/5000); + + auto future1 = wrapper.translate_async({"آ", "ز", "ا"}); + auto future2 = wrapper.translate_async({"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"}); + + EXPECT_EQ(future1.get().hypotheses[0], + (std::vector<std::string>{"a", "z", "z", "a"})); + EXPECT_EQ(future2.get().hypotheses[0], + (std::vector<std::string>{"a", "t", "z", "m", "o", "n"})); +} |