diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-17 01:59:38 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-17 01:59:38 +0300 |
commit | e8d5492b91c81b8d83458e9571ce2c01d41ec4b2 (patch) | |
tree | b28eed971a7f47a6aff0d7f872f7d801dcdb7a10 | |
parent | ecb7626875f2b8573bc08ee48b0100a2c9daabdd (diff) |
056: Code Refactoring
-rw-r--r-- | examples/word_level/wmt_2018/de_en/microtransquest.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/examples/word_level/wmt_2018/de_en/microtransquest.py b/examples/word_level/wmt_2018/de_en/microtransquest.py index 82b30ea..1968657 100644 --- a/examples/word_level/wmt_2018/de_en/microtransquest.py +++ b/examples/word_level/wmt_2018/de_en/microtransquest.py @@ -40,7 +40,7 @@ for i in range(microtransquest_config["n_fold"]): if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) - model.train_model(raw_train, eval_df=raw_eval) + model.train_model(raw_train, eval_data=raw_eval) model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) @@ -88,8 +88,8 @@ for sentence_id in range(len(test_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) target_predictions.append(majority_prediction) -test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() -test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() +test_source_sentences = raw_test_df["source"].tolist() +test_target_sentences = raw_test_df["target"].tolist() with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: for sentence_id, (test_source_sentence, source_prediction) in enumerate( @@ -153,10 +153,10 @@ for sentence_id in range(len(dev_sentences)): majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) dev_target_predictions.append(majority_prediction) -dev_source_sentences = raw_dev_df[microtransquest_config["source_column"]].tolist() -dev_target_sentences = raw_dev_df[microtransquest_config["target_column"]].tolist() -dev_source_gold_tags = raw_dev_df[microtransquest_config["source_tags_column"]].tolist() -dev_target_gold_tags = raw_dev_df[microtransquest_config["target_tags_column"]].tolist() +dev_source_sentences = raw_dev_df["source"].tolist() +dev_target_sentences = raw_dev_df["target"].tolist() +dev_source_gold_tags = raw_dev_df["source_tags"].tolist() +dev_target_gold_tags = raw_dev_df["target_tags"].tolist() with open(os.path.join(TEMP_DIRECTORY, DEV_SOURCE_TAGS_FILE_SUB), 'w') as f: for sentence_id, (dev_source_sentence, dev_source_prediction, source_gold_tag) in enumerate( |