diff options
author | Vincent Nguyen <vince62s@yahoo.com> | 2022-10-26 13:27:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-26 13:27:57 +0300 |
commit | a0761136982ac7c30eba58a31cec1327db9dcc37 (patch) | |
tree | 358076e7b05cb0dd5b5c8bbefbd34c1dd6cdd066 | |
parent | 39c0adcca03258c73af29b06773725cdbe3adfac (diff) | |
parent | d7f1cca887a641d6fb99b978d2122f33a7d32736 (diff) |
Merge pull request #2229 from vince62s/dedup
dedup batches on the fly
-rw-r--r-- | onmt/inputters/dynamic_iterator.py | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 23edbbd8..a663804b 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -227,33 +227,37 @@ class DynamicDatasetIter(torch.utils.data.IterableDataset): if batch_size_fn is None: def batch_size_fn(new, count, sofar): return count - minibatch, size_so_far = [], 0 + minibatch, size_so_far, seen = [], 0, [] for ex in data: - minibatch.append(ex) - size_so_far = batch_size_fn(ex, len(minibatch), size_so_far) - if size_so_far >= batch_size: - overflowed = 0 - if size_so_far > batch_size: - overflowed += 1 - if batch_size_multiple > 1: - overflowed += ( - (len(minibatch) - overflowed) % batch_size_multiple) - if overflowed == 0: - yield minibatch - minibatch, size_so_far = [], 0 - else: - if overflowed == len(minibatch): - logger.warning( - "The batch will be filled until we reach %d," - "its size may exceed %d tokens" - % (batch_size_multiple, batch_size) - ) + if ex['src']['src'] not in seen: + seen.append(ex['src']['src']) + minibatch.append(ex) + size_so_far = batch_size_fn(ex, len(minibatch), size_so_far) + if size_so_far >= batch_size: + overflowed = 0 + if size_so_far > batch_size: + overflowed += 1 + if batch_size_multiple > 1: + overflowed += ( + (len(minibatch) - overflowed) + % batch_size_multiple) + if overflowed == 0: + yield minibatch + minibatch, size_so_far, seen = [], 0, [] else: - yield minibatch[:-overflowed] - minibatch = minibatch[-overflowed:] - size_so_far = 0 - for i, ex in enumerate(minibatch): - size_so_far = batch_size_fn(ex, i + 1, size_so_far) + if overflowed == len(minibatch): + logger.warning( + "The batch will be filled until we reach %d," + "its size may exceed %d tokens" + % (batch_size_multiple, batch_size) + ) + else: + yield minibatch[:-overflowed] + minibatch = minibatch[-overflowed:] + size_so_far, seen = 0, [] + for i, ex in enumerate(minibatch): + size_so_far = batch_size_fn(ex, i + 1, + size_so_far) if minibatch: yield minibatch |