Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-01-29 14:02:58 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-01-29 14:02:58 +0300
commitcb5680f8212db470c32258ea5cc37b6fbfa1d005 (patch)
tree42780f1a4a77e3106e85cf49229be3e015e67ece
parente7d1c62459ae60876eb253b7cb57a4dd5727c478 (diff)
055: Adding word level examples
-rw-r--r--examples/word_level/wmt_2018/de_en/microtransquest_zeroshot.py103
1 files changed, 56 insertions, 47 deletions
diff --git a/examples/word_level/wmt_2018/de_en/microtransquest_zeroshot.py b/examples/word_level/wmt_2018/de_en/microtransquest_zeroshot.py
index 9db2d11..1b08bc5 100644
--- a/examples/word_level/wmt_2018/de_en/microtransquest_zeroshot.py
+++ b/examples/word_level/wmt_2018/de_en/microtransquest_zeroshot.py
@@ -1,6 +1,5 @@
import os
-from examples.word_level.common.download import download_from_google_drive
from examples.word_level.common.util import reader
from examples.word_level.wmt_2018.de_en.microtransquest_config import microtransquest_config, TEST_PATH, \
TEST_SOURCE_FILE, \
@@ -11,52 +10,62 @@ from transquest.algo.word_level.microtransquest.run_model import MicroTransQuest
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
+models = {
+ "en_cs": "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/en_cs_smt_it/best_model",
+ "en_de_nmt": "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/en_de_nmt_it/best_model",
+ "en_de_smt": "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/en_de_smt_it/best_model",
+ "en_lv_nmt": "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/en_lv_nmt_pharmaceutical/best_model",
+ "en_lv_smt": "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/en_lv_smt_pharmaceutical/best_model"
+}
-model = MicroTransQuestModel(MODEL_TYPE, "/content/drive/MyDrive/TransQuestModels/MicroTransQuest/wmt2018/de_en_smt_pharmaceutical/best_model", labels=["OK", "BAD"],
- args=microtransquest_config)
+for language, path in models.items():
-if not os.path.exists(TEMP_DIRECTORY):
- os.makedirs(TEMP_DIRECTORY)
+ if not os.path.exists(TEMP_DIRECTORY, language):
+ os.makedirs(TEMP_DIRECTORY, language)
+
+ model = MicroTransQuestModel(MODEL_TYPE, path, labels=["OK", "BAD"], args=microtransquest_config)
+
+ if not os.path.exists(TEMP_DIRECTORY):
+ os.makedirs(TEMP_DIRECTORY)
+
+ raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+
+ test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
+
+ fold_sources_tags = []
+ fold_targets_tags = []
+
+ predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
+ sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+
+ test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
+ test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+
+ with open(os.path.join(TEMP_DIRECTORY, language, TEST_SOURCE_TAGS_FILE), 'w') as f:
+ for sentence_id, (test_source_sentence, source_prediction) in enumerate(
+ zip(test_source_sentences, sources_tags)):
+ words = test_source_sentence.split()
+ for word_id, (word, word_prediction) in enumerate(zip(words, source_prediction)):
+ f.write("MicroTransQuest" + "\t" + "source" + "\t" +
+ str(sentence_id) + "\t" + str(word_id) + "\t"
+ + word + "\t" + word_prediction + '\n')
-raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
-
-test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
-
-fold_sources_tags = []
-fold_targets_tags = []
-
-
-predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
-sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
-
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
-
-with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
- for sentence_id, (test_source_sentence, source_prediction) in enumerate(
- zip(test_source_sentences, sources_tags)):
- words = test_source_sentence.split()
- for word_id, (word, word_prediction) in enumerate(zip(words, source_prediction)):
- f.write("MicroTransQuest" + "\t" + "source" + "\t" +
- str(sentence_id) + "\t" + str(word_id) + "\t"
- + word + "\t" + word_prediction + '\n')
-
-with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
- os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
- for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, targets_tags)):
- target_sentence = test_sentence.split("[SEP]")[1]
- words = target_sentence.split()
- # word_predictions = target_prediction.split()
- gap_index = 0
- word_index = 0
- for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
- if word_id % 2 == 0:
- gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
- str(sentence_id) + "\t" + str(gap_index) + "\t"
- + "gap" + "\t" + word_prediction + '\n')
- gap_index += 1
- else:
- target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
- str(sentence_id) + "\t" + str(word_index) + "\t"
- + word + "\t" + word_prediction + '\n')
- word_index += 1
+ with open(os.path.join(TEMP_DIRECTORY, language, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
+ os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
+ for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, targets_tags)):
+ target_sentence = test_sentence.split("[SEP]")[1]
+ words = target_sentence.split()
+ # word_predictions = target_prediction.split()
+ gap_index = 0
+ word_index = 0
+ for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)):
+ if word_id % 2 == 0:
+ gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
+ str(sentence_id) + "\t" + str(gap_index) + "\t"
+ + "gap" + "\t" + word_prediction + '\n')
+ gap_index += 1
+ else:
+ target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
+ str(sentence_id) + "\t" + str(word_index) + "\t"
+ + word + "\t" + word_prediction + '\n')
+ word_index += 1