Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-03-18 16:41:04 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-03-18 16:41:04 +0300
commited721ada1157616469378218da5ace7ab611d2a8 (patch)
tree48c34c03c296084dbc2e33d6ad971418c6eba2a3 /examples
parent05d649b3f488a2042d9906fa89735994fcba9eee (diff)
056: Code Refactoring
Diffstat (limited to 'examples')
-rw-r--r--examples/word_level/wmt_2018/en_cs/microtransquest.py81
-rw-r--r--examples/word_level/wmt_2018/en_de/nmt/microtransquest.py77
-rw-r--r--examples/word_level/wmt_2018/en_de/smt/microtransquest.py80
-rw-r--r--examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py79
-rw-r--r--examples/word_level/wmt_2018/en_lv/smt/microtransquest.py79
-rw-r--r--examples/word_level/wmt_2020/en_de/microtransquest.py34
6 files changed, 188 insertions, 242 deletions
diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest.py b/examples/word_level/wmt_2018/en_cs/microtransquest.py
index 6595e8d..ebf2b3f 100644
--- a/examples/word_level/wmt_2018/en_cs/microtransquest.py
+++ b/examples/word_level/wmt_2018/en_cs/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_SOURCE_TAGS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB, DEV_SOURCE_FILE, DEV_TARGET_FILE, \
DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_PATH
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-raw_train_df = raw_train_df.head(21000)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
@@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
+ for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)):
+ # target_sentence = test_sentence.split("[SEP]")[1]
+ words = test_target_sentence.split()
# word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
+
+ for prediction_id, prediction in enumerate(target_prediction):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
+ + "gap" + "\t" + prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
+ + words[word_index] + "\t" + prediction + '\n')
word_index += 1
# Predictions for dev file
@@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(
@@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f:
for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate(
- zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)):
- dev_target_sentence = dev_sentence.split("[SEP]")[1]
- words = dev_target_sentence.split()
- gold_predictions = source_gold_tag.split()
- # word_predictions = target_prediction.split()
+ zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)):
+ words = dev_sentence.split()
+ gold_predictions = dev_target_gold_tag.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction, gold_prediction) in enumerate(
- zip(words, target_prediction, gold_predictions)):
- if word_id % 2 == 0:
+
+ for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n')
word_index += 1
diff --git a/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py b/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py
index df3b99d..d2fdb2e 100644
--- a/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py
+++ b/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.en_de.nmt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_de.nmt.microtransquest_config import TRAIN_
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_SOURCE_TAGS_FILE_SUB, \
DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
@@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
+ for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)):
+ # target_sentence = test_sentence.split("[SEP]")[1]
+ words = test_target_sentence.split()
# word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
+
+ for prediction_id, prediction in enumerate(target_prediction):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
+ + "gap" + "\t" + prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
+ + words[word_index] + "\t" + prediction + '\n')
word_index += 1
# Predictions for dev file
@@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(
@@ -180,22 +173,20 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f:
for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate(
- zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)):
- dev_target_sentence = dev_sentence.split("[SEP]")[1]
- words = dev_target_sentence.split()
+ zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)):
+ words = dev_sentence.split()
gold_predictions = dev_target_gold_tag.split()
- # word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction, gold_prediction) in enumerate(
- zip(words, dev_target_prediction, gold_predictions)):
- if word_id % 2 == 0:
+
+ for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n')
word_index += 1
diff --git a/examples/word_level/wmt_2018/en_de/smt/microtransquest.py b/examples/word_level/wmt_2018/en_de/smt/microtransquest.py
index cc67736..8600f4b 100644
--- a/examples/word_level/wmt_2018/en_de/smt/microtransquest.py
+++ b/examples/word_level/wmt_2018/en_de/smt/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.en_de.smt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_de.smt.microtransquest_config import TRAIN_
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_SOURCE_TAGS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB, DEV_PATH, DEV_SOURCE_FILE, \
DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
@@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
+ for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)):
+ # target_sentence = test_sentence.split("[SEP]")[1]
+ words = test_target_sentence.split()
# word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
+
+ for prediction_id, prediction in enumerate(target_prediction):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
+ + "gap" + "\t" + prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
+ + words[word_index] + "\t" + prediction + '\n')
word_index += 1
# Predictions for dev file
@@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(
@@ -180,23 +173,20 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f:
for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate(
- zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)):
- dev_target_sentence = dev_sentence.split("[SEP]")[1]
- words = dev_target_sentence.split()
- gold_predictions = source_gold_tag.split()
- # word_predictions = target_prediction.split()
+ zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)):
+ words = dev_sentence.split()
+ gold_predictions = dev_target_gold_tag.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction, gold_prediction) in enumerate(
- zip(words, target_prediction, gold_predictions)):
- if word_id % 2 == 0:
+
+ for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n')
word_index += 1
-
diff --git a/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py b/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py
index 248a144..46b37af 100644
--- a/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py
+++ b/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.en_lv.nmt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_lv.nmt.microtransquest_config import TRAIN_
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_SOURCE_TAGS_FILE_SUB, \
DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
@@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
+ for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)):
+ # target_sentence = test_sentence.split("[SEP]")[1]
+ words = test_target_sentence.split()
# word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
+
+ for prediction_id, prediction in enumerate(target_prediction):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
+ + "gap" + "\t" + prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
+ + words[word_index] + "\t" + prediction + '\n')
word_index += 1
# Predictions for dev file
@@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(
@@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f:
for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate(
- zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)):
- dev_target_sentence = dev_sentence.split("[SEP]")[1]
- words = dev_target_sentence.split()
- gold_predictions = source_gold_tag.split()
- # word_predictions = target_prediction.split()
+ zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)):
+ words = dev_sentence.split()
+ gold_predictions = dev_target_gold_tag.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction, gold_prediction) in enumerate(
- zip(words, target_prediction, gold_predictions)):
- if word_id % 2 == 0:
+
+ for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n')
word_index += 1
diff --git a/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py b/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py
index d485934..6dfaccc 100644
--- a/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py
+++ b/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py
@@ -3,7 +3,7 @@ import shutil
from sklearn.model_selection import train_test_split
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2018.en_lv.smt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_lv.smt.microtransquest_config import TRAIN_
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \
DEV_TARGET_GAPS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_SOURCE_TAGS_FILE_SUB, DEV_PATH, DEV_SOURCE_FILE, \
DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)
@@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
+ for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)):
+ # target_sentence = test_sentence.split("[SEP]")[1]
+ words = test_target_sentence.split()
# word_predictions = target_prediction.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
+
+ for prediction_id, prediction in enumerate(target_prediction):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
+ + "gap" + "\t" + prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
+ + words[word_index] + "\t" + prediction + '\n')
word_index += 1
# Predictions for dev file
@@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(
@@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open(
os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f:
for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate(
- zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)):
- dev_target_sentence = dev_sentence.split("[SEP]")[1]
- words = dev_target_sentence.split()
- gold_predictions = source_gold_tag.split()
- # word_predictions = target_prediction.split()
+ zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)):
+ words = dev_sentence.split()
+ gold_predictions = dev_target_gold_tag.split()
gap_index = 0
word_index = 0
- for word_id, (word, word_prediction, gold_prediction) in enumerate(
- zip(words, target_prediction, gold_predictions)):
- if word_id % 2 == 0:
+
+ for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)):
+ if prediction_id % 2 == 0:
gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n')
gap_index += 1
else:
target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + "\t" + gold_prediction + '\n')
+ + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n')
word_index += 1
diff --git a/examples/word_level/wmt_2020/en_de/microtransquest.py b/examples/word_level/wmt_2020/en_de/microtransquest.py
index 4ee4184..4e29ed3 100644
--- a/examples/word_level/wmt_2020/en_de/microtransquest.py
+++ b/examples/word_level/wmt_2020/en_de/microtransquest.py
@@ -2,7 +2,7 @@ import shutil
from sklearn.model_selection import train_test_split
import os
-from examples.word_level.common.util import reader
+from examples.word_level.common.util import reader, prepare_testdata
from examples.word_level.wmt_2020.en_de.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
TRAIN_SOURCE_TAGS_FILE, \
TRAIN_TARGET_FILE, \
@@ -10,21 +10,20 @@ from examples.word_level.wmt_2020.en_de.microtransquest_config import TRAIN_PATH
TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, TEST_TARGET_TAGS_FLE, SEED, DEV_TARGET_TAGS_FLE, \
DEV_SOURCE_TAGS_FILE, DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_TARGET_TAGS_FILE_SUB, DEV_SOURCE_TAGS_FILE_SUB
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
-from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
TRAIN_TARGET_TAGS_FLE)
-raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
+raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE,
DEV_TARGET_TAGS_FLE)
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config)
+test_sentences = prepare_testdata(raw_test_df)
+dev_sentences = prepare_testdata(raw_dev_df)
fold_sources_tags = []
fold_targets_tags = []
@@ -39,27 +38,20 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
- train_df = prepare_data(raw_train, args=microtransquest_config)
- eval_df = prepare_data(raw_eval, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df, eval_df=eval_df)
- model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train, eval_data=raw_eval)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
else:
- train_df = prepare_data(raw_train_df, args=microtransquest_config)
- tags = train_df['labels'].unique().tolist()
- model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
- model.train_model(train_df)
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
+ model.train_model(raw_train_df)
- predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
- sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True)
fold_sources_tags.append(sources_tags)
fold_targets_tags.append(targets_tags)
- dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True)
- dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config)
+ dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True)
dev_fold_sources_tags.append(dev_sources_tags)
dev_fold_targets_tags.append(dev_targets_tags)