Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-24 02:12:44 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-24 02:12:44 +0300
commitfc1830aedeb9849cbe086fcaa707a67cebc8850d (patch)
treeabd7ee9389fd04b577ba5c754337bbd2ab127c6d
parentb583f20acc765c48ec0e08363f6707c5f170efaf (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 6c62242..d9351be 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -83,7 +83,7 @@ class SiameseTransQuestModel:
train_dataloader = DataLoader(train_samples, shuffle=True,
batch_size=self.args.train_batch_size)
- train_loss = CosineSimilarityLoss(model=self)
+ train_loss = CosineSimilarityLoss(model=self.model)
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval')
warmup_steps = math.ceil(len(train_dataloader) * self.args.num_train_epochs * 0.1)