Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-24 02:04:26 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-24 02:04:26 +0300
commitb583f20acc765c48ec0e08363f6707c5f170efaf (patch)
tree8ce2924b4e9f7c8f1e9e587a14b53161bb1e30f5
parentac8e402339606301bc8bd775044360ee7ff4a9e2 (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py4
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py2
2 files changed, 3 insertions, 3 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
index 02c1c73..f82e337 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models/siamese_transformer.py
@@ -26,9 +26,9 @@ logger = logging.getLogger(__name__)
class SiameseTransformer(nn.Sequential):
- def __init__(self, model_name: str = None, args=None, device: str = None):
+ def __init__(self, model_name: str = None, max_seq_length: int = 100, device: str = None):
- transformer_model = Transformer(model_name, max_seq_length=args.max_seq_length)
+ transformer_model = Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 5d2024e..6c62242 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -50,7 +50,7 @@ class SiameseTransQuestModel:
if self.args.n_gpu > 0:
torch.cuda.manual_seed_all(self.args.manual_seed)
- self.model = SiameseTransformer(model_name, args=args)
+ self.model = SiameseTransformer(model_name, max_seq_length=self.args.max_seq_length)
def predict(self, to_predict, verbose=True):
sentences1 = []