diff options
author | anderleich <andercorral95@gmail.com> | 2022-02-09 19:15:17 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-09 19:15:17 +0300 |
commit | fc3409ce08785915728dac9e9c9fb49f73027917 (patch) | |
tree | 68d4ff764cc330073623b77865c236b37aa01c19 | |
parent | 908881444a814f7ab05c77c14316a2f05d135960 (diff) |
Dynamic data loading for inference (#2145)
-rw-r--r-- | onmt/bin/translate_dynamic.py | 61 | ||||
-rw-r--r-- | onmt/inputters/__init__.py | 4 | ||||
-rw-r--r-- | onmt/inputters/dataset_base.py | 54 | ||||
-rw-r--r-- | onmt/inputters/text_dataset.py | 129 | ||||
-rw-r--r-- | onmt/opts.py | 10 | ||||
-rw-r--r-- | onmt/transforms/tokenize.py | 10 | ||||
-rw-r--r-- | onmt/transforms/transform.py | 8 | ||||
-rw-r--r-- | onmt/translate/translator.py | 66 | ||||
-rw-r--r-- | onmt/utils/parse.py | 10 | ||||
-rw-r--r-- | setup.py | 1 | ||||
-rw-r--r-- | translate_dynamic.py | 6 |
11 files changed, 353 insertions, 6 deletions
diff --git a/onmt/bin/translate_dynamic.py b/onmt/bin/translate_dynamic.py new file mode 100644 index 00000000..2ff217f0 --- /dev/null +++ b/onmt/bin/translate_dynamic.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from onmt.utils.logging import init_logger +from onmt.translate.translator import build_translator +from onmt.inputters.text_dataset import InferenceDataReader +from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe + +import onmt.opts as opts +from onmt.utils.parse import ArgumentParser + + +def translate(opt): + ArgumentParser.validate_translate_opts(opt) + ArgumentParser._get_all_transform_translate(opt) + ArgumentParser._validate_transforms_opts(opt) + ArgumentParser.validate_translate_opts_dynamic(opt) + logger = init_logger(opt.log_file) + + translator = build_translator(opt, logger=logger, report_score=True) + + data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) + + # Build transforms + transforms_cls = get_transforms_cls(opt._all_transform) + transforms = make_transforms(opt, transforms_cls, translator.fields) + data_transform = [ + transforms[name] for name in opt.transforms if name in transforms + ] + transform = TransformPipe.build_from(data_transform) + + for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader): + logger.info("Translating shard %d." % i) + translator.translate_dynamic( + src=src_shard, + transform=transform, + src_feats=feats_shard, + tgt=tgt_shard, + batch_size=opt.batch_size, + batch_type=opt.batch_type, + attn_debug=opt.attn_debug, + align_debug=opt.align_debug + ) + + +def _get_parser(): + parser = ArgumentParser(description='translate_dynamic.py') + + opts.config_opts(parser) + opts.translate_opts(parser, dynamic=True) + return parser + + +def main(): + parser = _get_parser() + + opt = parser.parse_args() + translate(opt) + + +if __name__ == "__main__": + main() diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index ed9aba89..7403426a 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -5,7 +5,7 @@ e.g., from a line of text to a sequence of embeddings. """ from onmt.inputters.inputter import get_fields, build_vocab, filter_example from onmt.inputters.iterator import max_tok_len, OrderedIterator -from onmt.inputters.dataset_base import Dataset +from onmt.inputters.dataset_base import Dataset, DynamicDataset from onmt.inputters.text_dataset import text_sort_key, TextDataReader from onmt.inputters.datareader_base import DataReaderBase @@ -17,4 +17,4 @@ str2sortkey = { __all__ = ['Dataset', 'get_fields', 'DataReaderBase', 'filter_example', 'build_vocab', 'OrderedIterator', 'max_tok_len', - 'text_sort_key', 'TextDataReader'] + 'text_sort_key', 'TextDataReader', 'DynamicDataset'] diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py index 574ee7ad..dda21389 100644 --- a/onmt/inputters/dataset_base.py +++ b/onmt/inputters/dataset_base.py @@ -165,3 +165,57 @@ class Dataset(TorchtextDataset): readers.append(field["reader"]) data.append((name, field["data"], field.get("features", {}))) return readers, data + + +class DynamicDataset(Dataset): + + def __init__(self, fields, data, sort_key, filter_pred=None): + self.sort_key = sort_key + can_copy = 'src_map' in fields and 'alignment' in fields + + # self.src_vocabs is used in collapse_copy_scores and Translator.py + self.src_vocabs = [] + examples = [] + for ex_dict in data: + if can_copy: + src_field = fields['src'] + tgt_field = fields['tgt'] + # this assumes src_field and tgt_field are both text + ex_dict = _dynamic_dict( + ex_dict, src_field.base_field, tgt_field.base_field) + self.src_vocabs.append(ex_dict["src_ex_vocab"]) + ex_fields = {k: [(k, v)] for k, v in fields.items() if + k in ex_dict} + ex = Example.fromdict(ex_dict, ex_fields) + examples.append(ex) + + # fields needs to have only keys that examples have as attrs + fields = [] + for _, nf_list in ex_fields.items(): + assert len(nf_list) == 1 + fields.append(nf_list[0]) + + super(Dataset, self).__init__(examples, fields, filter_pred) + + def __getattr__(self, attr): + # avoid infinite recursion when fields isn't defined + if 'fields' not in vars(self): + raise AttributeError + if attr in self.fields: + return (getattr(x, attr) for x in self.examples) + else: + raise AttributeError + + def save(self, path, remove_fields=True): + if remove_fields: + self.fields = [] + torch.save(self, path) + + @staticmethod + def config(fields): + readers, data = [], [] + for name, field in fields: + if field["data"] is not None: + readers.append(field["reader"]) + data.append((name, field["data"], field.get("features", {}))) + return readers, data diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index ddb51d5c..da6a1584 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -1,18 +1,19 @@ # -*- coding: utf-8 -*- from functools import partial +from itertools import repeat import torch from torchtext.data import Field, RawField from onmt.constants import DefaultTokens from onmt.inputters.datareader_base import DataReaderBase +from onmt.utils.misc import split_corpus class TextDataReader(DataReaderBase): def read(self, sequences, side, features={}): """Read text data from disk. - - Args: + Args: sequences (str or Iterable[str]): path to text file or iterable of the actual text data. side (str): Prefix used in return dict. Usually @@ -20,7 +21,6 @@ class TextDataReader(DataReaderBase): features: (Dict[str or Iterable[str]]): dictionary mapping feature names with the path to feature file or iterable of the actual feature data. - Yields: dictionaries whose keys are the names of fields and whose values are more or less the result of tokenizing with those @@ -49,6 +49,129 @@ class TextDataReader(DataReaderBase): yield {side: ex_dict, "indices": i} +class InferenceDataReader(object): + """It handles inference data reading from disk in shards. + + Args: + src (str): path to the source file + tgt (str or NoneType): path to the target file + src_feats (Dict[str]): paths to the features files + shard_size (int): divides files into smaller files of size shard_size + + Returns: + Tuple[List[str], List[str], Dict[List[str]]] + """ + + def __init__(self, src, tgt, src_feats={}, shard_size=10000): + self.src = src + self.tgt = tgt + self.src_feats = src_feats + self.shard_size = shard_size + + def __iter__(self): + src_shards = split_corpus(self.src, self.shard_size) + tgt_shards = split_corpus(self.tgt, self.shard_size) + + if not self.src_feats: + features_shards = [repeat(None)] + else: + features_shards = [] + features_names = [] + for feat_name, feat_path in self.src_feats.items(): + features_shards.append( + split_corpus(feat_path, self.shard_size)) + features_names.append(feat_name) + + shard_pairs = zip(src_shards, tgt_shards, *features_shards) + for i, shard in enumerate(shard_pairs): + src_shard, tgt_shard, *features_shard = shard + if features_shard[0] is not None: + features_shard_ = dict() + for j, x in enumerate(features_shard): + features_shard_[features_names[j]] = x + else: + features_shard_ = None + yield src_shard, tgt_shard, features_shard_ + + +class InferenceDataIterator(object): + + def __init__(self, src, tgt, src_feats, transform): + self.src = src + self.tgt = tgt + self.src_feats = src_feats + self.transform = transform + + def _tokenize(self, example): + example['src'] = example['src'].decode("utf-8").strip('\n').split() + example['tgt'] = example['tgt'].decode("utf-8").strip('\n').split() \ + if example["tgt"] is not None else None + example['src_original'] = example['src'] + example['tgt_original'] = example['tgt'] + if 'src_feats' in example: + for k in example['src_feats'].keys(): + example['src_feats'][k] = example['src_feats'][k] \ + .decode("utf-8").strip('\n').split() \ + if example['src_feats'][k] is not None else None + return example + + def _transform(self, example, remove_tgt=False): + maybe_example = self.transform.apply( + example, is_train=False, corpus_name="translate") + assert maybe_example is not None, \ + "Transformation on example skipped the example. " \ + "Please check the transforms." + return maybe_example + + def _process(self, example, remove_tgt=False): + example['src'] = {"src": ' '.join(example['src'])} + example['tgt'] = {"tgt": ' '.join(example['tgt'])} + + # Make features part of src as in TextMultiField + # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} + if 'src_feats' in example: + for feat_name, feat_value in example['src_feats'].items(): + example['src'][feat_name] = ' '.join(feat_value) + del example["src_feats"] + + # Cleanup + if remove_tgt: + del example["tgt"] + del example["tgt_original"] + del example["src_original"] + + return example + + def __iter__(self): + tgt = self.tgt if self.tgt is not None else repeat(None) + + if self.src_feats is not None: + features_names = [] + features_values = [] + for feat_name, values in self.src_feats.items(): + features_names.append(feat_name) + features_values.append(values) + else: + features_values = [repeat(None)] + + for i, (src, tgt, *src_feats) in enumerate(zip( + self.src, tgt, *features_values)): + ex = { + "src": src, + "tgt": tgt if tgt is not None else b"" + } + if src_feats[0] is not None: + src_feats_ = {} + for j, x in enumerate(src_feats): + src_feats_[features_names[j]] = x + ex["src_feats"] = src_feats_ + ex = self._tokenize(ex) + ex = self._transform(ex) + ex = self._process(ex, remove_tgt=self.tgt is None) + ex["indices"] = i + yield ex + + def text_sort_key(ex): """Sort using the number of tokens in the sequence.""" if hasattr(ex, "tgt"): diff --git a/onmt/opts.py b/onmt/opts.py index ef12a9d7..28504d21 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -734,7 +734,7 @@ def _add_decoding_opts(parser): "the table), then it will copy the source token.") -def translate_opts(parser): +def translate_opts(parser, dynamic=False): """ Translation / inference options """ group = parser.add_argument_group('Model') group.add('--model', '-model', dest='models', metavar='MODEL', @@ -801,6 +801,14 @@ def translate_opts(parser): group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") + if dynamic: + group.add("-transforms", "--transforms", default=[], nargs="+", + choices=AVAILABLE_TRANSFORMS.keys(), + help="Default transform pipeline to apply to data.") + + # Adding options related to Transforms + _add_dynamic_transform_opts(parser) + # Copyright 2016 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index d4b92cb0..cf519b05 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -410,6 +410,12 @@ class ONMTTokenizerTransform(TokenizerTransform): segmented, _ = tokenizer.tokenize(sentence) return segmented + def _detokenize(self, tokens, side='src', is_train=False): + """Do OpenNMT Tokenizer's detokenize.""" + tokenizer = self.load_models[side] + detokenized = tokenizer.detokenize(tokens) + return detokenized + def apply(self, example, is_train=False, stats=None, **kwargs): """Apply OpenNMT Tokenizer to src & tgt.""" src_out = self._tokenize(example['src'], 'src') @@ -421,6 +427,10 @@ class ONMTTokenizerTransform(TokenizerTransform): example['src'], example['tgt'] = src_out, tgt_out return example + def apply_reverse(self, translated): + """Apply OpenNMT Tokenizer to src & tgt.""" + return self._detokenize(translated.split(), 'tgt') + def _repr_args(self): """Return str represent key arguments for class.""" repr_str = '{}={}'.format('share_vocab', self.share_vocab) diff --git a/onmt/transforms/transform.py b/onmt/transforms/transform.py index 6dd5869d..88b6f40c 100644 --- a/onmt/transforms/transform.py +++ b/onmt/transforms/transform.py @@ -60,6 +60,9 @@ class Transform(object): """ raise NotImplementedError + def apply_reverse(self, translated): + return translated + def __getstate__(self): """Pickling following for rebuild.""" state = {"opts": self.opts} @@ -193,6 +196,11 @@ class TransformPipe(Transform): break return example + def apply_reverse(self, translated): + for transform in self.transforms: + translated = transform.apply_reverse(translated) + return translated + def __getstate__(self): """Pickling following for rebuild.""" return (self.opts, self.transforms, self.statistics) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 81a98926..69892cec 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -12,6 +12,7 @@ from onmt.constants import DefaultTokens import onmt.model_builder import onmt.inputters as inputters import onmt.decoders.ensemble +from onmt.inputters.text_dataset import InferenceDataIterator from onmt.translate.beam_search import BeamSearch, BeamSearchLM from onmt.translate.greedy_search import GreedySearch, GreedySearchLM from onmt.utils.misc import tile, set_random_seed, report_matrix @@ -330,6 +331,45 @@ class Inference(object): gs = [0] * batch_size return gs + def translate_dynamic( + self, + src, + transform, + src_feats={}, + tgt=None, + batch_size=None, + batch_type="sents", + attn_debug=False, + align_debug=False, + phrase_table="" + ): + + if batch_size is None: + raise ValueError("batch_size must be set") + + if self.tgt_prefix and tgt is None: + raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") + + data_iter = InferenceDataIterator(src, tgt, src_feats, transform) + + data = inputters.DynamicDataset( + self.fields, + data=data_iter, + sort_key=inputters.str2sortkey[self.data_type], + filter_pred=self._filter_pred, + ) + + return self._translate( + data, + tgt=tgt, + batch_size=batch_size, + batch_type=batch_type, + attn_debug=attn_debug, + align_debug=align_debug, + phrase_table=phrase_table, + dynamic=True, + transform=transform) + def translate( self, src, @@ -387,6 +427,28 @@ class Inference(object): filter_pred=self._filter_pred, ) + return self._translate( + data, + tgt=tgt, + batch_size=batch_size, + batch_type=batch_type, + attn_debug=attn_debug, + align_debug=align_debug, + phrase_table=phrase_table) + + def _translate( + self, + data, + tgt=None, + batch_size=None, + batch_type="sents", + attn_debug=False, + align_debug=False, + phrase_table="", + dynamic=False, + transform=None + ): + data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, @@ -448,6 +510,10 @@ class Inference(object): n_best_preds, n_best_preds_align ) ] + + if dynamic: + n_best_preds = [transform.apply_reverse(x) + for x in n_best_preds] all_predictions += [n_best_preds] self.out_file.write("\n".join(n_best_preds) + "\n") self.out_file.flush() diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 35f398af..edc61f44 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -123,6 +123,10 @@ class DataOptsCheckerMixin(object): opt._all_transform = all_transforms @classmethod + def _get_all_transform_translate(cls, opt): + opt._all_transform = opt.transforms + + @classmethod def _validate_fields_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab and fields.""" @@ -327,3 +331,9 @@ class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin): @classmethod def validate_translate_opts(cls, opt): opt.src_feats = eval(opt.src_feats) if opt.src_feats else {} + + @classmethod + def validate_translate_opts_dynamic(cls, opt): + # It comes from training + # TODO: needs to be added as inference opt + opt.share_vocab = False @@ -35,6 +35,7 @@ setup( "onmt_server=onmt.bin.server:main", "onmt_train=onmt.bin.train:main", "onmt_translate=onmt.bin.translate:main", + "onmt_translate_dynamic=onmt.bin.translate_dynamic:main", "onmt_release_model=onmt.bin.release_model:main", "onmt_average_models=onmt.bin.average_models:main", "onmt_build_vocab=onmt.bin.build_vocab:main" diff --git a/translate_dynamic.py b/translate_dynamic.py new file mode 100644 index 00000000..19ee4023 --- /dev/null +++ b/translate_dynamic.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +from onmt.bin.translate_dynamic import main + + +if __name__ == "__main__": + main() |