diff options
author | Linxiao ZENG <linxiao.zeng@gmail.com> | 2021-04-08 19:33:52 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-08 19:33:52 +0300 |
commit | 1186decfa6fd584ed05713738bd8249219321b1d (patch) | |
tree | 76decb2515456ed59dcbd7bf85887546c03526fa | |
parent | 2f70dfcd425e7c04aa1d8ac38ad1618b1c1c8137 (diff) |
Fix beam warning and buffers reuse (#2033)
-rw-r--r-- | onmt/translate/beam_search.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 8d5f6534..3550352e 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -34,7 +34,7 @@ class BeamSearchBase(DecodeStrategy): _batch_offset (LongTensor): Shape ``(B,)``. _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``. alive_seq (LongTensor): See base. - topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These + topk_log_probs (FloatTensor): Shape ``(B, beam_size,)``. These are the scores used for the topk operation. memory_lengths (LongTensor): Lengths of encodings. Used for masking attentions. @@ -105,7 +105,7 @@ class BeamSearchBase(DecodeStrategy): dtype=torch.long, device=device) self.topk_log_probs = torch.tensor( [0.0] + [float("-inf")] * (self.beam_size - 1), device=device - ).repeat(self.batch_size) + ).repeat(self.batch_size).reshape(self.batch_size, self.beam_size) # buffers for the topk scores and 'backpointer' self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) @@ -128,11 +128,12 @@ class BeamSearchBase(DecodeStrategy): def batch_offset(self): return self._batch_offset - def _pick(self, log_probs): - """Return token decision for a step. + def _pick(self, log_probs, out=None): + """Take a token pick decision for a step. Args: - log_probs (FloatTensor): (B, vocab_size) + log_probs (FloatTensor): (B * beam_size, vocab_size) + out (Tensor, LongTensor): output buffers to reuse, optional. Returns: topk_scores (FloatTensor): (B, beam_size) @@ -144,6 +145,9 @@ class BeamSearchBase(DecodeStrategy): # Flatten probs into a list of possibilities. curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) + if out is not None: + torch.topk(curr_scores, self.beam_size, dim=-1, out=out) + return topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) return topk_scores, topk_ids @@ -267,7 +271,7 @@ class BeamSearchBase(DecodeStrategy): self.block_ngram_repeats(curr_scores) # Pick up candidate token by curr_scores - self.topk_scores, self.topk_ids = self._pick(curr_scores) + self._pick(curr_scores, out=(self.topk_scores, self.topk_ids)) # Recover log probs. # Length penalty is just a scalar. It doesn't matter if it's applied |