Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Nguyen <vince62s@yahoo.com>2022-09-09 13:47:56 +0300
committerGitHub <noreply@github.com>2022-09-09 13:47:56 +0300
commit4452c83ad0cd1f47d2ff0e109bfe962c1dee655f (patch)
tree653802d43b96f5d6f99761b9568e96533853c220
parent948a15002d25378fd5ad26db813250abdce7e81e (diff)
Change beam exit condition (#2190)
-rw-r--r--onmt/tests/test_beam_search.py39
-rw-r--r--onmt/translate/beam_search.py8
2 files changed, 18 insertions, 29 deletions
diff --git a/onmt/tests/test_beam_search.py b/onmt/tests/test_beam_search.py
index 61fb7334..39f57263 100644
--- a/onmt/tests/test_beam_search.py
+++ b/onmt/tests/test_beam_search.py
@@ -252,7 +252,7 @@ class TestBeamSearch(unittest.TestCase):
# since only beam 0 terminates and n_best = 2
pass
- def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
+ def test_beam_is_done_when_X_beams_eos_using_min_length(self):
# this is also a test that when block_ngram_repeat=0,
# repeating is acceptable
beam_sz = 5
@@ -290,13 +290,8 @@ class TestBeamSearch(unittest.TestCase):
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score
else:
- word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
- word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
- # provide beam_sz other good predictions in other beams
- for k, (j, score) in enumerate(
- zip(_non_eos_idxs, valid_score_dist[1:])):
- beam_idx = min(beam_sz - 1, k)
- word_probs[beam_idx::beam_sz, j] = score
+ for j in range(beam_sz):
+ word_probs[j::beam_sz, eos_idx] = valid_score_dist[0]
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
@@ -351,13 +346,8 @@ class TestBeamSearch(unittest.TestCase):
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score
else:
- word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
- word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
- # provide beam_sz other good predictions in other beams
- for k, (j, score) in enumerate(
- zip(_non_eos_idxs, valid_score_dist[1:])):
- beam_idx = min(beam_sz - 1, k)
- word_probs[beam_idx::beam_sz, j] = score
+ for j in range(beam_sz):
+ word_probs[j::beam_sz, eos_idx] = valid_score_dist[0]
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
@@ -502,11 +492,11 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase):
def third_step(self, beam, expected_beam_scores, expected_len_pen):
# assumes beam 0 finished on last step
scores_3 = torch.log_softmax(torch.tensor(
- [[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
+ [[0, 0, 10000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
[0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 5000, 0, 0],
- [0, 0, 0, .2, .2, .2, .2, .2],
- [0, 0, 50, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
+ [0, 0, 10000, 0, 0, 5000, 0, 0],
+ [0, 0, 50, .2, .2, .2, .2, .2], # beam 3 -> beam 1 should die
+ [0, 0, 50, 0, .2, .2, .2, .2]]
), dim=1)
scores_3 = scores_3.repeat(self.BATCH_SZ, 1)
@@ -526,11 +516,10 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase):
expected_beam_scores / expected_len_pen))
self.assertTrue(beam.topk_ids.equal(expected_preds_3))
self.assertTrue(beam.current_backptr.equal(expected_bptr_3))
- self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
- # new beam 1 finished
- self.assertTrue(beam.is_finished[:, 1].all())
- # new beam 1 is old beam 4
- self.assertTrue(expected_bptr_3[:, 1].eq(4).all())
+ # we finish 3 hyps per example in this step
+ self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ * 3)
+ # new beam 1 is old beam 3
+ self.assertTrue(expected_bptr_3[:, 1].eq(3).all())
beam.update_finished()
self.assertTrue(beam.top_beam_finished.all())
self.assertTrue(beam.done)
@@ -579,7 +568,7 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase):
[0, 0, 0, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
), dim=1)
scores_finish = scores_finish.repeat(self.BATCH_SZ, 1)
- scores_finish[:self.BEAM_SZ, beam.eos] = 0
+ scores_finish[:self.BEAM_SZ, beam.eos] = 100
beam.advance(scores_finish, None)
any_finished = beam.is_finished.any()
diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py
index b48f41f7..fc338272 100644
--- a/onmt/translate/beam_search.py
+++ b/onmt/translate/beam_search.py
@@ -189,18 +189,18 @@ class BeamSearchBase(DecodeStrategy):
self.is_finished[i].all()
else:
finish_flag = self.top_beam_finished[i] != 0
- if finish_flag and len(self.hypotheses[b]) >= self.n_best:
+ if finish_flag and len(self.hypotheses[b]) >= self.beam_size:
best_hyp = sorted(
- self.hypotheses[b], key=lambda x: x[0], reverse=True)
+ self.hypotheses[b], key=lambda x: x[0],
+ reverse=True)[:self.n_best]
for n, (score, pred, attn) in enumerate(best_hyp):
- if n >= self.n_best:
- break
self.scores[b].append(score)
self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append(
attn if attn is not None else [])
else:
non_finished_batch.append(i)
+
non_finished = torch.tensor(non_finished_batch)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0: