Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-25 21:23:05 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-25 21:23:05 +0300
commitd24e8435c1166905bb355470823baa69e16f115c (patch)
treee468bb971b0db3427764cdc49caad170ba4201fd
parentfc1830aedeb9849cbe086fcaa707a67cebc8850d (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py10
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py12
2 files changed, 12 insertions, 10 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
index f82e337..7834e35 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
@@ -18,6 +18,7 @@ from torch.utils.data import DataLoader
from tqdm.autonotebook import trange
from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator
+from transquest.algo.sentence_level.siamesetransquest.model_args import SiameseTransQuestArgs
from transquest.algo.sentence_level.siamesetransquest.models import Transformer, Pooling
from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device
@@ -577,6 +578,15 @@ class SiameseTransformer(nn.Sequential):
first_tuple = next(gen)
return first_tuple[1].device
+ def save_model_args(self, output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ self.args.save(output_dir)
+
+ def load_model_args(self, input_dir):
+ args = SiameseTransQuestArgs()
+ args.load(input_dir)
+ return args
+
@property
def tokenizer(self):
"""
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index d9351be..c8d120e 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -33,7 +33,8 @@ class SiameseTransQuestModel:
def __init__(self, model_name: str = None, args=None, device: str = None):
- self.args = self._load_model_args(model_name)
+ self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length)
+ self.args = self.model.load_model_args(model_name)
if isinstance(args, dict):
self.args.update_from_dict(args)
@@ -50,8 +51,6 @@ class SiameseTransQuestModel:
if self.args.n_gpu > 0:
torch.cuda.manual_seed_all(self.args.manual_seed)
- self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length)
-
def predict(self, to_predict, verbose=True):
sentences1 = []
sentences2 = []
@@ -100,13 +99,6 @@ class SiameseTransQuestModel:
max_grad_norm=self.args.max_grad_norm,
output_path=self.args.best_model_dir)
- def save_model_args(self, output_dir):
- os.makedirs(output_dir, exist_ok=True)
- self.args.save(output_dir)
- def _load_model_args(self, input_dir):
- args = SiameseTransQuestArgs()
- args.load(input_dir)
- return args