Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoranderleich <andercorral95@gmail.com>2022-02-09 19:15:17 +0300
committerGitHub <noreply@github.com>2022-02-09 19:15:17 +0300
commitfc3409ce08785915728dac9e9c9fb49f73027917 (patch)
tree68d4ff764cc330073623b77865c236b37aa01c19
parent908881444a814f7ab05c77c14316a2f05d135960 (diff)
Dynamic data loading for inference (#2145)
-rw-r--r--onmt/bin/translate_dynamic.py61
-rw-r--r--onmt/inputters/__init__.py4
-rw-r--r--onmt/inputters/dataset_base.py54
-rw-r--r--onmt/inputters/text_dataset.py129
-rw-r--r--onmt/opts.py10
-rw-r--r--onmt/transforms/tokenize.py10
-rw-r--r--onmt/transforms/transform.py8
-rw-r--r--onmt/translate/translator.py66
-rw-r--r--onmt/utils/parse.py10
-rw-r--r--setup.py1
-rw-r--r--translate_dynamic.py6
11 files changed, 353 insertions, 6 deletions
diff --git a/onmt/bin/translate_dynamic.py b/onmt/bin/translate_dynamic.py
new file mode 100644
index 00000000..2ff217f0
--- /dev/null
+++ b/onmt/bin/translate_dynamic.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from onmt.utils.logging import init_logger
+from onmt.translate.translator import build_translator
+from onmt.inputters.text_dataset import InferenceDataReader
+from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
+
+import onmt.opts as opts
+from onmt.utils.parse import ArgumentParser
+
+
+def translate(opt):
+ ArgumentParser.validate_translate_opts(opt)
+ ArgumentParser._get_all_transform_translate(opt)
+ ArgumentParser._validate_transforms_opts(opt)
+ ArgumentParser.validate_translate_opts_dynamic(opt)
+ logger = init_logger(opt.log_file)
+
+ translator = build_translator(opt, logger=logger, report_score=True)
+
+ data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats)
+
+ # Build transforms
+ transforms_cls = get_transforms_cls(opt._all_transform)
+ transforms = make_transforms(opt, transforms_cls, translator.fields)
+ data_transform = [
+ transforms[name] for name in opt.transforms if name in transforms
+ ]
+ transform = TransformPipe.build_from(data_transform)
+
+ for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader):
+ logger.info("Translating shard %d." % i)
+ translator.translate_dynamic(
+ src=src_shard,
+ transform=transform,
+ src_feats=feats_shard,
+ tgt=tgt_shard,
+ batch_size=opt.batch_size,
+ batch_type=opt.batch_type,
+ attn_debug=opt.attn_debug,
+ align_debug=opt.align_debug
+ )
+
+
+def _get_parser():
+ parser = ArgumentParser(description='translate_dynamic.py')
+
+ opts.config_opts(parser)
+ opts.translate_opts(parser, dynamic=True)
+ return parser
+
+
+def main():
+ parser = _get_parser()
+
+ opt = parser.parse_args()
+ translate(opt)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py
index ed9aba89..7403426a 100644
--- a/onmt/inputters/__init__.py
+++ b/onmt/inputters/__init__.py
@@ -5,7 +5,7 @@ e.g., from a line of text to a sequence of embeddings.
"""
from onmt.inputters.inputter import get_fields, build_vocab, filter_example
from onmt.inputters.iterator import max_tok_len, OrderedIterator
-from onmt.inputters.dataset_base import Dataset
+from onmt.inputters.dataset_base import Dataset, DynamicDataset
from onmt.inputters.text_dataset import text_sort_key, TextDataReader
from onmt.inputters.datareader_base import DataReaderBase
@@ -17,4 +17,4 @@ str2sortkey = {
__all__ = ['Dataset', 'get_fields', 'DataReaderBase', 'filter_example',
'build_vocab', 'OrderedIterator', 'max_tok_len',
- 'text_sort_key', 'TextDataReader']
+ 'text_sort_key', 'TextDataReader', 'DynamicDataset']
diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py
index 574ee7ad..dda21389 100644
--- a/onmt/inputters/dataset_base.py
+++ b/onmt/inputters/dataset_base.py
@@ -165,3 +165,57 @@ class Dataset(TorchtextDataset):
readers.append(field["reader"])
data.append((name, field["data"], field.get("features", {})))
return readers, data
+
+
+class DynamicDataset(Dataset):
+
+ def __init__(self, fields, data, sort_key, filter_pred=None):
+ self.sort_key = sort_key
+ can_copy = 'src_map' in fields and 'alignment' in fields
+
+ # self.src_vocabs is used in collapse_copy_scores and Translator.py
+ self.src_vocabs = []
+ examples = []
+ for ex_dict in data:
+ if can_copy:
+ src_field = fields['src']
+ tgt_field = fields['tgt']
+ # this assumes src_field and tgt_field are both text
+ ex_dict = _dynamic_dict(
+ ex_dict, src_field.base_field, tgt_field.base_field)
+ self.src_vocabs.append(ex_dict["src_ex_vocab"])
+ ex_fields = {k: [(k, v)] for k, v in fields.items() if
+ k in ex_dict}
+ ex = Example.fromdict(ex_dict, ex_fields)
+ examples.append(ex)
+
+ # fields needs to have only keys that examples have as attrs
+ fields = []
+ for _, nf_list in ex_fields.items():
+ assert len(nf_list) == 1
+ fields.append(nf_list[0])
+
+ super(Dataset, self).__init__(examples, fields, filter_pred)
+
+ def __getattr__(self, attr):
+ # avoid infinite recursion when fields isn't defined
+ if 'fields' not in vars(self):
+ raise AttributeError
+ if attr in self.fields:
+ return (getattr(x, attr) for x in self.examples)
+ else:
+ raise AttributeError
+
+ def save(self, path, remove_fields=True):
+ if remove_fields:
+ self.fields = []
+ torch.save(self, path)
+
+ @staticmethod
+ def config(fields):
+ readers, data = [], []
+ for name, field in fields:
+ if field["data"] is not None:
+ readers.append(field["reader"])
+ data.append((name, field["data"], field.get("features", {})))
+ return readers, data
diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py
index ddb51d5c..da6a1584 100644
--- a/onmt/inputters/text_dataset.py
+++ b/onmt/inputters/text_dataset.py
@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-
from functools import partial
+from itertools import repeat
import torch
from torchtext.data import Field, RawField
from onmt.constants import DefaultTokens
from onmt.inputters.datareader_base import DataReaderBase
+from onmt.utils.misc import split_corpus
class TextDataReader(DataReaderBase):
def read(self, sequences, side, features={}):
"""Read text data from disk.
-
- Args:
+ Args:
sequences (str or Iterable[str]):
path to text file or iterable of the actual text data.
side (str): Prefix used in return dict. Usually
@@ -20,7 +21,6 @@ class TextDataReader(DataReaderBase):
features: (Dict[str or Iterable[str]]):
dictionary mapping feature names with the path to feature
file or iterable of the actual feature data.
-
Yields:
dictionaries whose keys are the names of fields and whose
values are more or less the result of tokenizing with those
@@ -49,6 +49,129 @@ class TextDataReader(DataReaderBase):
yield {side: ex_dict, "indices": i}
+class InferenceDataReader(object):
+ """It handles inference data reading from disk in shards.
+
+ Args:
+ src (str): path to the source file
+ tgt (str or NoneType): path to the target file
+ src_feats (Dict[str]): paths to the features files
+ shard_size (int): divides files into smaller files of size shard_size
+
+ Returns:
+ Tuple[List[str], List[str], Dict[List[str]]]
+ """
+
+ def __init__(self, src, tgt, src_feats={}, shard_size=10000):
+ self.src = src
+ self.tgt = tgt
+ self.src_feats = src_feats
+ self.shard_size = shard_size
+
+ def __iter__(self):
+ src_shards = split_corpus(self.src, self.shard_size)
+ tgt_shards = split_corpus(self.tgt, self.shard_size)
+
+ if not self.src_feats:
+ features_shards = [repeat(None)]
+ else:
+ features_shards = []
+ features_names = []
+ for feat_name, feat_path in self.src_feats.items():
+ features_shards.append(
+ split_corpus(feat_path, self.shard_size))
+ features_names.append(feat_name)
+
+ shard_pairs = zip(src_shards, tgt_shards, *features_shards)
+ for i, shard in enumerate(shard_pairs):
+ src_shard, tgt_shard, *features_shard = shard
+ if features_shard[0] is not None:
+ features_shard_ = dict()
+ for j, x in enumerate(features_shard):
+ features_shard_[features_names[j]] = x
+ else:
+ features_shard_ = None
+ yield src_shard, tgt_shard, features_shard_
+
+
+class InferenceDataIterator(object):
+
+ def __init__(self, src, tgt, src_feats, transform):
+ self.src = src
+ self.tgt = tgt
+ self.src_feats = src_feats
+ self.transform = transform
+
+ def _tokenize(self, example):
+ example['src'] = example['src'].decode("utf-8").strip('\n').split()
+ example['tgt'] = example['tgt'].decode("utf-8").strip('\n').split() \
+ if example["tgt"] is not None else None
+ example['src_original'] = example['src']
+ example['tgt_original'] = example['tgt']
+ if 'src_feats' in example:
+ for k in example['src_feats'].keys():
+ example['src_feats'][k] = example['src_feats'][k] \
+ .decode("utf-8").strip('\n').split() \
+ if example['src_feats'][k] is not None else None
+ return example
+
+ def _transform(self, example, remove_tgt=False):
+ maybe_example = self.transform.apply(
+ example, is_train=False, corpus_name="translate")
+ assert maybe_example is not None, \
+ "Transformation on example skipped the example. " \
+ "Please check the transforms."
+ return maybe_example
+
+ def _process(self, example, remove_tgt=False):
+ example['src'] = {"src": ' '.join(example['src'])}
+ example['tgt'] = {"tgt": ' '.join(example['tgt'])}
+
+ # Make features part of src as in TextMultiField
+ # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}}
+ if 'src_feats' in example:
+ for feat_name, feat_value in example['src_feats'].items():
+ example['src'][feat_name] = ' '.join(feat_value)
+ del example["src_feats"]
+
+ # Cleanup
+ if remove_tgt:
+ del example["tgt"]
+ del example["tgt_original"]
+ del example["src_original"]
+
+ return example
+
+ def __iter__(self):
+ tgt = self.tgt if self.tgt is not None else repeat(None)
+
+ if self.src_feats is not None:
+ features_names = []
+ features_values = []
+ for feat_name, values in self.src_feats.items():
+ features_names.append(feat_name)
+ features_values.append(values)
+ else:
+ features_values = [repeat(None)]
+
+ for i, (src, tgt, *src_feats) in enumerate(zip(
+ self.src, tgt, *features_values)):
+ ex = {
+ "src": src,
+ "tgt": tgt if tgt is not None else b""
+ }
+ if src_feats[0] is not None:
+ src_feats_ = {}
+ for j, x in enumerate(src_feats):
+ src_feats_[features_names[j]] = x
+ ex["src_feats"] = src_feats_
+ ex = self._tokenize(ex)
+ ex = self._transform(ex)
+ ex = self._process(ex, remove_tgt=self.tgt is None)
+ ex["indices"] = i
+ yield ex
+
+
def text_sort_key(ex):
"""Sort using the number of tokens in the sequence."""
if hasattr(ex, "tgt"):
diff --git a/onmt/opts.py b/onmt/opts.py
index ef12a9d7..28504d21 100644
--- a/onmt/opts.py
+++ b/onmt/opts.py
@@ -734,7 +734,7 @@ def _add_decoding_opts(parser):
"the table), then it will copy the source token.")
-def translate_opts(parser):
+def translate_opts(parser, dynamic=False):
""" Translation / inference options """
group = parser.add_argument_group('Model')
group.add('--model', '-model', dest='models', metavar='MODEL',
@@ -801,6 +801,14 @@ def translate_opts(parser):
group.add('--gpu', '-gpu', type=int, default=-1,
help="Device to run on")
+ if dynamic:
+ group.add("-transforms", "--transforms", default=[], nargs="+",
+ choices=AVAILABLE_TRANSFORMS.keys(),
+ help="Default transform pipeline to apply to data.")
+
+ # Adding options related to Transforms
+ _add_dynamic_transform_opts(parser)
+
# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py
index d4b92cb0..cf519b05 100644
--- a/onmt/transforms/tokenize.py
+++ b/onmt/transforms/tokenize.py
@@ -410,6 +410,12 @@ class ONMTTokenizerTransform(TokenizerTransform):
segmented, _ = tokenizer.tokenize(sentence)
return segmented
+ def _detokenize(self, tokens, side='src', is_train=False):
+ """Do OpenNMT Tokenizer's detokenize."""
+ tokenizer = self.load_models[side]
+ detokenized = tokenizer.detokenize(tokens)
+ return detokenized
+
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply OpenNMT Tokenizer to src & tgt."""
src_out = self._tokenize(example['src'], 'src')
@@ -421,6 +427,10 @@ class ONMTTokenizerTransform(TokenizerTransform):
example['src'], example['tgt'] = src_out, tgt_out
return example
+ def apply_reverse(self, translated):
+ """Apply OpenNMT Tokenizer to src & tgt."""
+ return self._detokenize(translated.split(), 'tgt')
+
def _repr_args(self):
"""Return str represent key arguments for class."""
repr_str = '{}={}'.format('share_vocab', self.share_vocab)
diff --git a/onmt/transforms/transform.py b/onmt/transforms/transform.py
index 6dd5869d..88b6f40c 100644
--- a/onmt/transforms/transform.py
+++ b/onmt/transforms/transform.py
@@ -60,6 +60,9 @@ class Transform(object):
"""
raise NotImplementedError
+ def apply_reverse(self, translated):
+ return translated
+
def __getstate__(self):
"""Pickling following for rebuild."""
state = {"opts": self.opts}
@@ -193,6 +196,11 @@ class TransformPipe(Transform):
break
return example
+ def apply_reverse(self, translated):
+ for transform in self.transforms:
+ translated = transform.apply_reverse(translated)
+ return translated
+
def __getstate__(self):
"""Pickling following for rebuild."""
return (self.opts, self.transforms, self.statistics)
diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py
index 81a98926..69892cec 100644
--- a/onmt/translate/translator.py
+++ b/onmt/translate/translator.py
@@ -12,6 +12,7 @@ from onmt.constants import DefaultTokens
import onmt.model_builder
import onmt.inputters as inputters
import onmt.decoders.ensemble
+from onmt.inputters.text_dataset import InferenceDataIterator
from onmt.translate.beam_search import BeamSearch, BeamSearchLM
from onmt.translate.greedy_search import GreedySearch, GreedySearchLM
from onmt.utils.misc import tile, set_random_seed, report_matrix
@@ -330,6 +331,45 @@ class Inference(object):
gs = [0] * batch_size
return gs
+ def translate_dynamic(
+ self,
+ src,
+ transform,
+ src_feats={},
+ tgt=None,
+ batch_size=None,
+ batch_type="sents",
+ attn_debug=False,
+ align_debug=False,
+ phrase_table=""
+ ):
+
+ if batch_size is None:
+ raise ValueError("batch_size must be set")
+
+ if self.tgt_prefix and tgt is None:
+ raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")
+
+ data_iter = InferenceDataIterator(src, tgt, src_feats, transform)
+
+ data = inputters.DynamicDataset(
+ self.fields,
+ data=data_iter,
+ sort_key=inputters.str2sortkey[self.data_type],
+ filter_pred=self._filter_pred,
+ )
+
+ return self._translate(
+ data,
+ tgt=tgt,
+ batch_size=batch_size,
+ batch_type=batch_type,
+ attn_debug=attn_debug,
+ align_debug=align_debug,
+ phrase_table=phrase_table,
+ dynamic=True,
+ transform=transform)
+
def translate(
self,
src,
@@ -387,6 +427,28 @@ class Inference(object):
filter_pred=self._filter_pred,
)
+ return self._translate(
+ data,
+ tgt=tgt,
+ batch_size=batch_size,
+ batch_type=batch_type,
+ attn_debug=attn_debug,
+ align_debug=align_debug,
+ phrase_table=phrase_table)
+
+ def _translate(
+ self,
+ data,
+ tgt=None,
+ batch_size=None,
+ batch_type="sents",
+ attn_debug=False,
+ align_debug=False,
+ phrase_table="",
+ dynamic=False,
+ transform=None
+ ):
+
data_iter = inputters.OrderedIterator(
dataset=data,
device=self._dev,
@@ -448,6 +510,10 @@ class Inference(object):
n_best_preds, n_best_preds_align
)
]
+
+ if dynamic:
+ n_best_preds = [transform.apply_reverse(x)
+ for x in n_best_preds]
all_predictions += [n_best_preds]
self.out_file.write("\n".join(n_best_preds) + "\n")
self.out_file.flush()
diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py
index 35f398af..edc61f44 100644
--- a/onmt/utils/parse.py
+++ b/onmt/utils/parse.py
@@ -123,6 +123,10 @@ class DataOptsCheckerMixin(object):
opt._all_transform = all_transforms
@classmethod
+ def _get_all_transform_translate(cls, opt):
+ opt._all_transform = opt.transforms
+
+ @classmethod
def _validate_fields_opts(cls, opt, build_vocab_only=False):
"""Check options relate to vocab and fields."""
@@ -327,3 +331,9 @@ class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin):
@classmethod
def validate_translate_opts(cls, opt):
opt.src_feats = eval(opt.src_feats) if opt.src_feats else {}
+
+ @classmethod
+ def validate_translate_opts_dynamic(cls, opt):
+ # It comes from training
+ # TODO: needs to be added as inference opt
+ opt.share_vocab = False
diff --git a/setup.py b/setup.py
index c086b849..ab179eec 100644
--- a/setup.py
+++ b/setup.py
@@ -35,6 +35,7 @@ setup(
"onmt_server=onmt.bin.server:main",
"onmt_train=onmt.bin.train:main",
"onmt_translate=onmt.bin.translate:main",
+ "onmt_translate_dynamic=onmt.bin.translate_dynamic:main",
"onmt_release_model=onmt.bin.release_model:main",
"onmt_average_models=onmt.bin.average_models:main",
"onmt_build_vocab=onmt.bin.build_vocab:main"
diff --git a/translate_dynamic.py b/translate_dynamic.py
new file mode 100644
index 00000000..19ee4023
--- /dev/null
+++ b/translate_dynamic.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+from onmt.bin.translate_dynamic import main
+
+
+if __name__ == "__main__":
+ main()