Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-03-17 01:59:38 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-03-17 01:59:38 +0300
commite8d5492b91c81b8d83458e9571ce2c01d41ec4b2 (patch)
treeb28eed971a7f47a6aff0d7f872f7d801dcdb7a10
parentecb7626875f2b8573bc08ee48b0100a2c9daabdd (diff)
056: Code Refactoring
-rw-r--r--examples/word_level/wmt_2018/de_en/microtransquest.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/examples/word_level/wmt_2018/de_en/microtransquest.py b/examples/word_level/wmt_2018/de_en/microtransquest.py
index 82b30ea..1968657 100644
--- a/examples/word_level/wmt_2018/de_en/microtransquest.py
+++ b/examples/word_level/wmt_2018/de_en/microtransquest.py
@@ -40,7 +40,7 @@ for i in range(microtransquest_config["n_fold"]):
if microtransquest_config["evaluate_during_training"]:
raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config)
- model.train_model(raw_train, eval_df=raw_eval)
+ model.train_model(raw_train, eval_data=raw_eval)
model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"],
args=microtransquest_config)
@@ -88,8 +88,8 @@ for sentence_id in range(len(test_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
target_predictions.append(majority_prediction)
-test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
-test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+test_source_sentences = raw_test_df["source"].tolist()
+test_target_sentences = raw_test_df["target"].tolist()
with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
for sentence_id, (test_source_sentence, source_prediction) in enumerate(
@@ -153,10 +153,10 @@ for sentence_id in range(len(dev_sentences)):
majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
dev_target_predictions.append(majority_prediction)
-dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist()
-dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist()
-dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist()
-dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist()
+dev_source_sentences = raw_dev_df["source"].tolist()
+dev_target_sentences = raw_dev_df["target"].tolist()
+dev_source_gold_tags = raw_dev_df["source_tags"].tolist()
+dev_target_gold_tags = raw_dev_df["target_tags"].tolist()
with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f:
for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate(