Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoranderleich <andercorral95@gmail.com>2021-09-09 13:31:03 +0300
committerGitHub <noreply@github.com>2021-09-09 13:31:03 +0300
commit7e8c7c2a6ac1d28265f27c861fc63ef72c1f0fae (patch)
tree868ee1ef4b129b24a83c95bd7ef3f28495244054
parent4cd9978564c77e3c3768cd58119ad81d5dfd8b73 (diff)
Source features support for V2.0 (#2090)
-rw-r--r--.github/workflows/push.yml30
-rw-r--r--data/data_features/src-test.feat01
-rw-r--r--data/data_features/src-test.txt1
-rw-r--r--data/data_features/src-train.feat03
-rw-r--r--data/data_features/src-train.txt3
-rw-r--r--data/data_features/src-val.feat01
-rw-r--r--data/data_features/src-val.txt1
-rw-r--r--data/data_features/tgt-train.txt3
-rw-r--r--data/data_features/tgt-val.txt1
-rw-r--r--data/features_data.yaml11
-rw-r--r--docs/source/FAQ.md70
-rw-r--r--onmt/bin/build_vocab.py7
-rwxr-xr-xonmt/bin/translate.py16
-rw-r--r--onmt/constants.py1
-rw-r--r--onmt/inputters/corpus.py65
-rw-r--r--onmt/inputters/dataset_base.py8
-rw-r--r--onmt/inputters/fields.py13
-rw-r--r--onmt/inputters/inputter.py12
-rw-r--r--onmt/inputters/text_dataset.py76
-rw-r--r--onmt/opts.py8
-rwxr-xr-xonmt/tests/pull_request_chk.sh46
-rw-r--r--onmt/tests/test_subword_marker.py33
-rw-r--r--onmt/tests/test_text_dataset.py26
-rw-r--r--onmt/tests/test_transform.py22
-rw-r--r--onmt/transforms/features.py90
-rw-r--r--onmt/translate/translator.py8
-rw-r--r--onmt/utils/alignment.py42
-rw-r--r--onmt/utils/parse.py27
28 files changed, 549 insertions, 76 deletions
diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml
index 9780df63..66d892ef 100644
--- a/.github/workflows/push.yml
+++ b/.github/workflows/push.yml
@@ -42,6 +42,16 @@ jobs:
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
&& rm -rf /tmp/sample
+ - name: Test vocabulary build with features
+ run: |
+ python onmt/bin/build_vocab.py \
+ -config data/features_data.yaml \
+ -save_data /tmp/onmt_feat \
+ -src_vocab /tmp/onmt_feat.vocab.src \
+ -tgt_vocab /tmp/onmt_feat.vocab.tgt \
+ -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
+ -n_sample -1 \
+ && rm -rf /tmp/sample
- name: Test field/transform dump
run: |
# The dumped fields are used later when testing tools
@@ -169,6 +179,26 @@ jobs:
-state_dim 256 \
-n_steps 10 \
-n_node 64
+ - name: Testing training with features
+ run: |
+ python onmt/bin/train.py \
+ -config data/features_data.yaml \
+ -src_vocab /tmp/onmt_feat.vocab.src \
+ -tgt_vocab /tmp/onmt_feat.vocab.tgt \
+ -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
+ -src_vocab_size 1000 -tgt_vocab_size 1000 \
+ -rnn_size 2 -batch_size 10 \
+ -word_vec_size 5 -rnn_size 10 \
+ -report_every 5 -train_steps 10 \
+ -save_model /tmp/onmt.model \
+ -save_checkpoint_steps 10
+ - name: Testing translation with features
+ run: |
+ python translate.py \
+ -model /tmp/onmt.model_step_10.pt \
+ -src data/data_features/src-test.txt \
+ -src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \
+ -verbose
- name: Test RNN translation
run: |
head data/src-test.txt > /tmp/src-test.txt
diff --git a/data/data_features/src-test.feat0 b/data/data_features/src-test.feat0
new file mode 100644
index 00000000..4ab4a9e6
--- /dev/null
+++ b/data/data_features/src-test.feat0
@@ -0,0 +1 @@
+C B A B \ No newline at end of file
diff --git a/data/data_features/src-test.txt b/data/data_features/src-test.txt
new file mode 100644
index 00000000..0cc723ce
--- /dev/null
+++ b/data/data_features/src-test.txt
@@ -0,0 +1 @@
+she is a hard-working. \ No newline at end of file
diff --git a/data/data_features/src-train.feat0 b/data/data_features/src-train.feat0
new file mode 100644
index 00000000..7e189f2c
--- /dev/null
+++ b/data/data_features/src-train.feat0
@@ -0,0 +1,3 @@
+A A A A B A A A C
+A B C D E
+C B A B \ No newline at end of file
diff --git a/data/data_features/src-train.txt b/data/data_features/src-train.txt
new file mode 100644
index 00000000..8a3ec35c
--- /dev/null
+++ b/data/data_features/src-train.txt
@@ -0,0 +1,3 @@
+however, according to the logs, she is a hard-working.
+however, according to the logs,
+she is a hard-working. \ No newline at end of file
diff --git a/data/data_features/src-val.feat0 b/data/data_features/src-val.feat0
new file mode 100644
index 00000000..4ab4a9e6
--- /dev/null
+++ b/data/data_features/src-val.feat0
@@ -0,0 +1 @@
+C B A B \ No newline at end of file
diff --git a/data/data_features/src-val.txt b/data/data_features/src-val.txt
new file mode 100644
index 00000000..0cc723ce
--- /dev/null
+++ b/data/data_features/src-val.txt
@@ -0,0 +1 @@
+she is a hard-working. \ No newline at end of file
diff --git a/data/data_features/tgt-train.txt b/data/data_features/tgt-train.txt
new file mode 100644
index 00000000..8a3ec35c
--- /dev/null
+++ b/data/data_features/tgt-train.txt
@@ -0,0 +1,3 @@
+however, according to the logs, she is a hard-working.
+however, according to the logs,
+she is a hard-working. \ No newline at end of file
diff --git a/data/data_features/tgt-val.txt b/data/data_features/tgt-val.txt
new file mode 100644
index 00000000..0cc723ce
--- /dev/null
+++ b/data/data_features/tgt-val.txt
@@ -0,0 +1 @@
+she is a hard-working. \ No newline at end of file
diff --git a/data/features_data.yaml b/data/features_data.yaml
new file mode 100644
index 00000000..fa9b665f
--- /dev/null
+++ b/data/features_data.yaml
@@ -0,0 +1,11 @@
+# Corpus opts:
+data:
+ corpus_1:
+ path_src: data/data_features/src-train.txt
+ path_tgt: data/data_features/tgt-train.txt
+ src_feats:
+ feat0: data/data_features/src-train.feat0
+ transforms: [filterfeats, inferfeats]
+ valid:
+ path_src: data/data_features/src-val.txt
+ path_tgt: data/data_features/tgt-val.txt
diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md
index f40fada2..8f618f6c 100644
--- a/docs/source/FAQ.md
+++ b/docs/source/FAQ.md
@@ -477,3 +477,73 @@ Training options to perform vocabulary update are:
* `-update_vocab`: set this option
* `-reset_optim`: set the value to "states"
* `-train_from`: checkpoint path
+
+
+## How can I use source word features?
+
+Extra information can be added to the words in the source sentences by defining word features.
+
+Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files:
+
+data.src
+```
+however, according to the logs, she is hard-working.
+```
+
+feat0.txt
+```
+A C C C C A A B
+```
+
+**Notes**
+- Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform.
+- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality.
+- Not possible to do shared embeddings (at least with `feat_merge: concat` method)
+
+Sample config file:
+
+```
+data:
+ dummy:
+ path_src: data/train/data.src
+ path_tgt: data/train/data.tgt
+ src_feats:
+ feat_0: data/train/data.src.feat_0
+ feat_1: data/train/data.src.feat_1
+ transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong]
+ weight: 1
+ valid:
+ path_src: data/valid/data.src
+ path_tgt: data/valid/data.tgt
+ src_feats:
+ feat_0: data/valid/data.src.feat_0
+ feat_1: data/valid/data.src.feat_1
+ transforms: [filterfeats, onmt_tokenize, inferfeats]
+
+# # Vocab opts
+src_vocab: exp/data.vocab.src
+tgt_vocab: exp/data.vocab.tgt
+src_feats_vocab:
+ feat_0: exp/data.vocab.feat_0
+ feat_1: exp/data.vocab.feat_1
+feat_merge: "sum"
+
+```
+
+During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature name with its data file.
+
+```
+{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}
+```
+
+**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example:
+
+```bash
+python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}"
+```
+
+When using the Transformer architecture make sure the following options are appropriately set:
+
+- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
+- `feat_merge`: how to handle features vecs
+- `feat_vec_size` and maybe `feat_vec_exponent`
diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py
index e106d921..ed510f09 100644
--- a/onmt/bin/build_vocab.py
+++ b/onmt/bin/build_vocab.py
@@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)
logger.info(f"Counter vocab from {opts.n_sample} samples.")
- src_counter, tgt_counter = build_vocab(
+ src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)
logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
+ for feat_name, feat_counter in src_feats_counter.items():
+ logger.info(f"Counters {feat_name}:{len(feat_counter)}")
def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
@@ -52,6 +54,9 @@ def build_vocab_main(opts):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)
+
+ for k, v in src_feats_counter.items():
+ save_counter(v, opts.src_feats_vocab[k])
def _get_parser():
diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py
index 0b5434f8..4e3e126a 100755
--- a/onmt/bin/translate.py
+++ b/onmt/bin/translate.py
@@ -6,6 +6,7 @@ from onmt.translate.translator import build_translator
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
+from collections import defaultdict
def translate(opt):
@@ -15,12 +16,21 @@ def translate(opt):
translator = build_translator(opt, logger=logger, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
- shard_pairs = zip(src_shards, tgt_shards)
-
- for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
+ features_shards = []
+ features_names = []
+ for feat_name, feat_path in opt.src_feats.items():
+ features_shards.append(split_corpus(feat_path, opt.shard_size))
+ features_names.append(feat_name)
+ shard_pairs = zip(src_shards, tgt_shards, *features_shards)
+
+ for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs):
+ features_shard_ = defaultdict(list)
+ for j, x in enumerate(features_shard):
+ features_shard_[features_names[j]] = x
logger.info("Translating shard %d." % i)
translator.translate(
src=src_shard,
+ src_feats=features_shard_,
tgt=tgt_shard,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
diff --git a/onmt/constants.py b/onmt/constants.py
index fb6afb02..2d586413 100644
--- a/onmt/constants.py
+++ b/onmt/constants.py
@@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
+ CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]
class ModelTask(object):
diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py
index c8a559f9..87da6513 100644
--- a/onmt/inputters/corpus.py
+++ b/onmt/inputters/corpus.py
@@ -7,10 +7,11 @@ from onmt.inputters.dataset_base import _dynamic_dict
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample
-from collections import Counter
+from collections import Counter, defaultdict
from contextlib import contextmanager
import multiprocessing as mp
+from collections import defaultdict
@contextmanager
@@ -70,10 +71,20 @@ class DatasetAdapter(object):
example, is_train=is_train, corpus_name=cid)
if maybe_example is None:
return None
- maybe_example['src'] = ' '.join(maybe_example['src'])
- maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
+
+ maybe_example['src'] = {"src": ' '.join(maybe_example['src'])}
+
+ # Make features part of src as in TextMultiField
+ # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}}
+ if 'src_feats' in maybe_example:
+ for feat_name, feat_value in maybe_example['src_feats'].items():
+ maybe_example['src'][feat_name] = ' '.join(feat_value)
+ del maybe_example["src_feats"]
+
+ maybe_example['tgt'] = {"tgt": ' '.join(maybe_example['tgt'])}
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
+
return maybe_example
def _maybe_add_dynamic_dict(self, example, fields):
@@ -107,12 +118,13 @@ class DatasetAdapter(object):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""
- def __init__(self, name, src, tgt, align=None):
+ def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
+ self.src_feats = src_feats
def load(self, offset=0, stride=1):
"""
@@ -120,10 +132,18 @@ class ParallelCorpus(object):
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
+ if self.src_feats:
+ features_names = []
+ features_files = []
+ for feat_name, feat_path in self.src_feats.items():
+ features_names.append(feat_name)
+ features_files.append(open(feat_path, mode='rb'))
+ else:
+ features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
- for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
+ for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
@@ -133,12 +153,18 @@ class ParallelCorpus(object):
}
if align is not None:
example['align'] = align.decode('utf-8')
+ if features:
+ example["src_feats"] = dict()
+ for j, feat in enumerate(features):
+ example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
+ for f in features_files:
+ f.close()
def __str__(self):
cls_name = type(self).__name__
- return '{}({}, {}, align={})'.format(
- cls_name, self.src, self.tgt, self.align)
+ return '{}({}, {}, align={}, src_feats={})'.format(
+ cls_name, self.src, self.tgt, self.align, self.src_feats)
def get_corpora(opts, is_train=False):
@@ -150,14 +176,16 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
- corpus_dict["path_align"])
+ corpus_dict["path_align"],
+ corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
CorpusName.VALID,
opts.data[CorpusName.VALID]["path_src"],
opts.data[CorpusName.VALID]["path_tgt"],
- opts.data[CorpusName.VALID]["path_align"])
+ opts.data[CorpusName.VALID]["path_align"],
+ opts.data[CorpusName.VALID]["src_feats"])
else:
return None
return corpora_dict
@@ -193,6 +221,9 @@ class ParallelCorpusIterator(object):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
+ if 'src_feats' in example:
+ for k in example['src_feats'].keys():
+ example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example
def _transform(self, stream):
@@ -286,6 +317,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
+ sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
@@ -297,7 +329,10 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("blank")
continue
- src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
+ src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt']
+ for feat_name, feat_line in maybe_example["src"].items():
+ if feat_name != "src":
+ sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
@@ -309,7 +344,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
- return sub_counter_src, sub_counter_tgt
+ return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
def init_pool(queues):
@@ -333,6 +368,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
+ counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
@@ -349,13 +385,14 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
- for sub_counter_src, sub_counter_tgt in p.imap(
+ for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
+ counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
- return counter_src, counter_tgt
+ return counter_src, counter_tgt, counter_src_feats
def save_transformed_sample(opts, transforms, n_sample=3):
@@ -387,7 +424,7 @@ def save_transformed_sample(opts, transforms, n_sample=3):
maybe_example = DatasetAdapter._process(item, is_train=True)
if maybe_example is None:
continue
- src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
+ src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt']
f_src.write(src_line + '\n')
f_tgt.write(tgt_line + '\n')
if n_sample > 0 and i >= n_sample:
diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py
index aeec428a..65322d9a 100644
--- a/onmt/inputters/dataset_base.py
+++ b/onmt/inputters/dataset_base.py
@@ -41,7 +41,7 @@ def _dynamic_dict(example, src_field, tgt_field):
``example``, changed as described.
"""
- src = src_field.tokenize(example["src"])
+ src = src_field.tokenize(example["src"]["src"])
# make a small vocab containing just the tokens in the source sequence
unk = src_field.unk_token
pad = src_field.pad_token
@@ -60,7 +60,7 @@ def _dynamic_dict(example, src_field, tgt_field):
example["src_ex_vocab"] = src_ex_vocab
if "tgt" in example:
- tgt = tgt_field.tokenize(example["tgt"])
+ tgt = tgt_field.tokenize(example["tgt"]["tgt"])
mask = torch.LongTensor(
[unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx])
example["alignment"] = mask
@@ -116,7 +116,7 @@ class Dataset(TorchtextDataset):
self.sort_key = sort_key
can_copy = 'src_map' in fields and 'alignment' in fields
- read_iters = [r.read(dat[1], dat[0]) for r, dat in zip(readers, data)]
+ read_iters = [r.read(dat, name, feats) for r, (name, dat, feats) in zip(readers, data)]
# self.src_vocabs is used in collapse_copy_scores and Translator.py
self.src_vocabs = []
@@ -162,5 +162,5 @@ class Dataset(TorchtextDataset):
for name, field in fields:
if field["data"] is not None:
readers.append(field["reader"])
- data.append((name, field["data"]))
+ data.append((name, field["data"], field["features"]))
return readers, data
diff --git a/onmt/inputters/fields.py b/onmt/inputters/fields.py
index 50c4e6c1..5f41a3a0 100644
--- a/onmt/inputters/fields.py
+++ b/onmt/inputters/fields.py
@@ -8,11 +8,10 @@ from onmt.inputters.inputter import get_fields, _load_vocab, \
def _get_dynamic_fields(opts):
- # NOTE: not support nfeats > 0 yet
- src_nfeats = 0
- tgt_nfeats = 0
+ # NOTE: not support tgt feats yet
+ tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
- fields = get_fields('text', src_nfeats, tgt_nfeats,
+ fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
@@ -33,6 +32,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)
+ if opts.src_feats_vocab:
+ for feat_name, filepath in opts.src_feats_vocab.items():
+ _, _ = _load_vocab(
+ filepath, feat_name, counters,
+ min_freq=0)
+
if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py
index f6b5c747..ffd8c77f 100644
--- a/onmt/inputters/inputter.py
+++ b/onmt/inputters/inputter.py
@@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):
def get_fields(
src_data_type,
- n_src_feats,
- n_tgt_feats,
+ src_feats,
+ tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
@@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
- n_src_feats (int): the number of source features (not counting tokens)
+ src_feats (Optional[Dict]): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
- n_tgt_feats (int): See above.
+ tgt_feats (Optional[Dict]): See above.
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
@@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)
src_field_kwargs = {
- "n_feats": n_src_feats,
+ "feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
@@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)
tgt_field_kwargs = {
- "n_feats": n_tgt_feats,
+ "feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py
index a0621f64..a55d2593 100644
--- a/onmt/inputters/text_dataset.py
+++ b/onmt/inputters/text_dataset.py
@@ -9,7 +9,7 @@ from onmt.inputters.datareader_base import DataReaderBase
class TextDataReader(DataReaderBase):
- def read(self, sequences, side):
+ def read(self, sequences, side, features={}):
"""Read text data from disk.
Args:
@@ -17,6 +17,9 @@ class TextDataReader(DataReaderBase):
path to text file or iterable of the actual text data.
side (str): Prefix used in return dict. Usually
``"src"`` or ``"tgt"``.
+ features: (Dict[str or Iterable[str]]):
+ dictionary mapping feature names with the path to feature
+ file or iterable of the actual feature data.
Yields:
dictionaries whose keys are the names of fields and whose
@@ -25,10 +28,25 @@ class TextDataReader(DataReaderBase):
"""
if isinstance(sequences, str):
sequences = DataReaderBase._read_file(sequences)
- for i, seq in enumerate(sequences):
+
+ features_names = []
+ features_values = []
+ for feat_name, v in features.items():
+ features_names.append(feat_name)
+ if isinstance(v, str):
+ features_values.append(DataReaderBase._read_file(features))
+ else:
+ features_values.append(v)
+ for i, (seq, *feats) in enumerate(zip(sequences, *features_values)):
+ ex_dict = {}
if isinstance(seq, bytes):
seq = seq.decode("utf-8")
- yield {side: seq, "indices": i}
+ ex_dict[side] = seq
+ for i, f in enumerate(feats):
+ if isinstance(f, bytes):
+ f = f.decode("utf-8")
+ ex_dict[features_names[i]] = f
+ yield {side: ex_dict, "indices": i}
def text_sort_key(ex):
@@ -38,6 +56,7 @@ def text_sort_key(ex):
return len(ex.src[0])
+# Legacy function. Currently it only truncates input if truncate is set.
# mix this with partial
def _feature_tokenize(
string, layer=0, tok_delim=None, feat_delim=None, truncate=None):
@@ -140,8 +159,7 @@ class TextMultiField(RawField):
lists of tokens/feature tags for the sentence. The output
is ordered like ``self.fields``.
"""
-
- return [f.preprocess(x) for _, f in self.fields]
+ return [f.preprocess(x[fn]) for fn, f in self.fields]
def __getitem__(self, item):
return self.fields[item]
@@ -152,7 +170,7 @@ def text_fields(**kwargs):
Args:
base_name (str): Name associated with the field.
- n_feats (int): Number of word level feats (not counting the tokens)
+ feats (Optional[Dict]): Word level feats
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
@@ -163,7 +181,7 @@ def text_fields(**kwargs):
TextMultiField
"""
- n_feats = kwargs["n_feats"]
+ feats = kwargs["feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", DefaultTokens.PAD)
@@ -171,20 +189,36 @@ def text_fields(**kwargs):
eos = kwargs.get("eos", DefaultTokens.EOS)
truncate = kwargs.get("truncate", None)
fields_ = []
- feat_delim = u"│" if n_feats > 0 else None
- for i in range(n_feats + 1):
- name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
- tokenize = partial(
- _feature_tokenize,
- layer=i,
- truncate=truncate,
- feat_delim=feat_delim)
- use_len = i == 0 and include_lengths
- feat = Field(
- init_token=bos, eos_token=eos,
- pad_token=pad, tokenize=tokenize,
- include_lengths=use_len)
- fields_.append((name, feat))
+
+ feat_delim = None #u"│" if n_feats > 0 else None
+
+ # Base field
+ tokenize = partial(
+ _feature_tokenize,
+ layer=None,
+ truncate=truncate,
+ feat_delim=feat_delim)
+ feat = Field(
+ init_token=bos, eos_token=eos,
+ pad_token=pad, tokenize=tokenize,
+ include_lengths=include_lengths)
+ fields_.append((base_name, feat))
+
+ # Feats fields
+ if feats:
+ for feat_name in feats.keys():
+ # Legacy function, it is not really necessary
+ tokenize = partial(
+ _feature_tokenize,
+ layer=None,
+ truncate=truncate,
+ feat_delim=feat_delim)
+ feat = Field(
+ init_token=bos, eos_token=eos,
+ pad_token=pad, tokenize=tokenize,
+ include_lengths=False)
+ fields_.append((feat_name, feat))
+
assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field
diff --git a/onmt/opts.py b/onmt/opts.py
index ec66f14e..4c37ab95 100644
--- a/onmt/opts.py
+++ b/onmt/opts.py
@@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False):
group.add("-share_vocab", "--share_vocab", action="store_true",
help="Share source and target vocabulary.")
+ group.add("-src_feats_vocab", "--src_feats_vocab",
+ help=("List of paths to save" if build_vocab_only else "List of paths to")
+ + " src features vocabulary files. "
+ "Files format: one <word> or <word>\t<count> per line.")
+
if not build_vocab_only:
group.add("-src_vocab_size", "--src_vocab_size",
type=int, default=50000,
@@ -755,6 +760,9 @@ def translate_opts(parser):
group.add('--src', '-src', required=True,
help="Source sequence to decode (one line per "
"sequence)")
+ group.add("-src_feats", "--src_feats", required=False,
+ help="Source sequence features (dict format). "
+ "Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}")
group.add('--tgt', '-tgt',
help='True target sequence (optional)')
group.add('--tgt_prefix', '-tgt_prefix', action='store_true',
diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh
index b282cc7f..70cd7682 100755
--- a/onmt/tests/pull_request_chk.sh
+++ b/onmt/tests/pull_request_chk.sh
@@ -67,10 +67,22 @@ PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \
-save_data $TMP_OUT_DIR/onmt \
-src_vocab $TMP_OUT_DIR/onmt.vocab.src \
-tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \
- -n_sample 5000 >> ${LOG_FILE} 2>&1
+ -n_sample 5000 -overwrite >> ${LOG_FILE} 2>&1
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}
-rm -r $TMP_OUT_DIR/sample
+rm -f -r $TMP_OUT_DIR/sample
+
+echo -n "[+] Testing vocabulary building with features..."
+PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \
+ -config ${DATA_DIR}/features_data.yaml \
+ -save_data $TMP_OUT_DIR/onmt_feat \
+ -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \
+ -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \
+ -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \
+ -n_sample -1 -overwrite>> ${LOG_FILE} 2>&1
+[ "$?" -eq 0 ] || error_exit
+echo "Succeeded" | tee -a ${LOG_FILE}
+rm -f -r $TMP_OUT_DIR/sample
#
# Training test
@@ -254,8 +266,24 @@ ${PYTHON} onmt/bin/train.py \
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}
-rm $TMP_OUT_DIR/onmt.vocab*
-rm $TMP_OUT_DIR/onmt.model*
+echo -n " [+] Testing training with features..."
+${PYTHON} onmt/bin/train.py \
+ -config ${DATA_DIR}/features_data.yaml \
+ -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \
+ -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \
+ -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \
+ -src_vocab_size 1000 -tgt_vocab_size 1000 \
+ -rnn_size 2 -batch_size 10 \
+ -word_vec_size 5 -rnn_size 10 \
+ -report_every 5 -train_steps 10 \
+ -save_model $TMP_OUT_DIR/onmt.features.model \
+ -save_checkpoint_steps 10 >> ${LOG_FILE} 2>&1
+[ "$?" -eq 0 ] || error_exit
+echo "Succeeded" | tee -a ${LOG_FILE}
+
+rm -f $TMP_OUT_DIR/onmt.vocab*
+rm -f $TMP_OUT_DIR/onmt.model*
+rm -f $TMP_OUT_DIR/onmt_feat.vocab.*
#
# Translation test
@@ -269,6 +297,16 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt -src $TMP_OUT_DIR/src-te
echo "Succeeded" | tee -a ${LOG_FILE}
rm $TMP_OUT_DIR/src-test.txt
+echo -n " [+] Testing NMT translation with features..."
+${PYTHON} translate.py \
+ -model ${TMP_OUT_DIR}/onmt.features.model_step_10.pt \
+ -src ${DATA_DIR}/data_features/src-test.txt \
+ -src_feats "{'feat0': '${DATA_DIR}/data_features/src-test.feat0'}" \
+ -verbose >> ${LOG_FILE} 2>&1
+[ "$?" -eq 0 ] || error_exit
+echo "Succeeded" | tee -a ${LOG_FILE}
+rm -f $TMP_OUT_DIR/onmt.features.model*
+
echo -n " [+] Testing NMT ensemble translation..."
head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt
${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt ${TEST_DIR}/test_model.pt \
diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py
index e827d52f..1b8337b5 100644
--- a/onmt/tests/test_subword_marker.py
+++ b/onmt/tests/test_subword_marker.py
@@ -2,6 +2,7 @@ import unittest
from onmt.transforms.bart import word_start_finder
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
+from onmt.constants import DefaultTokens, SubwordMarker
class TestWordStartFinder(unittest.TestCase):
@@ -37,7 +38,25 @@ class TestSubwordGroup(unittest.TestCase):
def test_subword_group_joiner(self):
data_in = ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'] # noqa: E501
true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7]
- out = subword_map_by_joiner(data_in)
+ out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
+ self.assertEqual(out, true_out)
+
+ def test_subword_group_joiner_with_case_markup(self):
+ data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] # noqa: E501
+ true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
+ out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
+ self.assertEqual(out, true_out)
+
+ def test_subword_group_joiner_with_new_joiner(self):
+ data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■', ',', 'according', 'to', 'the', 'logs', '■', ',', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■', '-', '■', 'working', '■', '.', '⦅mrk_end_case_region_U⦆'] # noqa: E501
+ true_out = [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7]
+ out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
+ self.assertEqual(out, true_out)
+
+ def test_subword_group_naive(self):
+ data_in = ['however', ',', 'according', 'to', 'the', 'logs', ',', 'she', 'is', 'hard', '-', 'working', '.'] # noqa: E501
+ true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+ out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP)
self.assertEqual(out, true_out)
def test_subword_group_spacer(self):
@@ -50,6 +69,18 @@ class TestSubwordGroup(unittest.TestCase):
no_dummy_out = subword_map_by_spacer(no_dummy)
self.assertEqual(no_dummy_out, true_out)
+ def test_subword_group_spacer_with_case_markup(self):
+ data_in = ['⦅mrk_case_modifier_C⦆', '▁however', ',', '▁according', '▁to', '▁the', '▁logs', ',', '▁⦅mrk_begin_case_region_U⦆', '▁she', '▁is', '▁hard', '-', 'working', '.', '▁⦅mrk_end_case_region_U⦆'] # noqa: E501
+ true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
+ out = subword_map_by_spacer(data_in)
+ self.assertEqual(out, true_out)
+
+ def test_subword_group_spacer_with_spacer_new(self):
+ data_in = ['⦅mrk_case_modifier_C⦆', '▁', 'however', ',', '▁', 'according', '▁', 'to', '▁', 'the', '▁', 'logs', ',', '▁', '⦅mrk_begin_case_region_U⦆', '▁', 'she', '▁', 'is', '▁', 'hard', '-', 'working', '.', '▁', '⦅mrk_end_case_region_U⦆'] # noqa: E501
+ true_out = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7]
+ out = subword_map_by_spacer(data_in)
+ self.assertEqual(out, true_out)
+
if __name__ == '__main__':
unittest.main()
diff --git a/onmt/tests/test_text_dataset.py b/onmt/tests/test_text_dataset.py
index e4d22e9c..4477bca7 100644
--- a/onmt/tests/test_text_dataset.py
+++ b/onmt/tests/test_text_dataset.py
@@ -79,7 +79,8 @@ class TestTextMultiField(unittest.TestCase):
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
- sample_str = "dummy input here ."
+
+ sample_str = {"base_field": "dummy input here .", "a": "A A B D", "r": "C C C C", "b": "D F E D", "zbase_field": "another dummy input ."}
proc = mf.preprocess(sample_str)
self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
@@ -147,7 +148,7 @@ class TestTextDataReader(unittest.TestCase):
]
rdr = TextDataReader()
for i, ex in enumerate(rdr.read(strings, "src")):
- self.assertEqual(ex["src"], strings[i].decode("utf-8"))
+ self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8")})
class TestTextDataReaderFromFS(unittest.TestCase):
@@ -174,4 +175,23 @@ class TestTextDataReaderFromFS(unittest.TestCase):
def test_read(self):
rdr = TextDataReader()
for i, ex in enumerate(rdr.read(self.FILE_NAME, "src")):
- self.assertEqual(ex["src"], self.STRINGS[i].decode("utf-8"))
+ self.assertEqual(ex["src"], {"src": self.STRINGS[i].decode("utf-8")})
+
+class TestTextDataReaderWithFeatures(unittest.TestCase):
+ def test_read(self):
+ strings = [
+ "hello world".encode("utf-8"),
+ "this's a string with punctuation .".encode("utf-8"),
+ "ThIs Is A sTrInG wItH oDD CapitALIZAtion".encode("utf-8")
+ ]
+ features = {
+ "feat_0": [
+ "A A".encode("utf-8"),
+ "A A B B C".encode("utf-8"),
+ "A A D D E E".encode("utf-8")
+ ]
+ }
+
+ rdr = TextDataReader()
+ for i, ex in enumerate(rdr.read(strings, "src", features)):
+ self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8"), "feat_0": features["feat_0"][i].decode("utf-8")}) \ No newline at end of file
diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py
index 4bfa8be3..d99bc607 100644
--- a/onmt/tests/test_transform.py
+++ b/onmt/tests/test_transform.py
@@ -509,3 +509,25 @@ class TestBARTNoising(unittest.TestCase):
# n_masked = math.ceil(n_words * bart_noise.mask_ratio)
# print(f"Text Span Infilling: {infillied} / {tokens}")
# print(n_words, n_masked)
+
+class TestFeaturesTransform(unittest.TestCase):
+ def test_inferfeats(self):
+ inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"]
+ opt = Namespace(reversible_tokenization="joiner")
+ inferfeats_transform = inferfeats_cls(opt)
+
+ ex_in = {
+ "src": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'],
+ "tgt": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.']
+ }
+ ex_out = inferfeats_transform.apply(ex_in)
+ self.assertIs(ex_out, ex_in)
+
+ ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]}
+ ex_out = inferfeats_transform.apply(ex_in)
+ self.assertEqual(ex_out["src_feats"]["feat_0"], ["A", "<null>", "A", "A", "A", "B", "<null>", "A", "A", "C", "<null>", "C", "<null>"])
+
+ ex_in["src"] = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆']
+ ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]}
+ ex_out = inferfeats_transform.apply(ex_in)
+ self.assertEqual(ex_out["src_feats"]["feat_0"], ["<null>", "A", "<null>", "A", "A", "A", "B", "<null>", "<null>", "A", "A", "C", "<null>", "C", "<null>", "<null>"])
diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py
new file mode 100644
index 00000000..24f02e30
--- /dev/null
+++ b/onmt/transforms/features.py
@@ -0,0 +1,90 @@
+from onmt.utils.logging import logger
+from onmt.transforms import register_transform
+from .transform import Transform, ObservableStats
+from onmt.constants import DefaultTokens, SubwordMarker
+from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
+import re
+from collections import defaultdict
+
+
+@register_transform(name='filterfeats')
+class FilterFeatsTransform(Transform):
+ """Filter out examples with a mismatch between source and features."""
+
+ def __init__(self, opts):
+ super().__init__(opts)
+
+ @classmethod
+ def add_options(cls, parser):
+ pass
+
+ def _parse_opts(self):
+ pass
+
+ def apply(self, example, is_train=False, stats=None, **kwargs):
+ """Return None if mismatch"""
+
+ if 'src_feats' not in example:
+ # Do nothing
+ return example
+
+ for feat_name, feat_values in example['src_feats'].items():
+ if len(example['src']) != len(feat_values):
+ logger.warning(f"Skipping example due to mismatch between source and feature {feat_name}")
+ return None
+ return example
+
+ def _repr_args(self):
+ return ''
+
+
+@register_transform(name='inferfeats')
+class InferFeatsTransform(Transform):
+ """Infer features for subword tokenization."""
+
+ def __init__(self, opts):
+ super().__init__(opts)
+
+ @classmethod
+ def add_options(cls, parser):
+ """Avalilable options related to this Transform."""
+ group = parser.add_argument_group("Transform/InferFeats")
+ group.add("--reversible_tokenization", "-reversible_tokenization", default="joiner",
+ choices=["joiner", "spacer"], help="Type of reversible tokenization applied on the tokenizer.")
+
+ def _parse_opts(self):
+ super()._parse_opts()
+ self.reversible_tokenization = self.opts.reversible_tokenization
+
+ def apply(self, example, is_train=False, stats=None, **kwargs):
+
+ if "src_feats" not in example:
+ # Do nothing
+ return example
+
+ if self.reversible_tokenization == "joiner":
+ word_to_subword_mapping = subword_map_by_joiner(example["src"])
+ else: #Spacer
+ word_to_subword_mapping = subword_map_by_spacer(example["src"])
+
+ inferred_feats = defaultdict(list)
+ for subword, word_id in zip(example["src"], word_to_subword_mapping):
+ for feat_name, feat_values in example["src_feats"].items():
+ # If case markup placeholder
+ if subword in SubwordMarker.CASE_MARKUP:
+ inferred_feat = "<null>"
+ # Punctuation only (assumes joiner is also some punctuation token)
+ elif not re.sub(r'(\W)+', '', subword).strip():
+ inferred_feat = "<null>"
+ else:
+ inferred_feat = feat_values[word_id]
+
+ inferred_feats[feat_name].append(inferred_feat)
+
+ for feat_name, feat_values in inferred_feats.items():
+ example["src_feats"][feat_name] = inferred_feats[feat_name]
+
+ return example
+
+ def _repr_args(self):
+ return '' \ No newline at end of file
diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py
index 9bdd4ee4..4d37e982 100644
--- a/onmt/translate/translator.py
+++ b/onmt/translate/translator.py
@@ -333,6 +333,7 @@ class Inference(object):
def translate(
self,
src,
+ src_feats={},
tgt=None,
batch_size=None,
batch_type="sents",
@@ -345,6 +346,7 @@ class Inference(object):
Args:
src: See :func:`self.src_reader.read()`.
tgt: See :func:`self.tgt_reader.read()`.
+ src_feats: See :func`self.src_reader.read()`.
batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging
align_debug (bool): enables the word alignment logging
@@ -363,8 +365,8 @@ class Inference(object):
if self.tgt_prefix and tgt is None:
raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")
- src_data = {"reader": self.src_reader, "data": src}
- tgt_data = {"reader": self.tgt_reader, "data": tgt}
+ src_data = {"reader": self.src_reader, "data": src, "features": src_feats}
+ tgt_data = {"reader": self.tgt_reader, "data": tgt, "features": {}}
_readers, _data = inputters.Dataset.config(
[("src", src_data), ("tgt", tgt_data)]
)
@@ -925,6 +927,7 @@ class GeneratorLM(Inference):
def translate(
self,
src,
+ src_feats={},
tgt=None,
batch_size=None,
batch_type="sents",
@@ -945,6 +948,7 @@ class GeneratorLM(Inference):
return super(GeneratorLM, self).translate(
src,
+ src_feats,
tgt,
batch_size=1,
batch_type=batch_type,
diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py
index 0a70edb3..d775cf92 100644
--- a/onmt/utils/alignment.py
+++ b/onmt/utils/alignment.py
@@ -120,25 +120,43 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
return " ".join(word_align)
-def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER):
+def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP):
"""Return word id for each subword token (annotate by joiner)."""
- flags = [0] * len(subwords)
+ flags = [1] * len(subwords)
for i, tok in enumerate(subwords):
- if tok.endswith(marker):
- flags[i] = 1
- if tok.startswith(marker):
- assert i >= 1 and flags[i-1] != 1, \
+ if tok.endswith(marker) or (tok in case_markup and tok.find("end")<0):
+ flags[i] = 0
+ if tok.startswith(marker) or (tok in case_markup and tok.find("end")>=0):
+ assert i >= 1 and flags[i-1] != 0, \
"Sentence `{}` not correct!".format(" ".join(subwords))
- flags[i-1] = 1
- marker_acc = list(accumulate([0] + flags[:-1]))
- word_group = [(i - maker_sofar) for i, maker_sofar
- in enumerate(marker_acc)]
+ flags[i-1] = 0
+ word_group = list(accumulate([0] + flags[:-1]))
return word_group
-def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER):
+def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER, case_markup=SubwordMarker.CASE_MARKUP):
"""Return word id for each subword token (annotate by spacer)."""
- word_group = list(accumulate([int(marker in x) for x in subwords]))
+ flags = [0] * len(subwords)
+ for i, tok in enumerate(subwords):
+ if marker in tok:
+ if tok.replace(marker, "") in case_markup:
+ if i < len(subwords)-1:
+ flags[i] = 1
+ else:
+ if i > 0:
+ previous = subwords[i-1].replace(marker, "")
+ if previous not in case_markup:
+ flags[i] = 1
+
+ # In case there is a final case_markup when new_spacer is on
+ for i in range(1,len(subwords)-1):
+ if subwords[-i] in case_markup:
+ flags[-i] = 0
+ elif subwords[-i] == marker:
+ flags[-i] = 0
+ break
+
+ word_group = list(accumulate(flags))
if word_group[0] == 1: # when dummy prefix is set
word_group = [item - 1 for item in word_group]
return word_group
diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py
index 2f4f1e1c..4a12a5fe 100644
--- a/onmt/utils/parse.py
+++ b/onmt/utils/parse.py
@@ -75,6 +75,19 @@ class DataOptsCheckerMixin(object):
logger.warning(f"Corpus {cname}'s weight should be given."
" We default it to 1 for you.")
corpus['weight'] = 1
+
+ # Check features
+ src_feats = corpus.get("src_feats", None)
+ if src_feats is not None:
+ for feature_name, feature_file in src_feats.items():
+ cls._validate_file(feature_file, info=f'{cname}/path_{feature_name}')
+ if 'inferfeats' not in corpus["transforms"]:
+ raise ValueError(f"'inferfeats' transform is required when setting source features")
+ if 'filterfeats' not in corpus["transforms"]:
+ raise ValueError(f"'filterfeats' transform is required when setting source features")
+ else:
+ corpus["src_feats"] = None
+
logger.info(f"Parsed {len(corpora)} corpora from -data.")
opt.data = corpora
@@ -107,6 +120,18 @@ class DataOptsCheckerMixin(object):
@classmethod
def _validate_fields_opts(cls, opt, build_vocab_only=False):
"""Check options relate to vocab and fields."""
+
+ for cname, corpus in opt.data.items():
+ if cname != CorpusName.VALID and corpus["src_feats"] is not None:
+ assert opt.src_feats_vocab, \
+ "-src_feats_vocab is required if using source features."
+ import yaml
+ opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab)
+
+ for feature in corpus["src_feats"].keys():
+ assert feature in opt.src_feats_vocab, \
+ f"No vocab file set for feature {feature}"
+
if build_vocab_only:
if not opt.share_vocab:
assert opt.tgt_vocab, \
@@ -295,4 +320,4 @@ class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin):
@classmethod
def validate_translate_opts(cls, opt):
- pass
+ opt.src_feats = eval(opt.src_feats) if opt.src_feats else {}