Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-07-06 18:32:19 +0300
committerGitHub <noreply@github.com>2022-07-06 18:32:19 +0300
commit739a5b1a3166f6fb0433ef5a102eaba0b9a61408 (patch)
tree34ec3a8593d7c558545d43a6d8b7b8b3d5dda8eb /tests
parent0598ef90043489eccb5cdfe39c85d6ecf3df7b15 (diff)
Fix application of max_decoding_length in return_alternatives mode (#866)
Diffstat (limited to 'tests')
-rw-r--r--tests/translator_test.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tests/translator_test.cc b/tests/translator_test.cc
index fda3a9f5..be95a382 100644
--- a/tests/translator_test.cc
+++ b/tests/translator_test.cc
@@ -793,6 +793,40 @@ TEST(TranslatorTest, AlternativesFromFullTarget) {
EXPECT_EQ(result.hypotheses[0], (std::vector<std::string>{"a", "t", "z", "m", "o", "n", "e"}));
}
+TEST(TranslatorTest, AlternativesMaxDecodingLength) {
+ Translator translator = default_translator();
+ TranslationOptions options;
+ options.num_hypotheses = 4;
+ options.max_decoding_length = 2;
+ options.return_alternatives = true;
+ options.return_scores = true;
+ options.return_attention = true;
+
+ const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"};
+ const std::vector<std::vector<std::string>> target_samples = {
+ {}, {"a"}, {"a", "t"}, {"a", "t", "z"}
+ };
+
+ for (const auto& target : target_samples) {
+ const auto result = translator.translate_with_prefix(input, target, options);
+
+ for (size_t i = 0; i < result.num_hypotheses(); ++i) {
+ EXPECT_EQ(result.hypotheses[i].size(), options.max_decoding_length);
+ EXPECT_EQ(result.attention[i].size(), options.max_decoding_length);
+
+ for (size_t t = 0; t < std::min(target.size(), options.max_decoding_length); ++t) {
+ EXPECT_EQ(result.hypotheses[i][t], target[t]);
+ }
+
+ if (target.size() < options.max_decoding_length) {
+ EXPECT_NE(result.scores[i], 0);
+ } else {
+ EXPECT_EQ(result.scores[i], 0);
+ }
+ }
+ }
+}
+
TEST(TranslatorTest, DetachModel) {
const std::vector<std::string> input = {"آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"};
Translator translator = default_translator();