diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2022-07-06 18:32:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-06 18:32:19 +0300 |
commit | 739a5b1a3166f6fb0433ef5a102eaba0b9a61408 (patch) | |
tree | 34ec3a8593d7c558545d43a6d8b7b8b3d5dda8eb /tests | |
parent | 0598ef90043489eccb5cdfe39c85d6ecf3df7b15 (diff) |
Fix application of max_decoding_length in return_alternatives mode (#866)
Diffstat (limited to 'tests')
-rw-r--r-- | tests/translator_test.cc | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tests/translator_test.cc b/tests/translator_test.cc index fda3a9f5..be95a382 100644 --- a/tests/translator_test.cc +++ b/tests/translator_test.cc @@ -793,6 +793,40 @@ TEST(TranslatorTest, AlternativesFromFullTarget) { EXPECT_EQ(result.hypotheses[0], (std::vector<std::string>{"a", "t", "z", "m", "o", "n", "e"})); } +TEST(TranslatorTest, AlternativesMaxDecodingLength) { + Translator translator = default_translator(); + TranslationOptions options; + options.num_hypotheses = 4; + options.max_decoding_length = 2; + options.return_alternatives = true; + options.return_scores = true; + options.return_attention = true; + + const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"}; + const std::vector<std::vector<std::string>> target_samples = { + {}, {"a"}, {"a", "t"}, {"a", "t", "z"} + }; + + for (const auto& target : target_samples) { + const auto result = translator.translate_with_prefix(input, target, options); + + for (size_t i = 0; i < result.num_hypotheses(); ++i) { + EXPECT_EQ(result.hypotheses[i].size(), options.max_decoding_length); + EXPECT_EQ(result.attention[i].size(), options.max_decoding_length); + + for (size_t t = 0; t < std::min(target.size(), options.max_decoding_length); ++t) { + EXPECT_EQ(result.hypotheses[i][t], target[t]); + } + + if (target.size() < options.max_decoding_length) { + EXPECT_NE(result.scores[i], 0); + } else { + EXPECT_EQ(result.scores[i], 0); + } + } + } +} + TEST(TranslatorTest, DetachModel) { const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"}; Translator translator = default_translator(); |