diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 18:30:43 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 18:30:43 +0300 |
commit | 2df58efe31f28973a1113755a574fb73a33d3bdd (patch) | |
tree | fa3b387eb79a2918829b7ee492b77b7980a3e73c /examples | |
parent | 30ccc5ddedcaab2d2bc08457eef093a471e0197b (diff) |
058: Code Refactoring
Diffstat (limited to 'examples')
23 files changed, 21 insertions, 216 deletions
diff --git a/examples/sentence_level/wmt_2018/common/util/download.py b/examples/sentence_level/wmt_2018/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2018/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2018/de_en/monotransquest.py b/examples/sentence_level/wmt_2018/de_en/monotransquest.py index 8b50388..0dcc5cf 100644 --- a/examples/sentence_level/wmt_2018/de_en/monotransquest.py +++ b/examples/sentence_level/wmt_2018/de_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -21,9 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" TEST_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" diff --git a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py index 9137a33..8e5a6ad 100644 --- a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py @@ -1,25 +1,22 @@ -import csv import logging -import math import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive + + from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission from examples.sentence_level.wmt_2018.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2018.de_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, \ MODEL_NAME, siamesetransquest_config, SEED, RESULT_FILE, SUBMISSION_FILE, RESULT_IMAGE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel + + logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +26,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" @@ -50,10 +45,12 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'ht dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'hter': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') - if siamesetransquest_config["evaluate_during_training"]: if siamesetransquest_config["n_fold"] > 0: dev_preds = np.zeros((len(dev), siamesetransquest_config["n_fold"])) @@ -68,75 +65,13 @@ if siamesetransquest_config["evaluate_during_training"]: siamesetransquest_config['cache_dir']): shutil.rmtree(siamesetransquest_config['cache_dir']) - os.makedirs(siamesetransquest_config['cache_dir']) - train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - train_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "train.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - eval_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "eval_df.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - dev.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "dev.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - test.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "test.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - - sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1, - score_col_idx=2, - normalize_scores=False, min_score=0, max_score=1, header=True) - - word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ - 'max_seq_length']) - - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False) - - model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model]) - train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model) - train_dataloader = DataLoader(train_data, shuffle=True, - batch_size=siamesetransquest_config['train_batch_size']) - train_loss = losses.CosineSimilarityLoss(model=model) - - eval_data = SentencesDataset(examples=sts_reader.get_examples('eval_df.tsv'), model=model) - eval_dataloader = DataLoader(eval_data, shuffle=False, - batch_size=siamesetransquest_config['train_batch_size']) - evaluator = EmbeddingSimilarityEvaluator(eval_dataloader) - - warmup_steps = math.ceil( - len(train_data) * siamesetransquest_config["num_train_epochs"] / siamesetransquest_config[ - 'train_batch_size'] * 0.1) - - model.fit(train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=siamesetransquest_config['num_train_epochs'], - evaluation_steps=100, - optimizer_params={'lr': siamesetransquest_config["learning_rate"], - 'eps': siamesetransquest_config["adam_epsilon"], - 'correct_bias': False}, - warmup_steps=warmup_steps, - output_path=siamesetransquest_config['best_model_dir']) + model = SiameseTransQuestModel(MODEL_NAME) + model.train_model(train_df, eval_df) model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir']) - - dev_data = SentencesDataset(examples=sts_reader.get_examples("dev.tsv"), model=model) - dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=8) - evaluator = EmbeddingSimilarityEvaluator(dev_dataloader) - model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) - - test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model) - test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8) - evaluator = EmbeddingSimilarityEvaluator(test_dataloader) - model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"), - verbose=False) - - with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f: - dev_preds[:, i] = list(map(float, f.read().splitlines())) - - with open(os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt")) as f: - test_preds[:, i] = list(map(float, f.read().splitlines())) + dev_preds[:, i] = model.predict(dev_sentence_pairs) + test_preds[:, i] = model.predict(test_sentence_pairs) dev['predictions'] = dev_preds.mean(axis=1) test['predictions'] = test_preds.mean(axis=1) diff --git a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py index ea2e7de..4ee9117 100644 --- a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/" diff --git a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py index d2d0369..d626166 100644 --- a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" diff --git a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py index 86e2b46..37de783 100644 --- a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" diff --git a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py index fc75ca4..948eb30 100644 --- a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" TEST_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" diff --git a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py index 13a40d5..128a01e 100644 --- a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" diff --git a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py index c14cded..b29ada7 100644 --- a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py +++ b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py @@ -7,7 +7,7 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -20,8 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) languages = { # "DE-EN": ["examples/sentence_level/wmt_2018/de_en/data/de_en/", diff --git a/examples/sentence_level/wmt_2019/common/util/download.py b/examples/sentence_level/wmt_2019/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2019/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2019/en_de/monotransquest.py b/examples/sentence_level/wmt_2019/en_de/monotransquest.py index 1261a42..ccdc5e6 100644 --- a/examples/sentence_level/wmt_2019/en_de/monotransquest.py +++ b/examples/sentence_level/wmt_2019/en_de/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file @@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/train" DEV_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/dev" diff --git a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py index e431100..2a1eede 100644 --- a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py +++ b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file @@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/train" DEV_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/dev" diff --git a/examples/sentence_level/wmt_2020/common/util/download.py b/examples/sentence_level/wmt_2020/common/util/download.py deleted file mode 100644 index 9adc98f..0000000 --- a/examples/sentence_level/wmt_2020/common/util/download.py +++ /dev/null @@ -1,7 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path=path + "/model.zip", - unzip=True) diff --git a/examples/sentence_level/wmt_2020/et_en/monotransquest.py b/examples/sentence_level/wmt_2020/et_en/monotransquest.py index defe3af..57e23d9 100644 --- a/examples/sentence_level/wmt_2020/et_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/et_en/monotransquest.py @@ -2,26 +2,22 @@ import os import shutil import numpy as np -import pandas as pd import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.et_en.monotransquest_config import TEMP_DIRECTORY, MODEL_TYPE, MODEL_NAME, monotransquest_config, SEED, \ - RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE, GOOGLE_DRIVE, DRIVE_FILE_ID + RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/train.eten.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/dev.eten.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py index 0ca785d..86495a4 100644 --- a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py +++ b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py @@ -7,7 +7,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -20,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) languages = { "EN-DE": ["examples/sentence_level/wmt_2020/en_de/data/en-de/train.ende.df.short.tsv", diff --git a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py index 6306928..8c1f929 100644 --- a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +18,7 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) + TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py index 480f75b..eadb2c0 100755 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py @@ -8,7 +8,6 @@ import time import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat @@ -17,13 +16,8 @@ from examples.sentence_level.wmt_2020.common.util.postprocess import format_subm from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.ro_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import models -from transquest.algo.sentence_level.siamesetransquest.evaluation.embedding_similarity_evaluator import \ - EmbeddingSimilarityEvaluator from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler -from transquest.algo.sentence_level.siamesetransquest.losses.cosine_similarity_loss import CosineSimilarityLoss -from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel @@ -76,61 +70,11 @@ if siamesetransquest_config["evaluate_during_training"]: siamesetransquest_config['cache_dir']): shutil.rmtree(siamesetransquest_config['cache_dir']) - os.makedirs(siamesetransquest_config['cache_dir']) - train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - - # word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ - # 'max_seq_length']) - # - # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), - # pooling_mode_mean_tokens=True, - # pooling_mode_cls_token=False, - # pooling_mode_max_tokens=False) - model = SiameseTransQuestModel(MODEL_NAME) - - train_samples = [] - eval_samples = [] - dev_samples = [] - test_samples = [] - - for index, row in train_df.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - train_samples.append(inp_example) - - for index, row in eval_df.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - eval_samples.append(inp_example) - - train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=siamesetransquest_config['train_batch_size']) - train_loss = CosineSimilarityLoss(model=model) - - evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval') - warmup_steps = math.ceil(len(train_dataloader) * siamesetransquest_config["num_train_epochs"] * 0.1) - - model.fit(train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=siamesetransquest_config['num_train_epochs'], - evaluation_steps=100, - optimizer_params={'lr': siamesetransquest_config["learning_rate"], - 'eps': siamesetransquest_config["adam_epsilon"], - 'correct_bias': False}, - warmup_steps=warmup_steps, - output_path=siamesetransquest_config['best_model_dir']) + model.train_model(train_df, eval_df) model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir']) - - for index, row in dev.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - dev_samples.append(inp_example) - - evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples) - model.evaluate(evaluator, - output_path=siamesetransquest_config['cache_dir']) dev_preds[:, i] = model.predict(dev_sentence_pairs) test_preds[:, i] = model.predict(test_sentence_pairs) diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py index 3f02ade..8d346b3 100644 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py @@ -30,13 +30,13 @@ siamesetransquest_config = { 'max_grad_norm': 1.0, 'do_lower_case': False, - 'logging_steps': 300, - 'save_steps': 300, + 'logging_steps': 100, + 'save_steps': 100, "no_cache": False, 'save_model_every_epoch': True, 'n_fold': 1, 'evaluate_during_training': True, - 'evaluate_during_training_steps': 300, + 'evaluate_during_training_steps': 100, "evaluate_during_training_verbose": True, 'use_cached_eval_features': False, 'save_eval_checkpoints': True, diff --git a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py index 2aa3c1a..c655077 100644 --- a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py @@ -6,7 +6,7 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/train.ruen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/dev.ruen.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/si_en/monotransquest.py b/examples/sentence_level/wmt_2020/si_en/monotransquest.py index 811ac6a..03751a0 100644 --- a/examples/sentence_level/wmt_2020/si_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/si_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/train.sien.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/dev.sien.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020_task2/common/util/download.py b/examples/sentence_level/wmt_2020_task2/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2020_task2/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py index c370e06..38aa05b 100644 --- a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py +++ b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/train" DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/dev" diff --git a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py index 68afd3d..a463e31 100644 --- a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py +++ b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission @@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/train" DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/dev" TEST_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/test-blind" |