diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-25 21:23:05 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-25 21:23:05 +0300 |
commit | d24e8435c1166905bb355470823baa69e16f115c (patch) | |
tree | e468bb971b0db3427764cdc49caad170ba4201fd | |
parent | fc1830aedeb9849cbe086fcaa707a67cebc8850d (diff) |
057: Code Refactoring - Siamese Architectures
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py | 10 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 12 |
2 files changed, 12 insertions, 10 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py index f82e337..7834e35 100644 --- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py +++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py @@ -18,6 +18,7 @@ from torch.utils.data import DataLoader from tqdm.autonotebook import trange from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator +from transquest.algo.sentence_level.siamesetransquest.model_args import SiameseTransQuestArgs from transquest.algo.sentence_level.siamesetransquest.models import Transformer, Pooling from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device @@ -577,6 +578,15 @@ class SiameseTransformer(nn.Sequential): first_tuple = next(gen) return first_tuple[1].device + def save_model_args(self, output_dir): + os.makedirs(output_dir, exist_ok=True) + self.args.save(output_dir) + + def load_model_args(self, input_dir): + args = SiameseTransQuestArgs() + args.load(input_dir) + return args + @property def tokenizer(self): """ diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index d9351be..c8d120e 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -33,7 +33,8 @@ class SiameseTransQuestModel: def __init__(self, model_name: str = None, args=None, device: str = None): - self.args = self._load_model_args(model_name) + self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length) + self.args = self.model.load_model_args(model_name) if isinstance(args, dict): self.args.update_from_dict(args) @@ -50,8 +51,6 @@ class SiameseTransQuestModel: if self.args.n_gpu > 0: torch.cuda.manual_seed_all(self.args.manual_seed) - self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length) - def predict(self, to_predict, verbose=True): sentences1 = [] sentences2 = [] @@ -100,13 +99,6 @@ class SiameseTransQuestModel: max_grad_norm=self.args.max_grad_norm, output_path=self.args.best_model_dir) - def save_model_args(self, output_dir): - os.makedirs(output_dir, exist_ok=True) - self.args.save(output_dir) - def _load_model_args(self, input_dir): - args = SiameseTransQuestArgs() - args.load(input_dir) - return args |