diff options
author | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-10-22 16:33:40 +0300 |
---|---|---|
committer | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-10-22 16:33:40 +0300 |
commit | 17200099e7f6ad04928e56a53682019a2db8e2a7 (patch) | |
tree | 7beb4d6646b65af3daafdb88eed14ebcc271f503 | |
parent | 0a9c3bc191e4edd55a3496d496ce283325801757 (diff) |
Improve beam search end condition to always return complete hypothesesv0.7.3
-rw-r--r-- | src/decoder.cc | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/decoder.cc b/src/decoder.cc index ffc50b94..25a10b51 100644 --- a/src/decoder.cc +++ b/src/decoder.cc @@ -140,6 +140,7 @@ namespace ctranslate2 { scores.clear(); scores.resize(batch_size); + std::vector<bool> top_beam_finished(batch_size, false); std::vector<size_t> batch_offset(batch_size); for (size_t i = 0; i < batch_size; ++i) { batch_offset[i] = i; @@ -208,11 +209,11 @@ namespace ctranslate2 { size_t finished_count = 0; for (size_t i = 0; i < cur_batch_size; ++i) { size_t batch_id = batch_offset[i]; - bool batch_finished = (topk_ids.at<int32_t>({i, 0}) == static_cast<int32_t>(end_token) - || step + 1 == max_steps); - for (size_t k = 0; k < beam_size; ++k) { - if (topk_ids.at<int32_t>({i, k}) == static_cast<int32_t>(end_token) || batch_finished) { + if (topk_ids.at<int32_t>({i, k}) == static_cast<int32_t>(end_token) + || step + 1 == max_steps) { + if (k == 0) + top_beam_finished[i] = true; float score = topk_log_probs.at<float>({i, k}); // Save the finished hypothesis only if it is still a candidate. if (hypotheses[batch_id].size() < num_hypotheses @@ -234,7 +235,7 @@ namespace ctranslate2 { } } - if (batch_finished) { + if (top_beam_finished[i] && hypotheses[batch_id].size() >= num_hypotheses) { ++finished_count; finished[i] = true; @@ -263,6 +264,7 @@ namespace ctranslate2 { for (; read_index < finished.size(); ++read_index) { if (!finished[read_index]) { keep_batches.at<int32_t>(write_index) = read_index; + top_beam_finished[write_index] = top_beam_finished[read_index]; batch_offset[write_index] = batch_offset[read_index]; ++write_index; } |