Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLinxiao ZENG <linxiao.zeng@gmail.com>2021-04-08 19:33:52 +0300
committerGitHub <noreply@github.com>2021-04-08 19:33:52 +0300
commit1186decfa6fd584ed05713738bd8249219321b1d (patch)
tree76decb2515456ed59dcbd7bf85887546c03526fa
parent2f70dfcd425e7c04aa1d8ac38ad1618b1c1c8137 (diff)
Fix beam warning and buffers reuse (#2033)
-rw-r--r--onmt/translate/beam_search.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py
index 8d5f6534..3550352e 100644
--- a/onmt/translate/beam_search.py
+++ b/onmt/translate/beam_search.py
@@ -34,7 +34,7 @@ class BeamSearchBase(DecodeStrategy):
_batch_offset (LongTensor): Shape ``(B,)``.
_beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``.
alive_seq (LongTensor): See base.
- topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
+ topk_log_probs (FloatTensor): Shape ``(B, beam_size,)``. These
are the scores used for the topk operation.
memory_lengths (LongTensor): Lengths of encodings. Used for
masking attentions.
@@ -105,7 +105,7 @@ class BeamSearchBase(DecodeStrategy):
dtype=torch.long, device=device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
- ).repeat(self.batch_size)
+ ).repeat(self.batch_size).reshape(self.batch_size, self.beam_size)
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((self.batch_size, self.beam_size),
dtype=torch.float, device=device)
@@ -128,11 +128,12 @@ class BeamSearchBase(DecodeStrategy):
def batch_offset(self):
return self._batch_offset
- def _pick(self, log_probs):
- """Return token decision for a step.
+ def _pick(self, log_probs, out=None):
+ """Take a token pick decision for a step.
Args:
- log_probs (FloatTensor): (B, vocab_size)
+ log_probs (FloatTensor): (B * beam_size, vocab_size)
+ out (Tensor, LongTensor): output buffers to reuse, optional.
Returns:
topk_scores (FloatTensor): (B, beam_size)
@@ -144,6 +145,9 @@ class BeamSearchBase(DecodeStrategy):
# Flatten probs into a list of possibilities.
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
+ if out is not None:
+ torch.topk(curr_scores, self.beam_size, dim=-1, out=out)
+ return
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
return topk_scores, topk_ids
@@ -267,7 +271,7 @@ class BeamSearchBase(DecodeStrategy):
self.block_ngram_repeats(curr_scores)
# Pick up candidate token by curr_scores
- self.topk_scores, self.topk_ids = self._pick(curr_scores)
+ self._pick(curr_scores, out=(self.topk_scores, self.topk_ids))
# Recover log probs.
# Length penalty is just a scalar. It doesn't matter if it's applied