Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Nguyen <vince62s@yahoo.com>2022-10-26 13:27:57 +0300
committerGitHub <noreply@github.com>2022-10-26 13:27:57 +0300
commita0761136982ac7c30eba58a31cec1327db9dcc37 (patch)
tree358076e7b05cb0dd5b5c8bbefbd34c1dd6cdd066
parent39c0adcca03258c73af29b06773725cdbe3adfac (diff)
parentd7f1cca887a641d6fb99b978d2122f33a7d32736 (diff)
Merge pull request #2229 from vince62s/dedup
dedup batches on the fly
-rw-r--r--onmt/inputters/dynamic_iterator.py54
1 files changed, 29 insertions, 25 deletions
diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py
index 23edbbd8..a663804b 100644
--- a/onmt/inputters/dynamic_iterator.py
+++ b/onmt/inputters/dynamic_iterator.py
@@ -227,33 +227,37 @@ class DynamicDatasetIter(torch.utils.data.IterableDataset):
if batch_size_fn is None:
def batch_size_fn(new, count, sofar):
return count
- minibatch, size_so_far = [], 0
+ minibatch, size_so_far, seen = [], 0, []
for ex in data:
- minibatch.append(ex)
- size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
- if size_so_far >= batch_size:
- overflowed = 0
- if size_so_far > batch_size:
- overflowed += 1
- if batch_size_multiple > 1:
- overflowed += (
- (len(minibatch) - overflowed) % batch_size_multiple)
- if overflowed == 0:
- yield minibatch
- minibatch, size_so_far = [], 0
- else:
- if overflowed == len(minibatch):
- logger.warning(
- "The batch will be filled until we reach %d,"
- "its size may exceed %d tokens"
- % (batch_size_multiple, batch_size)
- )
+ if ex['src']['src'] not in seen:
+ seen.append(ex['src']['src'])
+ minibatch.append(ex)
+ size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
+ if size_so_far >= batch_size:
+ overflowed = 0
+ if size_so_far > batch_size:
+ overflowed += 1
+ if batch_size_multiple > 1:
+ overflowed += (
+ (len(minibatch) - overflowed)
+ % batch_size_multiple)
+ if overflowed == 0:
+ yield minibatch
+ minibatch, size_so_far, seen = [], 0, []
else:
- yield minibatch[:-overflowed]
- minibatch = minibatch[-overflowed:]
- size_so_far = 0
- for i, ex in enumerate(minibatch):
- size_so_far = batch_size_fn(ex, i + 1, size_so_far)
+ if overflowed == len(minibatch):
+ logger.warning(
+ "The batch will be filled until we reach %d,"
+ "its size may exceed %d tokens"
+ % (batch_size_multiple, batch_size)
+ )
+ else:
+ yield minibatch[:-overflowed]
+ minibatch = minibatch[-overflowed:]
+ size_so_far, seen = 0, []
+ for i, ex in enumerate(minibatch):
+ size_so_far = batch_size_fn(ex, i + 1,
+ size_so_far)
if minibatch:
yield minibatch