Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-03-16 16:01:59 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-03-16 16:01:59 +0300
commit17802d537280be12c09d7ba3b4aa8884b4029084 (patch)
tree78d9842d7da7784fedd14423b7d2f31918df99ce
parentc3f7e1347f9218f05bd95f06ada98c60f711371c (diff)
056: Code Refactoring
-rw-r--r--docs/architectures/sentence_level_architectures.md27
-rw-r--r--docs/architectures/word_level_architecture.md46
-rw-r--r--docs/examples/sentence_level_examples.md2
-rw-r--r--docs/examples/word_level_examples.md0
-rw-r--r--docs/index.md2
-rw-r--r--docs/models/word_level_pretrained.md0
-rw-r--r--examples/word_level/common/util.py12
-rw-r--r--examples/word_level/wmt_2018/de_en/microtransquest.py24
-rw-r--r--transquest/algo/word_level/microtransquest/format.py11
-rwxr-xr-xtransquest/algo/word_level/microtransquest/run_model.py10
10 files changed, 99 insertions, 35 deletions
diff --git a/docs/architectures/sentence_level_architectures.md b/docs/architectures/sentence_level_architectures.md
index d142791..ce94585 100644
--- a/docs/architectures/sentence_level_architectures.md
+++ b/docs/architectures/sentence_level_architectures.md
@@ -1,17 +1,26 @@
-# TransQuest Architectures
+# Sentence Level TransQuest Architectures
We have introduced two architectures for the sentence level QE in the TransQuest framework, both relies on the XLM-R transformer model.
-##MonoTransQuest
+### Data Preparation
+First read your data in to a pandas dataframe and format it so that it has three columns with headers text_a, text_b and labels. text_a is the source text, text_b is the target text and labels are the quality scores as in the following table.
-The first architecture proposed uses a single XLM-R transformer model. The input of this model is a concatenation of the original sentence and its translation, separated by the *[SEP]* token. Then the output of the
+| text_a | text_b | labels |
+| ------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------|--------|
+| නමුත් 1170 සිට 1270 දක්වා රජය පාලනය කරන ලද්දේ යුධ නායකයින් විසිනි. | But from 1170 to 1270 the government was controlled by warlords. | 0.8833 |
+| ව්‍යංගයෙන් ගිවිසුමක් යනු කොන්දේසි වචනයෙන් විස්තර නොකරන ලද එක් අවස්ථාවකි. | A contract from the constitution is one of the occasions in which the term is not described. | 0.6667 |
+
+
+Now, you can consider following architectures to build the QE model.
+
+## MonoTransQuest
+
+The first architecture proposed uses a single XLM-R transformer model. The input of this model is a concatenation of the original sentence and its translation, separated by the *[SEP]* token. Then the output of the *[CLS]* token is passed through a softmax layer to reflect the quality scores.
![MonoTransQuest Architecture](../images/TransQuest.png)
### Minimal Start for a MonoTransQuest Model
-First read your data in to a pandas dataframe and format it so that it has three columns with headers text_a, text_b and labels. text_a is the source text, text_b is the target text and labels are the quality scores.
-
-Then initiate and train the model like in the following code. train_df and eval_df are the pandas dataframes prepared with the above instructions.
+Initiate and train the model like in the following code. train_df and eval_df are the pandas dataframes prepared with the instructions in Data Preparation section.
```python
from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr
@@ -24,7 +33,7 @@ model = MonoTransQuestModel("xlmroberta", "xlm-roberta-large", num_labels=1, use
model.train_model(train_df, eval_df=eval_df, pearson_corr=pearson_corr, spearman_corr=spearman_corr,
mae=mean_absolute_error)
```
-An example monotransquest_config is available [here.](https://github.com/TharinduDR/TransQuest/blob/master/examples/wmt_2020/ro_en/transformer_config.py). The best model will be saved to the path specified in the "best_model_dir" in monotransquest_config. Then you can load it and do the predictions like this.
+An example monotransquest_config is available [here.](https://github.com/TharinduDR/TransQuest/blob/master/examples/sentence_level/wmt_2020/ro_en/monotransquest_config.py). The best model will be saved to the path specified in the "best_model_dir" in monotransquest_config. Then you can load it and do the predictions like this.
```python
from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel
@@ -48,7 +57,7 @@ Then the output of all the word embeddings goes through a mean pooling layer. Af
### Minimal Start for a SiameseTransQuest Model
-First save your train/dev csv files in a single folder. We refer the path to that folder as "path" in the code below. You have to provide the indices of source, target and quality labels when reading with the QEDataReader class.
+First save your train/dev pandas dataframes to csv files in a single folder. We refer the path to that folder as "path" in the code below. You have to provide the indices of source, target and quality labels when reading with the QEDataReader class.
```python
from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
@@ -112,4 +121,4 @@ test_data = SentencesDataset(examples=qe_reader.get_examples("test.tsv", test_fi
verbose=False)
```
-You will find the predictions in the test_result.txt file in the siamesetransquest_config['cache_dir'] folder. You will find more examples in [here.](https://tharindudr.github.io/TransQuest/examples/sentence_level) \ No newline at end of file
+You will find the predictions in the test_result.txt file in the siamesetransquest_config['cache_dir'] folder. You can find more examples in [here.](https://tharindudr.github.io/TransQuest/examples/sentence_level) \ No newline at end of file
diff --git a/docs/architectures/word_level_architecture.md b/docs/architectures/word_level_architecture.md
new file mode 100644
index 0000000..372ec84
--- /dev/null
+++ b/docs/architectures/word_level_architecture.md
@@ -0,0 +1,46 @@
+# Word Level TransQuest Architecture
+WE have one architecture that is capable of providing word level quality estimation models; MicroTransQuest.
+
+### Data Preparation
+Please have your data as a pandas dataframe in this format.
+
+| source_column | target_column | source_tags_column | target_tags_column |
+| ----------------------------------------| ----------------------------------|--------------------|-------------------------------------|
+| 52 mg wasserfreie Lactose . | 52 mg anhydrous lactose . | [OK OK OK OK OK] | [OK OK OK OK OK OK OK OK OK OK OK] |
+| România sanofi-aventis România S.R.L. | Sanofi-Aventis România S. R. L. | [BAD OK OK OK] | [BAD BAD OK OK OK OK OK OK OK OK OK]|
+
+Please note that target_tags_column has word level quality labels for gaps in the target too. Therefore, it has 2*N+1 labels, where N is the total number of tokens in the target. For more information please have a look at WMT word level quality estimtion task.
+
+Now, you can consider MicroTransQuest to build the QE model.
+
+## MicroTransQuest
+The input of this model is a concatenation of the original sentence and its translation, separated by the *[SEP]* token. As shown in the Figure target sentence contains gaps too. Then the output of the each token is passed through a softmax layer to reflect the quality scores.
+
+
+![MonoTransQuest Architecture](../images/MicroTransQuest.png)
+
+### Minimal Start for a MonoTransQuest Model
+
+Initiate and train the model like in the following code. train_df and eval_df are the pandas dataframes prepared with the instructions in Data Preparation section.
+
+```python
+from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
+from transquest.algo.word_level.microtransquest.format import prepare_data
+import torch
+
+model = MicroTransQuestModel("xlmroberta", "xlm-roberta-large", labels=["OK", "BAD"], use_cuda=torch.cuda.is_available(), args=microtransquest_config)
+model.train_model(prepare_data(train_df, microtransquest_config) , eval_df=(prepare_data(eval_df, microtransquest_config)))
+```
+
+An example microtransquest_config is available [here.](https://github.com/TharinduDR/TransQuest/blob/master/examples/word_level/wmt_2018/en_de/microtransquest_config.py). The best model will be saved to the path specified in the "best_model_dir" in microtransquest_config. Then you can load it and do the predictions like this.
+
+```python
+from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
+
+model = MonoTransQuestModel("xlmroberta", monotransquest_config["best_model_dir"],
+ use_cuda=torch.cuda.is_available(), args=microtransquest_config)
+
+predictions, raw_outputs = model.predict([[source, target]])
+print(predictions)
+
+```
diff --git a/docs/examples/sentence_level_examples.md b/docs/examples/sentence_level_examples.md
index 1e70731..1972b37 100644
--- a/docs/examples/sentence_level_examples.md
+++ b/docs/examples/sentence_level_examples.md
@@ -1,4 +1,4 @@
-# Examples
+# Sentence Level Examples
We have provided several examples on how to use TransQuest in recent WMT sentence-level quality estimation shared tasks. They are included in the repository but are not shipped with the library. Therefore, if you need to run the examples, please clone the repository.
!!! note
diff --git a/docs/examples/word_level_examples.md b/docs/examples/word_level_examples.md
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/docs/examples/word_level_examples.md
diff --git a/docs/index.md b/docs/index.md
index 84935cd..ada2601 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,7 +1,7 @@
# TransQuest: Translation Quality Estimation with Cross-lingual Transformers
The goal of quality estimation (QE) is to evaluate the quality of a translation without having access to a reference translation. High-accuracy QE that can be easily deployed for a number of language pairs is the missing piece in many commercial translation workflows as they have numerous potential uses. They can be employed to select the best translation when several translation engines are available or can inform the end user about the reliability of automatically translated content. In addition, QE systems can be used to decide whether a translation can be published as it is in a given context, or whether it requires human post-editing before publishing or translation from scratch by a human. The quality estimation can be done at different levels: document level, sentence level and word level.
-With TransQuest, we have opensourced our research in sentence-level quality estimation which also won the sentence-level direct assessment quality estimation shared task in [WMT 2020](http://www.statmt.org/wmt20/quality-estimation-task.html). TransQuest outperforms current open-source quality estimation frameworks such as [OpenKiwi](https://github.com/Unbabel/OpenKiwi) and [DeepQuest](https://github.com/sheffieldnlp/deepQuest).
+With TransQuest, we have opensourced our research in translation quality estimation which also won the sentence-level direct assessment quality estimation shared task in [WMT 2020](http://www.statmt.org/wmt20/quality-estimation-task.html). TransQuest outperforms current open-source quality estimation frameworks such as [OpenKiwi](https://github.com/Unbabel/OpenKiwi) and [DeepQuest](https://github.com/sheffieldnlp/deepQuest).
## Installation
diff --git a/docs/models/word_level_pretrained.md b/docs/models/word_level_pretrained.md
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/docs/models/word_level_pretrained.md
diff --git a/examples/word_level/common/util.py b/examples/word_level/common/util.py
index 416f698..94965d6 100644
--- a/examples/word_level/common/util.py
+++ b/examples/word_level/common/util.py
@@ -37,6 +37,18 @@ def reader(path, args, source_file, target_file, source_tags_file=None, target_t
return df
+def prepare_testdata(raw_df, args):
+ source_sentences = raw_df[args["source_column"]].tolist()
+ target_sentences = raw_df[args["target_column"]].tolist()
+
+ test_sentences = []
+ for source_sentence, target_sentence in zip(source_sentences, target_sentences):
+ test_sentences.append([source_sentence, target_sentence])
+
+ return test_sentences
+
+
+
diff --git a/examples/word_level/wmt_2018/de_en/microtransquest.py b/examples/word_level/wmt_2018/de_en/microtransquest.py
index 1e4c9e8..1d6490e 100644
--- a/examples/word_level/wmt_2018/de_en/microtransquest.py
+++ b/examples/word_level/wmt_2018/de_en/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.de_en.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,7 +11,6 @@ from examples.word_level.wmt_2018.de_en.microtransquest_config import TRAIN_PATH
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_SOURCE_TAGS_FILE_SUB, \
DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_df=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
diff --git a/transquest/algo/word_level/microtransquest/format.py b/transquest/algo/word_level/microtransquest/format.py
index 607c2c6..b7409b2 100644
--- a/transquest/algo/word_level/microtransquest/format.py
+++ b/transquest/algo/word_level/microtransquest/format.py
@@ -30,17 +30,14 @@ def prepare_data(raw_df, args):
return pd.DataFrame(data, columns=['sentence_id', 'words', 'labels'])
-def prepare_testdata(raw_df, args):
- source_sentences = raw_df[args["source_column"]].tolist()
- target_sentences = raw_df[args["target_column"]].tolist()
-
+def format_to_test(to_test, args):
test_sentences = []
- for source_sentence, target_sentence in zip(source_sentences, target_sentences):
+ for source_sentence, target_sentence in to_test:
test_sentence = source_sentence + " " + "[SEP]"
target_words = target_sentence.split()
for target_word in target_words:
- test_sentence = test_sentence + " " + args["tag"] + " " + target_word
- test_sentence = test_sentence + " " + args["tag"]
+ test_sentence = test_sentence + " " + args.tag + " " + target_word
+ test_sentence = test_sentence + " " + args.tag
test_sentences.append(test_sentence)
return test_sentences
diff --git a/transquest/algo/word_level/microtransquest/run_model.py b/transquest/algo/word_level/microtransquest/run_model.py
index d75b184..245a80f 100755
--- a/transquest/algo/word_level/microtransquest/run_model.py
+++ b/transquest/algo/word_level/microtransquest/run_model.py
@@ -44,6 +44,7 @@ from transformers.optimization import (
get_polynomial_decay_schedule_with_warmup,
)
+from transquest.algo.word_level.microtransquest.format import post_process, prepare_data, format_to_test
from transquest.algo.word_level.microtransquest.model_args import MicroTransQuestArgs
from transquest.algo.word_level.microtransquest.utils import sweep_config_to_sweep_values, InputExample, \
read_examples_from_file, get_examples_from_df, convert_examples_to_features, LazyQEDataset
@@ -249,6 +250,9 @@ class MicroTransQuestModel:
training_details: Average training loss if evaluate_during_training is False or full training progress scores if evaluate_during_training is True
""" # noqa: ignore flake8"
+ train_data = prepare_data(train_data, args)
+ eval_data = prepare_data(eval_data, args)
+
if args:
self.args.update_from_dict(args)
@@ -955,6 +959,8 @@ class MicroTransQuestModel:
pad_token_label_id = self.pad_token_label_id
preds = None
+ to_predict = format_to_test(to_predict, self.args)
+
if split_on_space:
if self.args.model_type == "layoutlm":
predict_examples = [
@@ -1104,7 +1110,9 @@ class MicroTransQuestModel:
for i, sentence in enumerate(to_predict)
]
- return preds, model_outputs
+ sources_tags, targets_tags = post_process(preds, to_predict, args=self.args)
+
+ return sources_tags, targets_tags
def _convert_tokens_to_word_logits(self, input_ids, label_ids, attention_mask, logits):