Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2018-10-22 16:33:40 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2018-10-22 16:33:40 +0300
commit17200099e7f6ad04928e56a53682019a2db8e2a7 (patch)
tree7beb4d6646b65af3daafdb88eed14ebcc271f503
parent0a9c3bc191e4edd55a3496d496ce283325801757 (diff)
Improve beam search end condition to always return complete hypothesesv0.7.3
-rw-r--r--src/decoder.cc12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/decoder.cc b/src/decoder.cc
index ffc50b94..25a10b51 100644
--- a/src/decoder.cc
+++ b/src/decoder.cc
@@ -140,6 +140,7 @@ namespace ctranslate2 {
scores.clear();
scores.resize(batch_size);
+ std::vector<bool> top_beam_finished(batch_size, false);
std::vector<size_t> batch_offset(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
batch_offset[i] = i;
@@ -208,11 +209,11 @@ namespace ctranslate2 {
size_t finished_count = 0;
for (size_t i = 0; i < cur_batch_size; ++i) {
size_t batch_id = batch_offset[i];
- bool batch_finished = (topk_ids.at<int32_t>({i, 0}) == static_cast<int32_t>(end_token)
- || step + 1 == max_steps);
-
for (size_t k = 0; k < beam_size; ++k) {
- if (topk_ids.at<int32_t>({i, k}) == static_cast<int32_t>(end_token) || batch_finished) {
+ if (topk_ids.at<int32_t>({i, k}) == static_cast<int32_t>(end_token)
+ || step + 1 == max_steps) {
+ if (k == 0)
+ top_beam_finished[i] = true;
float score = topk_log_probs.at<float>({i, k});
// Save the finished hypothesis only if it is still a candidate.
if (hypotheses[batch_id].size() < num_hypotheses
@@ -234,7 +235,7 @@ namespace ctranslate2 {
}
}
- if (batch_finished) {
+ if (top_beam_finished[i] && hypotheses[batch_id].size() >= num_hypotheses) {
++finished_count;
finished[i] = true;
@@ -263,6 +264,7 @@ namespace ctranslate2 {
for (; read_index < finished.size(); ++read_index) {
if (!finished[read_index]) {
keep_batches.at<int32_t>(write_index) = read_index;
+ top_beam_finished[write_index] = top_beam_finished[read_index];
batch_offset[write_index] = batch_offset[read_index];
++write_index;
}