diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-24 02:04:26 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-24 02:04:26 +0300 |
commit | b583f20acc765c48ec0e08363f6707c5f170efaf (patch) | |
tree | 8ce2924b4e9f7c8f1e9e587a14b53161bb1e30f5 | |
parent | ac8e402339606301bc8bd775044360ee7ff4a9e2 (diff) |
057: Code Refactoring - Siamese Architectures
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py | 4 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 2 |
2 files changed, 3 insertions, 3 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py index 02c1c73..f82e337 100644 --- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py +++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py @@ -26,9 +26,9 @@ logger = logging.getLogger(__name__) class SiameseTransformer(nn.Sequential): - def __init__(self, model_name: str = None, args=None, device: str = None): + def __init__(self, model_name: str = None, max_seq_length: int = 100, device: str = None): - transformer_model = Transformer(model_name, max_seq_length=args.max_seq_length) + transformer_model = Transformer(model_name, max_seq_length=max_seq_length) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index 5d2024e..6c62242 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -50,7 +50,7 @@ class SiameseTransQuestModel: if self.args.n_gpu > 0: torch.cuda.manual_seed_all(self.args.manual_seed) - self.model = SiameseTransformer(model_name, args=args) + self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length) def predict(self, to_predict, verbose=True): sentences1 = [] |