diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-18 16:41:04 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-18 16:41:04 +0300 |
commit | ed721ada1157616469378218da5ace7ab611d2a8 (patch) | |
tree | 48c34c03c296084dbc2e33d6ad971418c6eba2a3 /examples | |
parent | 05d649b3f488a2042d9906fa89735994fcba9eee (diff) |
056: Code Refactoring
Diffstat (limited to 'examples')
6 files changed, 188 insertions, 242 deletions
diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest.py b/examples/word_level/wmt_2018/en_cs/microtransquest.py index 6595e8d..ebf2b3f 100644 --- a/examples/word_level/wmt_2018/en_cs/microtransquest.py +++ b/examples/word_level/wmt_2018/en_cs/microtransquest.py @@ -3,7 +3,7 @@ import shutil from sklearn.model_selection import train_test_split -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \ DEV_SOURCE_TAGS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB, DEV_SOURCE_FILE, DEV_TARGET_FILE, \ DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_PATH -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -raw_train_df = raw_train_df.head(21000) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) + +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) @@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: - for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): - target_sentence = test_sentence.split("[SEP]")[1] - words = target_sentence.split() + for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)): + # target_sentence = test_sentence.split("[SEP]")[1] + words = test_target_sentence.split() # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)): - if word_id % 2 == 0: + + for prediction_id, prediction in enumerate(target_prediction): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + '\n') + + "gap" + "\t" + prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + '\n') + + words[word_index] + "\t" + prediction + '\n') word_index += 1 # Predictions for dev file @@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( @@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f: for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate( - zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)): - dev_target_sentence = dev_sentence.split("[SEP]")[1] - words = dev_target_sentence.split() - gold_predictions = source_gold_tag.split() - # word_predictions = target_prediction.split() + zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)): + words = dev_sentence.split() + gold_predictions = dev_target_gold_tag.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction, gold_prediction) in enumerate( - zip(words, target_prediction, gold_predictions)): - if word_id % 2 == 0: + + for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n') word_index += 1 diff --git a/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py b/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py index df3b99d..d2fdb2e 100644 --- a/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py +++ b/examples/word_level/wmt_2018/en_de/nmt/microtransquest.py @@ -3,7 +3,7 @@ import shutil from sklearn.model_selection import train_test_split -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2018.en_de.nmt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_de.nmt.microtransquest_config import TRAIN_ TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \ DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_SOURCE_TAGS_FILE_SUB, \ DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) @@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: - for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): - target_sentence = test_sentence.split("[SEP]")[1] - words = target_sentence.split() + for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)): + # target_sentence = test_sentence.split("[SEP]")[1] + words = test_target_sentence.split() # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)): - if word_id % 2 == 0: + + for prediction_id, prediction in enumerate(target_prediction): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + '\n') + + "gap" + "\t" + prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + '\n') + + words[word_index] + "\t" + prediction + '\n') word_index += 1 # Predictions for dev file @@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( @@ -180,22 +173,20 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f: for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate( - zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)): - dev_target_sentence = dev_sentence.split("[SEP]")[1] - words = dev_target_sentence.split() + zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)): + words = dev_sentence.split() gold_predictions = dev_target_gold_tag.split() - # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction, gold_prediction) in enumerate( - zip(words, dev_target_prediction, gold_predictions)): - if word_id % 2 == 0: + + for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n') word_index += 1 diff --git a/examples/word_level/wmt_2018/en_de/smt/microtransquest.py b/examples/word_level/wmt_2018/en_de/smt/microtransquest.py index cc67736..8600f4b 100644 --- a/examples/word_level/wmt_2018/en_de/smt/microtransquest.py +++ b/examples/word_level/wmt_2018/en_de/smt/microtransquest.py @@ -3,7 +3,7 @@ import shutil from sklearn.model_selection import train_test_split -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2018.en_de.smt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_de.smt.microtransquest_config import TRAIN_ TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \ DEV_SOURCE_TAGS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB, DEV_PATH, DEV_SOURCE_FILE, \ DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) @@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: - for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): - target_sentence = test_sentence.split("[SEP]")[1] - words = target_sentence.split() + for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)): + # target_sentence = test_sentence.split("[SEP]")[1] + words = test_target_sentence.split() # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)): - if word_id % 2 == 0: + + for prediction_id, prediction in enumerate(target_prediction): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + '\n') + + "gap" + "\t" + prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + '\n') + + words[word_index] + "\t" + prediction + '\n') word_index += 1 # Predictions for dev file @@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( @@ -180,23 +173,20 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f: for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate( - zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)): - dev_target_sentence = dev_sentence.split("[SEP]")[1] - words = dev_target_sentence.split() - gold_predictions = source_gold_tag.split() - # word_predictions = target_prediction.split() + zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)): + words = dev_sentence.split() + gold_predictions = dev_target_gold_tag.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction, gold_prediction) in enumerate( - zip(words, target_prediction, gold_predictions)): - if word_id % 2 == 0: + + for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n') word_index += 1 - diff --git a/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py b/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py index 248a144..46b37af 100644 --- a/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py +++ b/examples/word_level/wmt_2018/en_lv/nmt/microtransquest.py @@ -3,7 +3,7 @@ import shutil from sklearn.model_selection import train_test_split -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2018.en_lv.nmt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_lv.nmt.microtransquest_config import TRAIN_ TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \ DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE, DEV_SOURCE_TAGS_FILE_SUB, \ DEV_TARGET_TAGS_FILE_SUB, DEV_TARGET_GAPS_FILE_SUB -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) @@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: - for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): - target_sentence = test_sentence.split("[SEP]")[1] - words = target_sentence.split() + for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)): + # target_sentence = test_sentence.split("[SEP]")[1] + words = test_target_sentence.split() # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)): - if word_id % 2 == 0: + + for prediction_id, prediction in enumerate(target_prediction): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + '\n') + + "gap" + "\t" + prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + '\n') + + words[word_index] + "\t" + prediction + '\n') word_index += 1 # Predictions for dev file @@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( @@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f: for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate( - zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)): - dev_target_sentence = dev_sentence.split("[SEP]")[1] - words = dev_target_sentence.split() - gold_predictions = source_gold_tag.split() - # word_predictions = target_prediction.split() + zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)): + words = dev_sentence.split() + gold_predictions = dev_target_gold_tag.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction, gold_prediction) in enumerate( - zip(words, target_prediction, gold_predictions)): - if word_id % 2 == 0: + + for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n') word_index += 1 diff --git a/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py b/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py index d485934..6dfaccc 100644 --- a/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py +++ b/examples/word_level/wmt_2018/en_lv/smt/microtransquest.py @@ -3,7 +3,7 @@ import shutil from sklearn.model_selection import train_test_split -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2018.en_lv.smt.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -11,21 +11,20 @@ from examples.word_level.wmt_2018.en_lv.smt.microtransquest_config import TRAIN_ TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE, \ DEV_TARGET_GAPS_FILE_SUB, DEV_TARGET_TAGS_FILE_SUB, DEV_SOURCE_TAGS_FILE_SUB, DEV_PATH, DEV_SOURCE_FILE, \ DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -40,27 +39,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) @@ -96,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -110,22 +102,23 @@ with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: - for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): - target_sentence = test_sentence.split("[SEP]")[1] - words = target_sentence.split() + for sentence_id, (test_target_sentence, target_prediction) in enumerate(zip(test_target_sentences, target_predictions)): + # target_sentence = test_sentence.split("[SEP]")[1] + words = test_target_sentence.split() # word_predictions = target_prediction.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction) in enumerate(zip(words, target_prediction)): - if word_id % 2 == 0: + + for prediction_id, prediction in enumerate(target_prediction): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + '\n') + + "gap" + "\t" + prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + '\n') + + words[word_index] + "\t" + prediction + '\n') word_index += 1 # Predictions for dev file @@ -161,10 +154,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( @@ -180,23 +173,21 @@ with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: with open(os.path.join(TEMP_DIRECTORY, DEV_TARGET_TAGS_FILE_SUB), 'w') as target_f, open( os.path.join(TEMP_DIRECTORY, DEV_TARGET_GAPS_FILE_SUB), 'w') as gap_f: for sentence_id, (dev_sentence, dev_target_prediction, dev_target_gold_tag) in enumerate( - zip(dev_sentences, dev_target_predictions, dev_target_gold_tags)): - dev_target_sentence = dev_sentence.split("[SEP]")[1] - words = dev_target_sentence.split() - gold_predictions = source_gold_tag.split() - # word_predictions = target_prediction.split() + zip(dev_target_sentences, dev_target_predictions, dev_target_gold_tags)): + words = dev_sentence.split() + gold_predictions = dev_target_gold_tag.split() gap_index = 0 word_index = 0 - for word_id, (word, word_prediction, gold_prediction) in enumerate( - zip(words, target_prediction, gold_predictions)): - if word_id % 2 == 0: + + for prediction_id, (prediction, gold_prediction) in enumerate(zip(dev_target_prediction, gold_predictions)): + if prediction_id % 2 == 0: gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + str(sentence_id) + "\t" + str(gap_index) + "\t" - + "gap" + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + "gap" + "\t" + prediction + "\t" + gold_prediction + '\n') gap_index += 1 else: target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + str(sentence_id) + "\t" + str(word_index) + "\t" - + word + "\t" + word_prediction + "\t" + gold_prediction + '\n') + + words[word_index] + "\t" + prediction + "\t" + gold_prediction + '\n') word_index += 1 diff --git a/examples/word_level/wmt_2020/en_de/microtransquest.py b/examples/word_level/wmt_2020/en_de/microtransquest.py index 4ee4184..4e29ed3 100644 --- a/examples/word_level/wmt_2020/en_de/microtransquest.py +++ b/examples/word_level/wmt_2020/en_de/microtransquest.py @@ -2,7 +2,7 @@ import shutil from sklearn.model_selection import train_test_split import os -from examples.word_level.common.util import reader +from examples.word_level.common.util import reader, prepare_testdata from examples.word_level.wmt_2020.en_de.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ TRAIN_SOURCE_TAGS_FILE, \ TRAIN_TARGET_FILE, \ @@ -10,21 +10,20 @@ from examples.word_level.wmt_2020.en_de.microtransquest_config import TRAIN_PATH TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, TEST_TARGET_TAGS_FLE, SEED, DEV_TARGET_TAGS_FLE, \ DEV_SOURCE_TAGS_FILE, DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_TARGET_TAGS_FILE_SUB, DEV_SOURCE_TAGS_FILE_SUB from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel -from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, +raw_train_df = reader(TRAIN_PATH, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, TRAIN_TARGET_TAGS_FLE) -raw_dev_df = reader(DEV_PATH, microtransquest_config, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, +raw_dev_df = reader(DEV_PATH, DEV_SOURCE_FILE, DEV_TARGET_FILE, DEV_SOURCE_TAGS_FILE, DEV_TARGET_TAGS_FLE) -raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) +raw_test_df = reader(TEST_PATH, TEST_SOURCE_FILE, TEST_TARGET_FILE) -test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) -dev_sentences = prepare_testdata(raw_dev_df, args=microtransquest_config) +test_sentences = prepare_testdata(raw_test_df) +dev_sentences = prepare_testdata(raw_dev_df) fold_sources_tags = [] fold_targets_tags = [] @@ -39,27 +38,20 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) - train_df = prepare_data(raw_train, args=microtransquest_config) - eval_df = prepare_data(raw_eval, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df, eval_df=eval_df) - model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train, eval_data=raw_eval) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: - train_df = prepare_data(raw_train_df, args=microtransquest_config) - tags = train_df['labels'].unique().tolist() - model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) - model.train_model(train_df) + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) + model.train_model(raw_train_df) - predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) - sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + sources_tags, targets_tags = model.predict(test_sentences, split_on_space=True) fold_sources_tags.append(sources_tags) fold_targets_tags.append(targets_tags) - dev_predicted_labels, dev_raw_predictions = model.predict(dev_sentences, split_on_space=True) - dev_sources_tags, dev_targets_tags = post_process(dev_predicted_labels, dev_sentences, args=microtransquest_config) + dev_sources_tags, dev_targets_tags = model.predict(dev_sentences, split_on_space=True) dev_fold_sources_tags.append(dev_sources_tags) dev_fold_targets_tags.append(dev_targets_tags) |