Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-26 15:26:40 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-26 15:26:40 +0300
commit312eb1b155bb9fb3361c01609a3f92463f7dd2ed (patch)
treee796787e7eebdc412874aa4039646c66c3c53922
parentc132e32b7d423e660d21cd8b4a93561532fc9a90 (diff)
057: Code cleaning
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py2
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models.py6
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py31
3 files changed, 13 insertions, 26 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py b/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py
index b9b6657..74985b8 100644
--- a/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py
+++ b/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py
@@ -51,7 +51,7 @@ class EmbeddingSimilarityEvaluator(SentenceEvaluator):
self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
- logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
+ logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
self.show_progress_bar = show_progress_bar
self.csv_file = "similarity_evaluation" + ("_" + name if name else '') + "_results.csv"
diff --git a/transquest/algo/sentence_level/siamesetransquest/models.py b/transquest/algo/sentence_level/siamesetransquest/models.py
index ab70622..a6a3e6c 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models.py
@@ -1,4 +1,3 @@
-from transformers import AutoModel, AutoTokenizer, AutoConfig
import json
import logging
import math
@@ -17,10 +16,11 @@ from torch import nn, Tensor, device
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from tqdm.autonotebook import trange
+from transformers import AutoModel, AutoTokenizer, AutoConfig
+from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator
from transquest.algo.sentence_level.siamesetransquest.model_args import SiameseTransQuestArgs
from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device
-from transquest.algo.sentence_level.siamesetransquest.evaluation.sentence_evaluator import SentenceEvaluator
logger = logging.getLogger(__name__)
@@ -557,7 +557,6 @@ class SiameseTransformer(nn.Sequential):
else:
return sum([len(t) for t in text]) # Sum of length of individual strings
-
def fit(self,
train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
evaluator: SentenceEvaluator = None,
@@ -818,4 +817,3 @@ class SiameseTransformer(nn.Sequential):
Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
"""
self._first_module().max_seq_length = value
-
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 91fccb7..246bf4b 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -1,26 +1,19 @@
import logging
import math
-import os
import random
-
import numpy as np
import torch
from sklearn.metrics.pairwise import paired_cosine_distances
-
-
from torch.utils.data import DataLoader
-
from transquest.algo.sentence_level.siamesetransquest.evaluation.embedding_similarity_evaluator import \
EmbeddingSimilarityEvaluator
from transquest.algo.sentence_level.siamesetransquest.losses.cosine_similarity_loss import CosineSimilarityLoss
from transquest.algo.sentence_level.siamesetransquest.model_args import SiameseTransQuestArgs
from transquest.algo.sentence_level.siamesetransquest.models import SiameseTransformer
-
from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample
-
logger = logging.getLogger(__name__)
@@ -89,17 +82,13 @@ class SiameseTransQuestModel:
warmup_steps = math.ceil(len(train_dataloader) * self.args.num_train_epochs * 0.1)
self.model.fit(train_objectives=[(train_dataloader, train_loss)],
- evaluator=evaluator,
- epochs=self.args.num_train_epochs,
- evaluation_steps=self.args.evaluate_during_training_steps,
- optimizer_params={'lr': self.args.learning_rate,
- 'eps': self.args.adam_epsilon,
- 'correct_bias': False},
- warmup_steps=warmup_steps,
- weight_decay=self.args.weight_decay,
- max_grad_norm=self.args.max_grad_norm,
- output_path=self.args.best_model_dir)
-
-
-
-
+ evaluator=evaluator,
+ epochs=self.args.num_train_epochs,
+ evaluation_steps=self.args.evaluate_during_training_steps,
+ optimizer_params={'lr': self.args.learning_rate,
+ 'eps': self.args.adam_epsilon,
+ 'correct_bias': False},
+ warmup_steps=warmup_steps,
+ weight_decay=self.args.weight_decay,
+ max_grad_norm=self.args.max_grad_norm,
+ output_path=self.args.best_model_dir)