diff options
Diffstat (limited to 'examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py')
-rw-r--r-- | examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py index 37de783..a3390d2 100644 --- a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py @@ -10,7 +10,7 @@ from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission from examples.sentence_level.wmt_2018.common.util.reader import read_annotated_file, read_test_file -from examples.sentence_level.wmt_2018.en_de.smt.monotransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ +from examples.sentence_level.wmt_2018.en_de.smt.monotransquest_config import TEMP_DIRECTORY, MODEL_NAME, \ monotransquest_config, MODEL_TYPE, SEED, RESULT_FILE, SUBMISSION_FILE, RESULT_IMAGE from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel @@ -18,13 +18,14 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) - TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" TEST_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" -train = read_annotated_file(path=TRAIN_FOLDER, original_file="train.smt.src", translation_file="train.smt.mt", hter_file="train.smt.hter") -dev = read_annotated_file(path=DEV_FOLDER, original_file="dev.smt.src", translation_file="dev.smt.mt", hter_file="dev.smt.hter") +train = read_annotated_file(path=TRAIN_FOLDER, original_file="train.smt.src", translation_file="train.smt.mt", + hter_file="train.smt.hter") +dev = read_annotated_file(path=DEV_FOLDER, original_file="dev.smt.src", translation_file="dev.smt.mt", + hter_file="dev.smt.hter") test = read_test_file(path=TEST_FOLDER, original_file="test.smt.src", translation_file="test.smt.mt") train = train[['original', 'translation', 'hter']] @@ -47,15 +48,17 @@ if monotransquest_config["evaluate_during_training"]: test_preds = np.zeros((len(test), monotransquest_config["n_fold"])) for i in range(monotransquest_config["n_fold"]): - if os.path.exists(monotransquest_config['output_dir']) and os.path.isdir(monotransquest_config['output_dir']): + if os.path.exists(monotransquest_config['output_dir']) and os.path.isdir( + monotransquest_config['output_dir']): shutil.rmtree(monotransquest_config['output_dir']) model = MonoTransQuestModel(MODEL_TYPE, MODEL_NAME, num_labels=1, use_cuda=torch.cuda.is_available(), args=monotransquest_config) - train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED*i) + train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) model.train_model(train_df, eval_df=eval_df, pearson_corr=pearson_corr, spearman_corr=spearman_corr, mae=mean_absolute_error) - model = MonoTransQuestModel(MODEL_TYPE, monotransquest_config["best_model_dir"], num_labels=1, use_cuda=torch.cuda.is_available(), args=monotransquest_config) + model = MonoTransQuestModel(MODEL_TYPE, monotransquest_config["best_model_dir"], num_labels=1, + use_cuda=torch.cuda.is_available(), args=monotransquest_config) result, model_outputs, wrong_predictions = model.eval_model(dev, pearson_corr=pearson_corr, spearman_corr=spearman_corr, mae=mean_absolute_error) @@ -97,4 +100,4 @@ test = un_fit(test, 'predictions') dev.to_csv(os.path.join(TEMP_DIRECTORY, RESULT_FILE), header=True, sep='\t', index=False, encoding='utf-8') draw_scatterplot(dev, 'labels', 'predictions', os.path.join(TEMP_DIRECTORY, RESULT_IMAGE), "English-German-SMT") print_stat(dev, 'labels', 'predictions') -format_submission(df=test, index=index, method="TransQuest", path=os.path.join(TEMP_DIRECTORY, SUBMISSION_FILE))
\ No newline at end of file +format_submission(df=test, index=index, method="TransQuest", path=os.path.join(TEMP_DIRECTORY, SUBMISSION_FILE)) |