Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-25 21:36:58 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-25 21:36:58 +0300
commit1fbabb9b0e3531812b7d7d69cc66837dad5d9039 (patch)
tree85aefa9f95e792699f4c981cf1d928152d023d24
parentd24e8435c1166905bb355470823baa69e16f115c (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py11
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py4
2 files changed, 11 insertions, 4 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
index 7834e35..25e1995 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
@@ -27,9 +27,16 @@ logger = logging.getLogger(__name__)
class SiameseTransformer(nn.Sequential):
- def __init__(self, model_name: str = None, max_seq_length: int = 100, device: str = None):
+ def __init__(self, model_name: str = None, args=None, device: str = None):
- transformer_model = Transformer(model_name, max_seq_length=max_seq_length)
+ self.args = self.load_model_args(model_name)
+
+ if isinstance(args, dict):
+ self.args.update_from_dict(args)
+ elif isinstance(args, SiameseTransQuestArgs):
+ self.args = args
+
+ transformer_model = Transformer(model_name, max_seq_length=self.args.max_seq_length)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index c8d120e..e465e7d 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -31,9 +31,9 @@ class SiameseTransQuestModel:
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
"""
- def __init__(self, model_name: str = None, args=None, device: str = None):
+ def __init__(self, model_name: str = None, args=None):
- self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length)
+ self.model = SiameseTransformer(model_name, args=args)
self.args = self.model.load_model_args(model_name)
if isinstance(args, dict):