diff options
author | Vincent Nguyen <vince62s@yahoo.com> | 2022-09-09 13:47:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-09 13:47:56 +0300 |
commit | 4452c83ad0cd1f47d2ff0e109bfe962c1dee655f (patch) | |
tree | 653802d43b96f5d6f99761b9568e96533853c220 | |
parent | 948a15002d25378fd5ad26db813250abdce7e81e (diff) |
Change beam exit condition (#2190)
-rw-r--r-- | onmt/tests/test_beam_search.py | 39 | ||||
-rw-r--r-- | onmt/translate/beam_search.py | 8 |
2 files changed, 18 insertions, 29 deletions
diff --git a/onmt/tests/test_beam_search.py b/onmt/tests/test_beam_search.py index 61fb7334..39f57263 100644 --- a/onmt/tests/test_beam_search.py +++ b/onmt/tests/test_beam_search.py @@ -252,7 +252,7 @@ class TestBeamSearch(unittest.TestCase): # since only beam 0 terminates and n_best = 2 pass - def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): + def test_beam_is_done_when_X_beams_eos_using_min_length(self): # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 @@ -290,13 +290,8 @@ class TestBeamSearch(unittest.TestCase): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx::beam_sz, j] = score else: - word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] - word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] - # provide beam_sz other good predictions in other beams - for k, (j, score) in enumerate( - zip(_non_eos_idxs, valid_score_dist[1:])): - beam_idx = min(beam_sz - 1, k) - word_probs[beam_idx::beam_sz, j] = score + for j in range(beam_sz): + word_probs[j::beam_sz, eos_idx] = valid_score_dist[0] attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) @@ -351,13 +346,8 @@ class TestBeamSearch(unittest.TestCase): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx::beam_sz, j] = score else: - word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] - word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] - # provide beam_sz other good predictions in other beams - for k, (j, score) in enumerate( - zip(_non_eos_idxs, valid_score_dist[1:])): - beam_idx = min(beam_sz - 1, k) - word_probs[beam_idx::beam_sz, j] = score + for j in range(beam_sz): + word_probs[j::beam_sz, eos_idx] = valid_score_dist[0] attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) @@ -502,11 +492,11 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase): def third_step(self, beam, expected_beam_scores, expected_len_pen): # assumes beam 0 finished on last step scores_3 = torch.log_softmax(torch.tensor( - [[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont + [[0, 0, 10000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 5000, 0, 0], - [0, 0, 0, .2, .2, .2, .2, .2], - [0, 0, 50, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die + [0, 0, 10000, 0, 0, 5000, 0, 0], + [0, 0, 50, .2, .2, .2, .2, .2], # beam 3 -> beam 1 should die + [0, 0, 50, 0, .2, .2, .2, .2]] ), dim=1) scores_3 = scores_3.repeat(self.BATCH_SZ, 1) @@ -526,11 +516,10 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase): expected_beam_scores / expected_len_pen)) self.assertTrue(beam.topk_ids.equal(expected_preds_3)) self.assertTrue(beam.current_backptr.equal(expected_bptr_3)) - self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ) - # new beam 1 finished - self.assertTrue(beam.is_finished[:, 1].all()) - # new beam 1 is old beam 4 - self.assertTrue(expected_bptr_3[:, 1].eq(4).all()) + # we finish 3 hyps per example in this step + self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ * 3) + # new beam 1 is old beam 3 + self.assertTrue(expected_bptr_3[:, 1].eq(3).all()) beam.update_finished() self.assertTrue(beam.top_beam_finished.all()) self.assertTrue(beam.done) @@ -579,7 +568,7 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase): [0, 0, 0, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die ), dim=1) scores_finish = scores_finish.repeat(self.BATCH_SZ, 1) - scores_finish[:self.BEAM_SZ, beam.eos] = 0 + scores_finish[:self.BEAM_SZ, beam.eos] = 100 beam.advance(scores_finish, None) any_finished = beam.is_finished.any() diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index b48f41f7..fc338272 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -189,18 +189,18 @@ class BeamSearchBase(DecodeStrategy): self.is_finished[i].all() else: finish_flag = self.top_beam_finished[i] != 0 - if finish_flag and len(self.hypotheses[b]) >= self.n_best: + if finish_flag and len(self.hypotheses[b]) >= self.beam_size: best_hyp = sorted( - self.hypotheses[b], key=lambda x: x[0], reverse=True) + self.hypotheses[b], key=lambda x: x[0], + reverse=True)[:self.n_best] for n, (score, pred, attn) in enumerate(best_hyp): - if n >= self.n_best: - break self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) else: non_finished_batch.append(i) + non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further. if len(non_finished) == 0: |