diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-25 21:36:58 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-25 21:36:58 +0300 |
commit | 1fbabb9b0e3531812b7d7d69cc66837dad5d9039 (patch) | |
tree | 85aefa9f95e792699f4c981cf1d928152d023d24 | |
parent | d24e8435c1166905bb355470823baa69e16f115c (diff) |
057: Code Refactoring - Siamese Architectures
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py | 11 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 4 |
2 files changed, 11 insertions, 4 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py index 7834e35..25e1995 100644 --- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py +++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py @@ -27,9 +27,16 @@ logger = logging.getLogger(__name__) class SiameseTransformer(nn.Sequential): - def __init__(self, model_name: str = None, max_seq_length: int = 100, device: str = None): + def __init__(self, model_name: str = None, args=None, device: str = None): - transformer_model = Transformer(model_name, max_seq_length=max_seq_length) + self.args = self.load_model_args(model_name) + + if isinstance(args, dict): + self.args.update_from_dict(args) + elif isinstance(args, SiameseTransQuestArgs): + self.args = args + + transformer_model = Transformer(model_name, max_seq_length=self.args.max_seq_length) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index c8d120e..e465e7d 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -31,9 +31,9 @@ class SiameseTransQuestModel: :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. """ - def __init__(self, model_name: str = None, args=None, device: str = None): + def __init__(self, model_name: str = None, args=None): - self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length) + self.model = SiameseTransformer(model_name, args=args) self.args = self.model.load_model_args(model_name) if isinstance(args, dict): |