Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 18:30:43 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 18:30:43 +0300
commit2df58efe31f28973a1113755a574fb73a33d3bdd (patch)
treefa3b387eb79a2918829b7ee492b77b7980a3e73c
parent30ccc5ddedcaab2d2bc08457eef093a471e0197b (diff)
058: Code Refactoring
-rw-r--r--examples/sentence_level/wmt_2018/common/util/download.py6
-rw-r--r--examples/sentence_level/wmt_2018/de_en/monotransquest.py4
-rw-r--r--examples/sentence_level/wmt_2018/de_en/siamesetransquest.py91
-rw-r--r--examples/sentence_level/wmt_2018/en_cs/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py4
-rw-r--r--examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2018/multilingual/monotransquest.py4
-rw-r--r--examples/sentence_level/wmt_2019/common/util/download.py6
-rw-r--r--examples/sentence_level/wmt_2019/en_de/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2019/en_ru/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2020/common/util/download.py7
-rw-r--r--examples/sentence_level/wmt_2020/et_en/monotransquest.py6
-rw-r--r--examples/sentence_level/wmt_2020/multilingual/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2020/ne_en/monotransquest.py4
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py58
-rw-r--r--examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py6
-rw-r--r--examples/sentence_level/wmt_2020/ru_en/monotransquest.py4
-rw-r--r--examples/sentence_level/wmt_2020/si_en/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2020_task2/common/util/download.py6
-rw-r--r--examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py3
-rw-r--r--examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py4
-rw-r--r--transquest/algo/sentence_level/monotransquest/model_args.py121
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/model_args.py27
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py163
-rw-r--r--transquest/algo/word_level/microtransquest/model_args.py2
-rw-r--r--transquest/model_args.py120
28 files changed, 239 insertions, 431 deletions
diff --git a/examples/sentence_level/wmt_2018/common/util/download.py b/examples/sentence_level/wmt_2018/common/util/download.py
deleted file mode 100644
index feacb5a..0000000
--- a/examples/sentence_level/wmt_2018/common/util/download.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from google_drive_downloader import GoogleDriveDownloader as gdd
-
-def download_from_google_drive(file_id, path):
- gdd.download_file_from_google_drive(file_id=file_id,
- dest_path= path + "/model.zip",
- unzip=True) \ No newline at end of file
diff --git a/examples/sentence_level/wmt_2018/de_en/monotransquest.py b/examples/sentence_level/wmt_2018/de_en/monotransquest.py
index 8b50388..0dcc5cf 100644
--- a/examples/sentence_level/wmt_2018/de_en/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/de_en/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -21,9 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
-
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/"
TEST_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/"
diff --git a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py
index 9137a33..8e5a6ad 100644
--- a/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2018/de_en/siamesetransquest.py
@@ -1,25 +1,22 @@
-import csv
import logging
-import math
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
+
+
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2018.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2018.de_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, \
MODEL_NAME, siamesetransquest_config, SEED, RESULT_FILE, SUBMISSION_FILE, RESULT_IMAGE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
+
+
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +26,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/de_en/data/de_en/"
@@ -50,10 +45,12 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'ht
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'hter': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
-
if siamesetransquest_config["evaluate_during_training"]:
if siamesetransquest_config["n_fold"] > 0:
dev_preds = np.zeros((len(dev), siamesetransquest_config["n_fold"]))
@@ -68,75 +65,13 @@ if siamesetransquest_config["evaluate_during_training"]:
siamesetransquest_config['cache_dir']):
shutil.rmtree(siamesetransquest_config['cache_dir'])
- os.makedirs(siamesetransquest_config['cache_dir'])
-
train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
- train_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "train.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- eval_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "eval_df.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- dev.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "dev.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- test.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "test.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
-
- sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1,
- score_col_idx=2,
- normalize_scores=False, min_score=0, max_score=1, header=True)
-
- word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
- 'max_seq_length'])
-
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
- pooling_mode_mean_tokens=True,
- pooling_mode_cls_token=False,
- pooling_mode_max_tokens=False)
-
- model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model])
- train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model)
- train_dataloader = DataLoader(train_data, shuffle=True,
- batch_size=siamesetransquest_config['train_batch_size'])
- train_loss = losses.CosineSimilarityLoss(model=model)
-
- eval_data = SentencesDataset(examples=sts_reader.get_examples('eval_df.tsv'), model=model)
- eval_dataloader = DataLoader(eval_data, shuffle=False,
- batch_size=siamesetransquest_config['train_batch_size'])
- evaluator = EmbeddingSimilarityEvaluator(eval_dataloader)
-
- warmup_steps = math.ceil(
- len(train_data) * siamesetransquest_config["num_train_epochs"] / siamesetransquest_config[
- 'train_batch_size'] * 0.1)
-
- model.fit(train_objectives=[(train_dataloader, train_loss)],
- evaluator=evaluator,
- epochs=siamesetransquest_config['num_train_epochs'],
- evaluation_steps=100,
- optimizer_params={'lr': siamesetransquest_config["learning_rate"],
- 'eps': siamesetransquest_config["adam_epsilon"],
- 'correct_bias': False},
- warmup_steps=warmup_steps,
- output_path=siamesetransquest_config['best_model_dir'])
+ model = SiameseTransQuestModel(MODEL_NAME)
+ model.train_model(train_df, eval_df)
model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir'])
-
- dev_data = SentencesDataset(examples=sts_reader.get_examples("dev.tsv"), model=model)
- dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=8)
- evaluator = EmbeddingSimilarityEvaluator(dev_dataloader)
- model.evaluate(evaluator,
- result_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt"))
-
- test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model)
- test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8)
- evaluator = EmbeddingSimilarityEvaluator(test_dataloader)
- model.evaluate(evaluator,
- result_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"),
- verbose=False)
-
- with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f:
- dev_preds[:, i] = list(map(float, f.read().splitlines()))
-
- with open(os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt")) as f:
- test_preds[:, i] = list(map(float, f.read().splitlines()))
+ dev_preds[:, i] = model.predict(dev_sentence_pairs)
+ test_preds[:, i] = model.predict(test_sentence_pairs)
dev['predictions'] = dev_preds.mean(axis=1)
test['predictions'] = test_preds.mean(axis=1)
diff --git a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py
index ea2e7de..4ee9117 100644
--- a/examples/sentence_level/wmt_2018/en_cs/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_cs/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_cs/data/en_cs/"
diff --git a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py
index d2d0369..d626166 100644
--- a/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_de/nmt/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
diff --git a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
index 86e2b46..37de783 100644
--- a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
diff --git a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py
index fc75ca4..948eb30 100644
--- a/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_lv/nmt/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
-
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/"
TEST_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/"
diff --git a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py
index 13a40d5..128a01e 100644
--- a/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_lv/smt/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_lv/data/en_lv/"
diff --git a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py
index c14cded..b29ada7 100644
--- a/examples/sentence_level/wmt_2018/multilingual/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/multilingual/monotransquest.py
@@ -7,7 +7,7 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2018.common.util.download import download_from_google_drive
+
from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
@@ -20,8 +20,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
languages = {
# "DE-EN": ["examples/sentence_level/wmt_2018/de_en/data/de_en/",
diff --git a/examples/sentence_level/wmt_2019/common/util/download.py b/examples/sentence_level/wmt_2019/common/util/download.py
deleted file mode 100644
index feacb5a..0000000
--- a/examples/sentence_level/wmt_2019/common/util/download.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from google_drive_downloader import GoogleDriveDownloader as gdd
-
-def download_from_google_drive(file_id, path):
- gdd.download_file_from_google_drive(file_id=file_id,
- dest_path= path + "/model.zip",
- unzip=True) \ No newline at end of file
diff --git a/examples/sentence_level/wmt_2019/en_de/monotransquest.py b/examples/sentence_level/wmt_2019/en_de/monotransquest.py
index 1261a42..ccdc5e6 100644
--- a/examples/sentence_level/wmt_2019/en_de/monotransquest.py
+++ b/examples/sentence_level/wmt_2019/en_de/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file
@@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/train"
DEV_FOLDER = "examples/sentence_level/wmt_2019/en_de/data/en-de/dev"
diff --git a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py
index e431100..2a1eede 100644
--- a/examples/sentence_level/wmt_2019/en_ru/monotransquest.py
+++ b/examples/sentence_level/wmt_2019/en_ru/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2019.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2019.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2019.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2019.common.util.reader import read_annotated_file, read_test_file
@@ -18,8 +17,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/train"
DEV_FOLDER = "examples/sentence_level/wmt_2019/en_ru/data/en-ru/dev"
diff --git a/examples/sentence_level/wmt_2020/common/util/download.py b/examples/sentence_level/wmt_2020/common/util/download.py
deleted file mode 100644
index 9adc98f..0000000
--- a/examples/sentence_level/wmt_2020/common/util/download.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from google_drive_downloader import GoogleDriveDownloader as gdd
-
-
-def download_from_google_drive(file_id, path):
- gdd.download_file_from_google_drive(file_id=file_id,
- dest_path=path + "/model.zip",
- unzip=True)
diff --git a/examples/sentence_level/wmt_2020/et_en/monotransquest.py b/examples/sentence_level/wmt_2020/et_en/monotransquest.py
index defe3af..57e23d9 100644
--- a/examples/sentence_level/wmt_2020/et_en/monotransquest.py
+++ b/examples/sentence_level/wmt_2020/et_en/monotransquest.py
@@ -2,26 +2,22 @@ import os
import shutil
import numpy as np
-import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.et_en.monotransquest_config import TEMP_DIRECTORY, MODEL_TYPE, MODEL_NAME, monotransquest_config, SEED, \
- RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE, GOOGLE_DRIVE, DRIVE_FILE_ID
+ RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr
from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/train.eten.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/dev.eten.df.short.tsv"
diff --git a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py
index 0ca785d..86495a4 100644
--- a/examples/sentence_level/wmt_2020/multilingual/monotransquest.py
+++ b/examples/sentence_level/wmt_2020/multilingual/monotransquest.py
@@ -7,7 +7,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
@@ -20,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
languages = {
"EN-DE": ["examples/sentence_level/wmt_2020/en_de/data/en-de/train.ende.df.short.tsv",
diff --git a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py
index 6306928..8c1f929 100644
--- a/examples/sentence_level/wmt_2020/ne_en/monotransquest.py
+++ b/examples/sentence_level/wmt_2020/ne_en/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
@@ -19,8 +18,7 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
+
TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv"
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index 480f75b..eadb2c0 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -8,7 +8,6 @@ import time
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
@@ -17,13 +16,8 @@ from examples.sentence_level.wmt_2020.common.util.postprocess import format_subm
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.ro_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import models
-from transquest.algo.sentence_level.siamesetransquest.evaluation.embedding_similarity_evaluator import \
- EmbeddingSimilarityEvaluator
from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
-from transquest.algo.sentence_level.siamesetransquest.losses.cosine_similarity_loss import CosineSimilarityLoss
-from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample
from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
@@ -76,61 +70,11 @@ if siamesetransquest_config["evaluate_during_training"]:
siamesetransquest_config['cache_dir']):
shutil.rmtree(siamesetransquest_config['cache_dir'])
- os.makedirs(siamesetransquest_config['cache_dir'])
-
train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
-
- # word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
- # 'max_seq_length'])
- #
- # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
- # pooling_mode_mean_tokens=True,
- # pooling_mode_cls_token=False,
- # pooling_mode_max_tokens=False)
-
model = SiameseTransQuestModel(MODEL_NAME)
-
- train_samples = []
- eval_samples = []
- dev_samples = []
- test_samples = []
-
- for index, row in train_df.iterrows():
- score = float(row["labels"])
- inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
- train_samples.append(inp_example)
-
- for index, row in eval_df.iterrows():
- score = float(row["labels"])
- inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
- eval_samples.append(inp_example)
-
- train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=siamesetransquest_config['train_batch_size'])
- train_loss = CosineSimilarityLoss(model=model)
-
- evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval')
- warmup_steps = math.ceil(len(train_dataloader) * siamesetransquest_config["num_train_epochs"] * 0.1)
-
- model.fit(train_objectives=[(train_dataloader, train_loss)],
- evaluator=evaluator,
- epochs=siamesetransquest_config['num_train_epochs'],
- evaluation_steps=100,
- optimizer_params={'lr': siamesetransquest_config["learning_rate"],
- 'eps': siamesetransquest_config["adam_epsilon"],
- 'correct_bias': False},
- warmup_steps=warmup_steps,
- output_path=siamesetransquest_config['best_model_dir'])
+ model.train_model(train_df, eval_df)
model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir'])
-
- for index, row in dev.iterrows():
- score = float(row["labels"])
- inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
- dev_samples.append(inp_example)
-
- evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples)
- model.evaluate(evaluator,
- output_path=siamesetransquest_config['cache_dir'])
dev_preds[:, i] = model.predict(dev_sentence_pairs)
test_preds[:, i] = model.predict(test_sentence_pairs)
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py
index 3f02ade..8d346b3 100644
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest_config.py
@@ -30,13 +30,13 @@ siamesetransquest_config = {
'max_grad_norm': 1.0,
'do_lower_case': False,
- 'logging_steps': 300,
- 'save_steps': 300,
+ 'logging_steps': 100,
+ 'save_steps': 100,
"no_cache": False,
'save_model_every_epoch': True,
'n_fold': 1,
'evaluate_during_training': True,
- 'evaluate_during_training_steps': 300,
+ 'evaluate_during_training_steps': 100,
"evaluate_during_training_verbose": True,
'use_cached_eval_features': False,
'save_eval_checkpoints': True,
diff --git a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py
index 2aa3c1a..c655077 100644
--- a/examples/sentence_level/wmt_2020/ru_en/monotransquest.py
+++ b/examples/sentence_level/wmt_2020/ru_en/monotransquest.py
@@ -6,7 +6,7 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
+
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
@@ -19,8 +19,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/train.ruen.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/dev.ruen.df.short.tsv"
diff --git a/examples/sentence_level/wmt_2020/si_en/monotransquest.py b/examples/sentence_level/wmt_2020/si_en/monotransquest.py
index 811ac6a..03751a0 100644
--- a/examples/sentence_level/wmt_2020/si_en/monotransquest.py
+++ b/examples/sentence_level/wmt_2020/si_en/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/train.sien.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/si_en/data/si-en/dev.sien.df.short.tsv"
diff --git a/examples/sentence_level/wmt_2020_task2/common/util/download.py b/examples/sentence_level/wmt_2020_task2/common/util/download.py
deleted file mode 100644
index feacb5a..0000000
--- a/examples/sentence_level/wmt_2020_task2/common/util/download.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from google_drive_downloader import GoogleDriveDownloader as gdd
-
-def download_from_google_drive(file_id, path):
- gdd.download_file_from_google_drive(file_id=file_id,
- dest_path= path + "/model.zip",
- unzip=True) \ No newline at end of file
diff --git a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py
index c370e06..38aa05b 100644
--- a/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py
+++ b/examples/sentence_level/wmt_2020_task2/en_de/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission
@@ -19,8 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/train"
DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_de/data/en-de/dev"
diff --git a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py
index 68afd3d..a463e31 100644
--- a/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py
+++ b/examples/sentence_level/wmt_2020_task2/en_zh/monotransquest.py
@@ -6,7 +6,6 @@ import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
-from examples.sentence_level.wmt_2020_task2.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020_task2.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020_task2.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020_task2.common.util.postprocess import format_submission
@@ -19,9 +18,6 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
-
TRAIN_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/train"
DEV_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/dev"
TEST_FOLDER = "examples/sentence_level/wmt_2020_task2/en_zh/data/en-zh/test-blind"
diff --git a/transquest/algo/sentence_level/monotransquest/model_args.py b/transquest/algo/sentence_level/monotransquest/model_args.py
index 47ccd15..658c3c7 100644
--- a/transquest/algo/sentence_level/monotransquest/model_args.py
+++ b/transquest/algo/sentence_level/monotransquest/model_args.py
@@ -1,123 +1,6 @@
-import json
-import os
-import sys
-from dataclasses import asdict, dataclass, field
-from multiprocessing import cpu_count
+from dataclasses import dataclass, field
-
-def get_default_process_count():
- process_count = cpu_count() - 2 if cpu_count() > 2 else 1
- if sys.platform == "win32":
- process_count = min(process_count, 61)
-
- return process_count
-
-
-def get_special_tokens():
- return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
-
-
-@dataclass
-class TransQuestArgs:
- adam_epsilon: float = 1e-8
- best_model_dir: str = "outputs/best_model"
- cache_dir: str = "cache_dir/"
- config: dict = field(default_factory=dict)
- cosine_schedule_num_cycles: float = 0.5
- custom_layer_parameters: list = field(default_factory=list)
- custom_parameter_groups: list = field(default_factory=list)
- dataloader_num_workers: int = 0
- do_lower_case: bool = False
- dynamic_quantize: bool = False
- early_stopping_consider_epochs: bool = False
- early_stopping_delta: float = 0
- early_stopping_metric: str = "eval_loss"
- early_stopping_metric_minimize: bool = True
- early_stopping_patience: int = 3
- encoding: str = None
- adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
- adafactor_clip_threshold: float = 1.0
- adafactor_decay_rate: float = -0.8
- adafactor_beta1: float = None
- adafactor_scale_parameter: bool = True
- adafactor_relative_step: bool = True
- adafactor_warmup_init: bool = True
- eval_batch_size: int = 8
- evaluate_during_training: bool = False
- evaluate_during_training_silent: bool = True
- evaluate_during_training_steps: int = 2000
- evaluate_during_training_verbose: bool = False
- evaluate_each_epoch: bool = True
- fp16: bool = True
- gradient_accumulation_steps: int = 1
- learning_rate: float = 4e-5
- local_rank: int = -1
- logging_steps: int = 50
- manual_seed: int = None
- max_grad_norm: float = 1.0
- max_seq_length: int = 128
- model_name: str = None
- model_type: str = None
- multiprocessing_chunksize: int = 500
- n_gpu: int = 1
- no_cache: bool = False
- no_save: bool = False
- not_saved_args: list = field(default_factory=list)
- num_train_epochs: int = 1
- optimizer: str = "AdamW"
- output_dir: str = "outputs/"
- overwrite_output_dir: bool = False
- process_count: int = field(default_factory=get_default_process_count)
- polynomial_decay_schedule_lr_end: float = 1e-7
- polynomial_decay_schedule_power: float = 1.0
- quantized_model: bool = False
- reprocess_input_data: bool = True
- save_best_model: bool = True
- save_eval_checkpoints: bool = True
- save_model_every_epoch: bool = True
- save_optimizer_and_scheduler: bool = True
- save_recent_only: bool = True
- save_steps: int = 2000
- scheduler: str = "linear_schedule_with_warmup"
- silent: bool = False
- skip_special_tokens: bool = True
- tensorboard_dir: str = None
- thread_count: int = None
- train_batch_size: int = 8
- train_custom_parameters_only: bool = False
- use_cached_eval_features: bool = False
- use_early_stopping: bool = False
- use_multiprocessing: bool = True
- wandb_kwargs: dict = field(default_factory=dict)
- wandb_project: str = None
- warmup_ratio: float = 0.06
- warmup_steps: int = 0
- weight_decay: float = 0.0
-
- def update_from_dict(self, new_values):
- if isinstance(new_values, dict):
- for key, value in new_values.items():
- setattr(self, key, value)
- else:
- raise (TypeError(f"{new_values} is not a Python dict."))
-
- def get_args_for_saving(self):
- args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args}
- return args_for_saving
-
- def save(self, output_dir):
- os.makedirs(output_dir, exist_ok=True)
- with open(os.path.join(output_dir, "model_args.json"), "w") as f:
- json.dump(self.get_args_for_saving(), f)
-
- def load(self, input_dir):
- if input_dir:
- model_args_file = os.path.join(input_dir, "model_args.json")
- if os.path.isfile(model_args_file):
- with open(model_args_file, "r") as f:
- model_args = json.load(f)
-
- self.update_from_dict(model_args)
+from transquest.model_args import TransQuestArgs
@dataclass
diff --git a/transquest/algo/sentence_level/siamesetransquest/model_args.py b/transquest/algo/sentence_level/siamesetransquest/model_args.py
new file mode 100644
index 0000000..aecce25
--- /dev/null
+++ b/transquest/algo/sentence_level/siamesetransquest/model_args.py
@@ -0,0 +1,27 @@
+from dataclasses import dataclass, field
+
+from transquest.model_args import TransQuestArgs
+
+
+@dataclass
+class SiameseTransQuestArgs(TransQuestArgs):
+ """
+ Model args for a SiameseTransQuest
+ """
+
+ model_class: str = "SiameseTransQuestModel"
+ labels_list: list = field(default_factory=list)
+ labels_map: dict = field(default_factory=dict)
+ lazy_delimiter: str = "\t"
+ lazy_labels_column: int = 1
+ lazy_loading: bool = False
+ lazy_loading_start_line: int = 1
+ lazy_text_a_column: bool = None
+ lazy_text_b_column: bool = None
+ lazy_text_column: int = 0
+ onnx: bool = False
+ regression: bool = True
+ sliding_window: bool = False
+ special_tokens_list: list = field(default_factory=list)
+ stride: float = 0.8
+ tie_value: int = 1 \ No newline at end of file
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index a271420..aeeb956 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -1,6 +1,7 @@
import json
import logging
import os
+import random
import shutil
from collections import OrderedDict
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable
@@ -20,12 +21,16 @@ from tqdm.autonotebook import trange
import math
import queue
-from . import __DOWNLOAD_SERVER__, models
+
from . import __version__
from transquest.algo.sentence_level.siamesetransquest.util import http_get, import_from_string, batch_to_device
from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator
from transquest.algo.sentence_level.siamesetransquest.models import Transformer, Pooling
+from .evaluation.embedding_similarity_evaluator import EmbeddingSimilarityEvaluator
+from .losses.cosine_similarity_loss import CosineSimilarityLoss
+from .model_args import SiameseTransQuestArgs
+from .readers.input_example import InputExample
logger = logging.getLogger(__name__)
@@ -35,103 +40,33 @@ class SiameseTransQuestModel(nn.Sequential):
Loads or create a SentenceTransformer model, that can be used to map sentences / text to embeddings.
:param model_name_or_path: If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from Huggingface models repository with that name.
- :param modules: This parameter can be used to create custom SentenceTransformer models from scratch.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
"""
- def __init__(self, model_name_or_path: str = None, device: str = None):
- save_model_to = None
+ def __init__(self, model_name: str = None, args=None, device: str = None):
+
+ self.args = self._load_model_args(model_name)
+
+ if isinstance(args, dict):
+ self.args.update_from_dict(args)
+ elif isinstance(args, SiameseTransQuestArgs):
+ self.args = args
+
+ if self.args.thread_count:
+ torch.set_num_threads(self.args.thread_count)
- transformer_model = Transformer(model_name_or_path, max_seq_length=80)
+ if self.args.manual_seed:
+ random.seed(self.args.manual_seed)
+ np.random.seed(self.args.manual_seed)
+ torch.manual_seed(self.args.manual_seed)
+ if self.args.n_gpu > 0:
+ torch.cuda.manual_seed_all(self.args.manual_seed)
+
+ transformer_model = Transformer(model_name, max_seq_length=80)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
modules = [transformer_model, pooling_model]
- # if model_name_or_path is not None and model_name_or_path != "":
- # logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))
- # model_path = model_name_or_path
- #
- # if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'):
- # logger.info("Did not find folder {}".format(model_path))
- #
- # if '\\' in model_path or model_path.count('/') > 1:
- # raise AttributeError("Path {} not found".format(model_path))
- #
- # model_path = __DOWNLOAD_SERVER__ + model_path + '.zip'
- # logger.info("Search model on server: {}".format(model_path))
- #
- # if model_path.startswith('http://') or model_path.startswith('https://'):
- # model_url = model_path
- # folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end
- #
- # cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
- # if cache_folder is None:
- # try:
- # from torch.hub import _get_torch_home
- # torch_cache_home = _get_torch_home()
- # except ImportError:
- # torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
- #
- # cache_folder = os.path.join(torch_cache_home, 'sentence_transformers')
- #
- # model_path = os.path.join(cache_folder, folder_name)
- #
- # if not os.path.exists(model_path) or not os.listdir(model_path):
- # if os.path.exists(model_path):
- # os.remove(model_path)
- #
- # model_url = model_url.rstrip("/")
- # logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))
- #
- # model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part"
- # try:
- # zip_save_path = os.path.join(model_path_tmp, 'model.zip')
- # http_get(model_url, zip_save_path)
- # with ZipFile(zip_save_path, 'r') as zip:
- # zip.extractall(model_path_tmp)
- # os.remove(zip_save_path)
- # os.rename(model_path_tmp, model_path)
- # except requests.exceptions.HTTPError as e:
- # shutil.rmtree(model_path_tmp)
- # if e.response.status_code == 429:
- # raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path))
- #
- # if e.response.status_code == 404:
- # logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url))
- # logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path))
- #
- # save_model_to = model_path
- # model_path = None
- # transformer_model = Transformer(model_name_or_path)
- # pooling_model = Pooling(transformer_model.get_word_embedding_dimension())
- # modules = [transformer_model, pooling_model]
- # else:
- # raise e
- # except Exception as e:
- # shutil.rmtree(model_path)
- # raise e
- #
- #
- # # #### Load from disk
- # if model_path is not None:
- # logger.info("Load SentenceTransformer from folder: {}".format(model_path))
- #
- # if os.path.exists(os.path.join(model_path, 'config.json')):
- # with open(os.path.join(model_path, 'config.json')) as fIn:
- # config = json.load(fIn)
- # if config['__version__'] > __version__:
- # logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__))
- #
- # with open(os.path.join(model_path, 'modules.json')) as fIn:
- # contained_modules = json.load(fIn)
- #
- # modules = OrderedDict()
- # for module_config in contained_modules:
- # module_class = import_from_string(module_config['type'])
- # module = module_class.load(os.path.join(model_path, module_config['path']))
- # modules[module_config['name']] = module
- #
- #
if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
@@ -142,10 +77,6 @@ class SiameseTransQuestModel(nn.Sequential):
self._target_device = torch.device(device)
- #We created a new model from scratch based on a Transformer model. Save the SBERT model in the cache folder
- if save_model_to is not None:
- self.save(save_model_to)
-
def encode(self, sentences: Union[str, List[str], List[int]],
batch_size: int = 32,
show_progress_bar: bool = None,
@@ -345,7 +276,6 @@ class SiameseTransQuestModel(nn.Sequential):
except queue.Empty:
break
-
def get_max_seq_length(self):
"""
Returns the maximal sequence length for input the model accepts. Longer inputs will be truncated
@@ -450,6 +380,40 @@ class SiameseTransQuestModel(nn.Sequential):
else:
return sum([len(t) for t in text]) #Sum of length of individual strings
+ def train_model(self, train_df, eval_df, args=None, output_dir=None, verbose=True):
+
+ train_samples = []
+ for index, row in train_df.iterrows():
+ score = float(row["labels"])
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ train_samples.append(inp_example)
+
+ eval_samples = []
+ for index, row in eval_df.iterrows():
+ score = float(row["labels"])
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ eval_samples.append(inp_example)
+
+ train_dataloader = DataLoader(train_samples, shuffle=True,
+ batch_size=self.args.train_batch_size)
+ train_loss = CosineSimilarityLoss(model=self)
+
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval')
+ warmup_steps = math.ceil(len(train_dataloader) * self.args.num_train_epochs * 0.1)
+
+ self.fit(train_objectives=[(train_dataloader, train_loss)],
+ evaluator=evaluator,
+ epochs=self.args.num_train_epochs,
+ evaluation_steps=self.args.evaluate_during_training_steps,
+ optimizer_params={'lr': self.args.learning_rate,
+ 'eps': self.args.adam_epsilon,
+ 'correct_bias': False},
+ warmup_steps=warmup_steps,
+ weight_decay=self.args.weight_decay,
+ max_grad_norm=self.args.max_grad_norm,
+ output_path=self.args.best_model_dir)
+
+
def fit(self,
train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
evaluator: SentenceEvaluator = None,
@@ -671,6 +635,15 @@ class SiameseTransQuestModel(nn.Sequential):
first_tuple = next(gen)
return first_tuple[1].device
+ def save_model_args(self, output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ self.args.save(output_dir)
+
+ def _load_model_args(self, input_dir):
+ args = SiameseTransQuestArgs()
+ args.load(input_dir)
+ return args
+
@property
def tokenizer(self):
"""
diff --git a/transquest/algo/word_level/microtransquest/model_args.py b/transquest/algo/word_level/microtransquest/model_args.py
index 4a5ca66..d2ef919 100644
--- a/transquest/algo/word_level/microtransquest/model_args.py
+++ b/transquest/algo/word_level/microtransquest/model_args.py
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
-from transquest.algo.sentence_level.monotransquest.model_args import TransQuestArgs
+from transquest.model_args import TransQuestArgs
@dataclass
diff --git a/transquest/model_args.py b/transquest/model_args.py
new file mode 100644
index 0000000..af8be51
--- /dev/null
+++ b/transquest/model_args.py
@@ -0,0 +1,120 @@
+import json
+import os
+import sys
+from dataclasses import dataclass, field, asdict
+from multiprocessing import cpu_count
+
+
+def get_default_process_count():
+ process_count = cpu_count() - 2 if cpu_count() > 2 else 1
+ if sys.platform == "win32":
+ process_count = min(process_count, 61)
+
+ return process_count
+
+
+def get_special_tokens():
+ return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
+
+
+@dataclass
+class TransQuestArgs:
+ adam_epsilon: float = 1e-8
+ best_model_dir: str = "outputs/best_model"
+ cache_dir: str = "cache_dir/"
+ config: dict = field(default_factory=dict)
+ cosine_schedule_num_cycles: float = 0.5
+ custom_layer_parameters: list = field(default_factory=list)
+ custom_parameter_groups: list = field(default_factory=list)
+ dataloader_num_workers: int = 0
+ do_lower_case: bool = False
+ dynamic_quantize: bool = False
+ early_stopping_consider_epochs: bool = False
+ early_stopping_delta: float = 0
+ early_stopping_metric: str = "eval_loss"
+ early_stopping_metric_minimize: bool = True
+ early_stopping_patience: int = 3
+ encoding: str = None
+ adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
+ adafactor_clip_threshold: float = 1.0
+ adafactor_decay_rate: float = -0.8
+ adafactor_beta1: float = None
+ adafactor_scale_parameter: bool = True
+ adafactor_relative_step: bool = True
+ adafactor_warmup_init: bool = True
+ eval_batch_size: int = 8
+ evaluate_during_training: bool = False
+ evaluate_during_training_silent: bool = True
+ evaluate_during_training_steps: int = 2000
+ evaluate_during_training_verbose: bool = False
+ evaluate_each_epoch: bool = True
+ fp16: bool = True
+ gradient_accumulation_steps: int = 1
+ learning_rate: float = 4e-5
+ local_rank: int = -1
+ logging_steps: int = 50
+ manual_seed: int = None
+ max_grad_norm: float = 1.0
+ max_seq_length: int = 128
+ model_name: str = None
+ model_type: str = None
+ multiprocessing_chunksize: int = 500
+ n_gpu: int = 1
+ no_cache: bool = False
+ no_save: bool = False
+ not_saved_args: list = field(default_factory=list)
+ num_train_epochs: int = 1
+ optimizer: str = "AdamW"
+ output_dir: str = "outputs/"
+ overwrite_output_dir: bool = False
+ process_count: int = field(default_factory=get_default_process_count)
+ polynomial_decay_schedule_lr_end: float = 1e-7
+ polynomial_decay_schedule_power: float = 1.0
+ quantized_model: bool = False
+ reprocess_input_data: bool = True
+ save_best_model: bool = True
+ save_eval_checkpoints: bool = True
+ save_model_every_epoch: bool = True
+ save_optimizer_and_scheduler: bool = True
+ save_recent_only: bool = True
+ save_steps: int = 2000
+ scheduler: str = "linear_schedule_with_warmup"
+ silent: bool = False
+ skip_special_tokens: bool = True
+ tensorboard_dir: str = None
+ thread_count: int = None
+ train_batch_size: int = 8
+ train_custom_parameters_only: bool = False
+ use_cached_eval_features: bool = False
+ use_early_stopping: bool = False
+ use_multiprocessing: bool = True
+ wandb_kwargs: dict = field(default_factory=dict)
+ wandb_project: str = None
+ warmup_ratio: float = 0.06
+ warmup_steps: int = 0
+ weight_decay: float = 0.0
+
+ def update_from_dict(self, new_values):
+ if isinstance(new_values, dict):
+ for key, value in new_values.items():
+ setattr(self, key, value)
+ else:
+ raise (TypeError(f"{new_values} is not a Python dict."))
+
+ def get_args_for_saving(self):
+ args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args}
+ return args_for_saving
+
+ def save(self, output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ with open(os.path.join(output_dir, "model_args.json"), "w") as f:
+ json.dump(self.get_args_for_saving(), f)
+
+ def load(self, input_dir):
+ if input_dir:
+ model_args_file = os.path.join(input_dir, "model_args.json")
+ if os.path.isfile(model_args_file):
+ with open(model_args_file, "r") as f:
+ model_args = json.load(f)
+
+ self.update_from_dict(model_args) \ No newline at end of file