diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 18:30:43 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 18:30:43 +0300 |
commit | 2df58efe31f28973a1113755a574fb73a33d3bdd (patch) | |
tree | fa3b387eb79a2918829b7ee492b77b7980a3e73c | |
parent | 30ccc5ddedcaab2d2bc08457eef093a471e0197b (diff) |
058: Code Refactoring
28 files changed, 239 insertions, 431 deletions
diff --git a/examples/sentence_level/wmt_2018/common/util/download.py b/examples/sentence_level/wmt_2018/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2018/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2018/de_en/monotransquest.py b/examples/sentence_level/wmt_2018/de_en/monotransquest.py index 8b50388..0dcc5cf 100644 --- a/examples/sentence_level/wmt_2018/de_en/monotransquest.py +++ b/examples/sentence_level/wmt_2018/de_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -21,9 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" TEST_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" diff --git a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py index 9137a33..8e5a6ad 100644 --- a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py @@ -1,25 +1,22 @@ -import csv import logging -import math import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive + + from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission from examples.sentence_level.wmt_2018.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2018.de_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, \ MODEL_NAME, siamesetransquest_config, SEED, RESULT_FILE, SUBMISSION_FILE, RESULT_IMAGE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel + + logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +26,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/" @@ -50,10 +45,12 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'ht dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'hter': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') - if siamesetransquest_config["evaluate_during_training"]: if siamesetransquest_config["n_fold"] > 0: dev_preds = np.zeros((len(dev), siamesetransquest_config["n_fold"])) @@ -68,75 +65,13 @@ if siamesetransquest_config["evaluate_during_training"]: siamesetransquest_config['cache_dir']): shutil.rmtree(siamesetransquest_config['cache_dir']) - os.makedirs(siamesetransquest_config['cache_dir']) - train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - train_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "train.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - eval_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "eval_df.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - dev.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "dev.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - test.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "test.tsv"), header=True, sep='\t', - index=False, quoting=csv.QUOTE_NONE) - - sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1, - score_col_idx=2, - normalize_scores=False, min_score=0, max_score=1, header=True) - - word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ - 'max_seq_length']) - - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False) - - model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model]) - train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model) - train_dataloader = DataLoader(train_data, shuffle=True, - batch_size=siamesetransquest_config['train_batch_size']) - train_loss = losses.CosineSimilarityLoss(model=model) - - eval_data = SentencesDataset(examples=sts_reader.get_examples('eval_df.tsv'), model=model) - eval_dataloader = DataLoader(eval_data, shuffle=False, - batch_size=siamesetransquest_config['train_batch_size']) - evaluator = EmbeddingSimilarityEvaluator(eval_dataloader) - - warmup_steps = math.ceil( - len(train_data) * siamesetransquest_config["num_train_epochs"] / siamesetransquest_config[ - 'train_batch_size'] * 0.1) - - model.fit(train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=siamesetransquest_config['num_train_epochs'], - evaluation_steps=100, - optimizer_params={'lr': siamesetransquest_config["learning_rate"], - 'eps': siamesetransquest_config["adam_epsilon"], - 'correct_bias': False}, - warmup_steps=warmup_steps, - output_path=siamesetransquest_config['best_model_dir']) + model = SiameseTransQuestModel(MODEL_NAME) + model.train_model(train_df, eval_df) model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir']) - - dev_data = SentencesDataset(examples=sts_reader.get_examples("dev.tsv"), model=model) - dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=8) - evaluator = EmbeddingSimilarityEvaluator(dev_dataloader) - model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) - - test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model) - test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8) - evaluator = EmbeddingSimilarityEvaluator(test_dataloader) - model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"), - verbose=False) - - with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f: - dev_preds[:, i] = list(map(float, f.read().splitlines())) - - with open(os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt")) as f: - test_preds[:, i] = list(map(float, f.read().splitlines())) + dev_preds[:, i] = model.predict(dev_sentence_pairs) + test_preds[:, i] = model.predict(test_sentence_pairs) dev['predictions'] = dev_preds.mean(axis=1) test['predictions'] = test_preds.mean(axis=1) diff --git a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py index ea2e7de..4ee9117 100644 --- a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/" diff --git a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py index d2d0369..d626166 100644 --- a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" diff --git a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py index 86e2b46..37de783 100644 --- a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/" diff --git a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py index fc75ca4..948eb30 100644 --- a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" TEST_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" diff --git a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py index 13a40d5..128a01e 100644 --- a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py +++ b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/" diff --git a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py index c14cded..b29ada7 100644 --- a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py +++ b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py @@ -7,7 +7,7 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission @@ -20,8 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) languages = { # "DE-EN": ["examples/sentence_level/wmt_2018/de_en/data/de_en/", diff --git a/examples/sentence_level/wmt_2019/common/util/download.py b/examples/sentence_level/wmt_2019/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2019/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2019/en_de/monotransquest.py b/examples/sentence_level/wmt_2019/en_de/monotransquest.py index 1261a42..ccdc5e6 100644 --- a/examples/sentence_level/wmt_2019/en_de/monotransquest.py +++ b/examples/sentence_level/wmt_2019/en_de/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file @@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/train" DEV_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/dev" diff --git a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py index e431100..2a1eede 100644 --- a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py +++ b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file @@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/train" DEV_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/dev" diff --git a/examples/sentence_level/wmt_2020/common/util/download.py b/examples/sentence_level/wmt_2020/common/util/download.py deleted file mode 100644 index 9adc98f..0000000 --- a/examples/sentence_level/wmt_2020/common/util/download.py +++ /dev/null @@ -1,7 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path=path + "/model.zip", - unzip=True) diff --git a/examples/sentence_level/wmt_2020/et_en/monotransquest.py b/examples/sentence_level/wmt_2020/et_en/monotransquest.py index defe3af..57e23d9 100644 --- a/examples/sentence_level/wmt_2020/et_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/et_en/monotransquest.py @@ -2,26 +2,22 @@ import os import shutil import numpy as np -import pandas as pd import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.et_en.monotransquest_config import TEMP_DIRECTORY, MODEL_TYPE, MODEL_NAME, monotransquest_config, SEED, \ - RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE, GOOGLE_DRIVE, DRIVE_FILE_ID + RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/train.eten.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/dev.eten.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py index 0ca785d..86495a4 100644 --- a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py +++ b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py @@ -7,7 +7,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -20,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) languages = { "EN-DE": ["examples/sentence_level/wmt_2020/en_de/data/en-de/train.ende.df.short.tsv", diff --git a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py index 6306928..8c1f929 100644 --- a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +18,7 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) + TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py index 480f75b..eadb2c0 100755 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py @@ -8,7 +8,6 @@ import time import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat @@ -17,13 +16,8 @@ from examples.sentence_level.wmt_2020.common.util.postprocess import format_subm from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.ro_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import models -from transquest.algo.sentence_level.siamesetransquest.evaluation.embedding_similarity_evaluator import \ - EmbeddingSimilarityEvaluator from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler -from transquest.algo.sentence_level.siamesetransquest.losses.cosine_similarity_loss import CosineSimilarityLoss -from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel @@ -76,61 +70,11 @@ if siamesetransquest_config["evaluate_during_training"]: siamesetransquest_config['cache_dir']): shutil.rmtree(siamesetransquest_config['cache_dir']) - os.makedirs(siamesetransquest_config['cache_dir']) - train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - - # word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ - # 'max_seq_length']) - # - # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), - # pooling_mode_mean_tokens=True, - # pooling_mode_cls_token=False, - # pooling_mode_max_tokens=False) - model = SiameseTransQuestModel(MODEL_NAME) - - train_samples = [] - eval_samples = [] - dev_samples = [] - test_samples = [] - - for index, row in train_df.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - train_samples.append(inp_example) - - for index, row in eval_df.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - eval_samples.append(inp_example) - - train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=siamesetransquest_config['train_batch_size']) - train_loss = CosineSimilarityLoss(model=model) - - evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval') - warmup_steps = math.ceil(len(train_dataloader) * siamesetransquest_config["num_train_epochs"] * 0.1) - - model.fit(train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=siamesetransquest_config['num_train_epochs'], - evaluation_steps=100, - optimizer_params={'lr': siamesetransquest_config["learning_rate"], - 'eps': siamesetransquest_config["adam_epsilon"], - 'correct_bias': False}, - warmup_steps=warmup_steps, - output_path=siamesetransquest_config['best_model_dir']) + model.train_model(train_df, eval_df) model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir']) - - for index, row in dev.iterrows(): - score = float(row["labels"]) - inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) - dev_samples.append(inp_example) - - evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples) - model.evaluate(evaluator, - output_path=siamesetransquest_config['cache_dir']) dev_preds[:, i] = model.predict(dev_sentence_pairs) test_preds[:, i] = model.predict(test_sentence_pairs) diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py index 3f02ade..8d346b3 100644 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py @@ -30,13 +30,13 @@ siamesetransquest_config = { 'max_grad_norm': 1.0, 'do_lower_case': False, - 'logging_steps': 300, - 'save_steps': 300, + 'logging_steps': 100, + 'save_steps': 100, "no_cache": False, 'save_model_every_epoch': True, 'n_fold': 1, 'evaluate_during_training': True, - 'evaluate_during_training_steps': 300, + 'evaluate_during_training_steps': 100, "evaluate_during_training_verbose": True, 'use_cached_eval_features': False, 'save_eval_checkpoints': True, diff --git a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py index 2aa3c1a..c655077 100644 --- a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py @@ -6,7 +6,7 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/train.ruen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/dev.ruen.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020/si_en/monotransquest.py b/examples/sentence_level/wmt_2020/si_en/monotransquest.py index 811ac6a..03751a0 100644 --- a/examples/sentence_level/wmt_2020/si_en/monotransquest.py +++ b/examples/sentence_level/wmt_2020/si_en/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/train.sien.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/dev.sien.df.short.tsv" diff --git a/examples/sentence_level/wmt_2020_task2/common/util/download.py b/examples/sentence_level/wmt_2020_task2/common/util/download.py deleted file mode 100644 index feacb5a..0000000 --- a/examples/sentence_level/wmt_2020_task2/common/util/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from google_drive_downloader import GoogleDriveDownloader as gdd - -def download_from_google_drive(file_id, path): - gdd.download_file_from_google_drive(file_id=file_id, - dest_path= path + "/model.zip", - unzip=True)
\ No newline at end of file diff --git a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py index c370e06..38aa05b 100644 --- a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py +++ b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission @@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/train" DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/dev" diff --git a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py index 68afd3d..a463e31 100644 --- a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py +++ b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py @@ -6,7 +6,6 @@ import torch from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split -from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission @@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) - TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/train" DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/dev" TEST_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/test-blind" diff --git a/transquest/algo/sentence_level/monotransquest/model_args.py b/transquest/algo/sentence_level/monotransquest/model_args.py index 47ccd15..658c3c7 100644 --- a/transquest/algo/sentence_level/monotransquest/model_args.py +++ b/transquest/algo/sentence_level/monotransquest/model_args.py @@ -1,123 +1,6 @@ -import json -import os -import sys -from dataclasses import asdict, dataclass, field -from multiprocessing import cpu_count +from dataclasses import dataclass, field - -def get_default_process_count(): - process_count = cpu_count() - 2 if cpu_count() > 2 else 1 - if sys.platform == "win32": - process_count = min(process_count, 61) - - return process_count - - -def get_special_tokens(): - return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] - - -@dataclass -class TransQuestArgs: - adam_epsilon: float = 1e-8 - best_model_dir: str = "outputs/best_model" - cache_dir: str = "cache_dir/" - config: dict = field(default_factory=dict) - cosine_schedule_num_cycles: float = 0.5 - custom_layer_parameters: list = field(default_factory=list) - custom_parameter_groups: list = field(default_factory=list) - dataloader_num_workers: int = 0 - do_lower_case: bool = False - dynamic_quantize: bool = False - early_stopping_consider_epochs: bool = False - early_stopping_delta: float = 0 - early_stopping_metric: str = "eval_loss" - early_stopping_metric_minimize: bool = True - early_stopping_patience: int = 3 - encoding: str = None - adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3)) - adafactor_clip_threshold: float = 1.0 - adafactor_decay_rate: float = -0.8 - adafactor_beta1: float = None - adafactor_scale_parameter: bool = True - adafactor_relative_step: bool = True - adafactor_warmup_init: bool = True - eval_batch_size: int = 8 - evaluate_during_training: bool = False - evaluate_during_training_silent: bool = True - evaluate_during_training_steps: int = 2000 - evaluate_during_training_verbose: bool = False - evaluate_each_epoch: bool = True - fp16: bool = True - gradient_accumulation_steps: int = 1 - learning_rate: float = 4e-5 - local_rank: int = -1 - logging_steps: int = 50 - manual_seed: int = None - max_grad_norm: float = 1.0 - max_seq_length: int = 128 - model_name: str = None - model_type: str = None - multiprocessing_chunksize: int = 500 - n_gpu: int = 1 - no_cache: bool = False - no_save: bool = False - not_saved_args: list = field(default_factory=list) - num_train_epochs: int = 1 - optimizer: str = "AdamW" - output_dir: str = "outputs/" - overwrite_output_dir: bool = False - process_count: int = field(default_factory=get_default_process_count) - polynomial_decay_schedule_lr_end: float = 1e-7 - polynomial_decay_schedule_power: float = 1.0 - quantized_model: bool = False - reprocess_input_data: bool = True - save_best_model: bool = True - save_eval_checkpoints: bool = True - save_model_every_epoch: bool = True - save_optimizer_and_scheduler: bool = True - save_recent_only: bool = True - save_steps: int = 2000 - scheduler: str = "linear_schedule_with_warmup" - silent: bool = False - skip_special_tokens: bool = True - tensorboard_dir: str = None - thread_count: int = None - train_batch_size: int = 8 - train_custom_parameters_only: bool = False - use_cached_eval_features: bool = False - use_early_stopping: bool = False - use_multiprocessing: bool = True - wandb_kwargs: dict = field(default_factory=dict) - wandb_project: str = None - warmup_ratio: float = 0.06 - warmup_steps: int = 0 - weight_decay: float = 0.0 - - def update_from_dict(self, new_values): - if isinstance(new_values, dict): - for key, value in new_values.items(): - setattr(self, key, value) - else: - raise (TypeError(f"{new_values} is not a Python dict.")) - - def get_args_for_saving(self): - args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args} - return args_for_saving - - def save(self, output_dir): - os.makedirs(output_dir, exist_ok=True) - with open(os.path.join(output_dir, "model_args.json"), "w") as f: - json.dump(self.get_args_for_saving(), f) - - def load(self, input_dir): - if input_dir: - model_args_file = os.path.join(input_dir, "model_args.json") - if os.path.isfile(model_args_file): - with open(model_args_file, "r") as f: - model_args = json.load(f) - - self.update_from_dict(model_args) +from transquest.model_args import TransQuestArgs @dataclass diff --git a/transquest/algo/sentence_level/siamesetransquest/model_args.py b/transquest/algo/sentence_level/siamesetransquest/model_args.py new file mode 100644 index 0000000..aecce25 --- /dev/null +++ b/transquest/algo/sentence_level/siamesetransquest/model_args.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass, field + +from transquest.model_args import TransQuestArgs + + +@dataclass +class SiameseTransQuestArgs(TransQuestArgs): + """ + Model args for a SiameseTransQuest + """ + + model_class: str = "SiameseTransQuestModel" + labels_list: list = field(default_factory=list) + labels_map: dict = field(default_factory=dict) + lazy_delimiter: str = "\t" + lazy_labels_column: int = 1 + lazy_loading: bool = False + lazy_loading_start_line: int = 1 + lazy_text_a_column: bool = None + lazy_text_b_column: bool = None + lazy_text_column: int = 0 + onnx: bool = False + regression: bool = True + sliding_window: bool = False + special_tokens_list: list = field(default_factory=list) + stride: float = 0.8 + tie_value: int = 1
\ No newline at end of file diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index a271420..aeeb956 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -1,6 +1,7 @@ import json import logging import os +import random import shutil from collections import OrderedDict from typing import List, Dict, Tuple, Iterable, Type, Union, Callable @@ -20,12 +21,16 @@ from tqdm.autonotebook import trange import math import queue -from . import __DOWNLOAD_SERVER__, models + from . import __version__ from transquest.algo.sentence_level.siamesetransquest.util import http_get, import_from_string, batch_to_device from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator from transquest.algo.sentence_level.siamesetransquest.models import Transformer, Pooling +from .evaluation.embedding_similarity_evaluator import EmbeddingSimilarityEvaluator +from .losses.cosine_similarity_loss import CosineSimilarityLoss +from .model_args import SiameseTransQuestArgs +from .readers.input_example import InputExample logger = logging.getLogger(__name__) @@ -35,103 +40,33 @@ class SiameseTransQuestModel(nn.Sequential): Loads or create a SentenceTransformer model, that can be used to map sentences / text to embeddings. :param model_name_or_path: If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from Huggingface models repository with that name. - :param modules: This parameter can be used to create custom SentenceTransformer models from scratch. :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. """ - def __init__(self, model_name_or_path: str = None, device: str = None): - save_model_to = None + def __init__(self, model_name: str = None, args=None, device: str = None): + + self.args = self._load_model_args(model_name) + + if isinstance(args, dict): + self.args.update_from_dict(args) + elif isinstance(args, SiameseTransQuestArgs): + self.args = args + + if self.args.thread_count: + torch.set_num_threads(self.args.thread_count) - transformer_model = Transformer(model_name_or_path, max_seq_length=80) + if self.args.manual_seed: + random.seed(self.args.manual_seed) + np.random.seed(self.args.manual_seed) + torch.manual_seed(self.args.manual_seed) + if self.args.n_gpu > 0: + torch.cuda.manual_seed_all(self.args.manual_seed) + + transformer_model = Transformer(model_name, max_seq_length=80) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) modules = [transformer_model, pooling_model] - # if model_name_or_path is not None and model_name_or_path != "": - # logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) - # model_path = model_name_or_path - # - # if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'): - # logger.info("Did not find folder {}".format(model_path)) - # - # if '\\' in model_path or model_path.count('/') > 1: - # raise AttributeError("Path {} not found".format(model_path)) - # - # model_path = __DOWNLOAD_SERVER__ + model_path + '.zip' - # logger.info("Search model on server: {}".format(model_path)) - # - # if model_path.startswith('http://') or model_path.startswith('https://'): - # model_url = model_path - # folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end - # - # cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME') - # if cache_folder is None: - # try: - # from torch.hub import _get_torch_home - # torch_cache_home = _get_torch_home() - # except ImportError: - # torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) - # - # cache_folder = os.path.join(torch_cache_home, 'sentence_transformers') - # - # model_path = os.path.join(cache_folder, folder_name) - # - # if not os.path.exists(model_path) or not os.listdir(model_path): - # if os.path.exists(model_path): - # os.remove(model_path) - # - # model_url = model_url.rstrip("/") - # logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path)) - # - # model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part" - # try: - # zip_save_path = os.path.join(model_path_tmp, 'model.zip') - # http_get(model_url, zip_save_path) - # with ZipFile(zip_save_path, 'r') as zip: - # zip.extractall(model_path_tmp) - # os.remove(zip_save_path) - # os.rename(model_path_tmp, model_path) - # except requests.exceptions.HTTPError as e: - # shutil.rmtree(model_path_tmp) - # if e.response.status_code == 429: - # raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path)) - # - # if e.response.status_code == 404: - # logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url)) - # logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path)) - # - # save_model_to = model_path - # model_path = None - # transformer_model = Transformer(model_name_or_path) - # pooling_model = Pooling(transformer_model.get_word_embedding_dimension()) - # modules = [transformer_model, pooling_model] - # else: - # raise e - # except Exception as e: - # shutil.rmtree(model_path) - # raise e - # - # - # # #### Load from disk - # if model_path is not None: - # logger.info("Load SentenceTransformer from folder: {}".format(model_path)) - # - # if os.path.exists(os.path.join(model_path, 'config.json')): - # with open(os.path.join(model_path, 'config.json')) as fIn: - # config = json.load(fIn) - # if config['__version__'] > __version__: - # logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__)) - # - # with open(os.path.join(model_path, 'modules.json')) as fIn: - # contained_modules = json.load(fIn) - # - # modules = OrderedDict() - # for module_config in contained_modules: - # module_class = import_from_string(module_config['type']) - # module = module_class.load(os.path.join(model_path, module_config['path'])) - # modules[module_config['name']] = module - # - # if modules is not None and not isinstance(modules, OrderedDict): modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)]) @@ -142,10 +77,6 @@ class SiameseTransQuestModel(nn.Sequential): self._target_device = torch.device(device) - #We created a new model from scratch based on a Transformer model. Save the SBERT model in the cache folder - if save_model_to is not None: - self.save(save_model_to) - def encode(self, sentences: Union[str, List[str], List[int]], batch_size: int = 32, show_progress_bar: bool = None, @@ -345,7 +276,6 @@ class SiameseTransQuestModel(nn.Sequential): except queue.Empty: break - def get_max_seq_length(self): """ Returns the maximal sequence length for input the model accepts. Longer inputs will be truncated @@ -450,6 +380,40 @@ class SiameseTransQuestModel(nn.Sequential): else: return sum([len(t) for t in text]) #Sum of length of individual strings + def train_model(self, train_df, eval_df, args=None, output_dir=None, verbose=True): + + train_samples = [] + for index, row in train_df.iterrows(): + score = float(row["labels"]) + inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) + train_samples.append(inp_example) + + eval_samples = [] + for index, row in eval_df.iterrows(): + score = float(row["labels"]) + inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score) + eval_samples.append(inp_example) + + train_dataloader = DataLoader(train_samples, shuffle=True, + batch_size=self.args.train_batch_size) + train_loss = CosineSimilarityLoss(model=self) + + evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval') + warmup_steps = math.ceil(len(train_dataloader) * self.args.num_train_epochs * 0.1) + + self.fit(train_objectives=[(train_dataloader, train_loss)], + evaluator=evaluator, + epochs=self.args.num_train_epochs, + evaluation_steps=self.args.evaluate_during_training_steps, + optimizer_params={'lr': self.args.learning_rate, + 'eps': self.args.adam_epsilon, + 'correct_bias': False}, + warmup_steps=warmup_steps, + weight_decay=self.args.weight_decay, + max_grad_norm=self.args.max_grad_norm, + output_path=self.args.best_model_dir) + + def fit(self, train_objectives: Iterable[Tuple[DataLoader, nn.Module]], evaluator: SentenceEvaluator = None, @@ -671,6 +635,15 @@ class SiameseTransQuestModel(nn.Sequential): first_tuple = next(gen) return first_tuple[1].device + def save_model_args(self, output_dir): + os.makedirs(output_dir, exist_ok=True) + self.args.save(output_dir) + + def _load_model_args(self, input_dir): + args = SiameseTransQuestArgs() + args.load(input_dir) + return args + @property def tokenizer(self): """ diff --git a/transquest/algo/word_level/microtransquest/model_args.py b/transquest/algo/word_level/microtransquest/model_args.py index 4a5ca66..d2ef919 100644 --- a/transquest/algo/word_level/microtransquest/model_args.py +++ b/transquest/algo/word_level/microtransquest/model_args.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from transquest.algo.sentence_level.monotransquest.model_args import TransQuestArgs +from transquest.model_args import TransQuestArgs @dataclass diff --git a/transquest/model_args.py b/transquest/model_args.py new file mode 100644 index 0000000..af8be51 --- /dev/null +++ b/transquest/model_args.py @@ -0,0 +1,120 @@ +import json +import os +import sys +from dataclasses import dataclass, field, asdict +from multiprocessing import cpu_count + + +def get_default_process_count(): + process_count = cpu_count() - 2 if cpu_count() > 2 else 1 + if sys.platform == "win32": + process_count = min(process_count, 61) + + return process_count + + +def get_special_tokens(): + return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] + + +@dataclass +class TransQuestArgs: + adam_epsilon: float = 1e-8 + best_model_dir: str = "outputs/best_model" + cache_dir: str = "cache_dir/" + config: dict = field(default_factory=dict) + cosine_schedule_num_cycles: float = 0.5 + custom_layer_parameters: list = field(default_factory=list) + custom_parameter_groups: list = field(default_factory=list) + dataloader_num_workers: int = 0 + do_lower_case: bool = False + dynamic_quantize: bool = False + early_stopping_consider_epochs: bool = False + early_stopping_delta: float = 0 + early_stopping_metric: str = "eval_loss" + early_stopping_metric_minimize: bool = True + early_stopping_patience: int = 3 + encoding: str = None + adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3)) + adafactor_clip_threshold: float = 1.0 + adafactor_decay_rate: float = -0.8 + adafactor_beta1: float = None + adafactor_scale_parameter: bool = True + adafactor_relative_step: bool = True + adafactor_warmup_init: bool = True + eval_batch_size: int = 8 + evaluate_during_training: bool = False + evaluate_during_training_silent: bool = True + evaluate_during_training_steps: int = 2000 + evaluate_during_training_verbose: bool = False + evaluate_each_epoch: bool = True + fp16: bool = True + gradient_accumulation_steps: int = 1 + learning_rate: float = 4e-5 + local_rank: int = -1 + logging_steps: int = 50 + manual_seed: int = None + max_grad_norm: float = 1.0 + max_seq_length: int = 128 + model_name: str = None + model_type: str = None + multiprocessing_chunksize: int = 500 + n_gpu: int = 1 + no_cache: bool = False + no_save: bool = False + not_saved_args: list = field(default_factory=list) + num_train_epochs: int = 1 + optimizer: str = "AdamW" + output_dir: str = "outputs/" + overwrite_output_dir: bool = False + process_count: int = field(default_factory=get_default_process_count) + polynomial_decay_schedule_lr_end: float = 1e-7 + polynomial_decay_schedule_power: float = 1.0 + quantized_model: bool = False + reprocess_input_data: bool = True + save_best_model: bool = True + save_eval_checkpoints: bool = True + save_model_every_epoch: bool = True + save_optimizer_and_scheduler: bool = True + save_recent_only: bool = True + save_steps: int = 2000 + scheduler: str = "linear_schedule_with_warmup" + silent: bool = False + skip_special_tokens: bool = True + tensorboard_dir: str = None + thread_count: int = None + train_batch_size: int = 8 + train_custom_parameters_only: bool = False + use_cached_eval_features: bool = False + use_early_stopping: bool = False + use_multiprocessing: bool = True + wandb_kwargs: dict = field(default_factory=dict) + wandb_project: str = None + warmup_ratio: float = 0.06 + warmup_steps: int = 0 + weight_decay: float = 0.0 + + def update_from_dict(self, new_values): + if isinstance(new_values, dict): + for key, value in new_values.items(): + setattr(self, key, value) + else: + raise (TypeError(f"{new_values} is not a Python dict.")) + + def get_args_for_saving(self): + args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args} + return args_for_saving + + def save(self, output_dir): + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, "model_args.json"), "w") as f: + json.dump(self.get_args_for_saving(), f) + + def load(self, input_dir): + if input_dir: + model_args_file = os.path.join(input_dir, "model_args.json") + if os.path.isfile(model_args_file): + with open(model_args_file, "r") as f: + model_args = json.load(f) + + self.update_from_dict(model_args)
\ No newline at end of file |