Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 15:34:02 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 15:34:02 +0300
commit84fdd4dafa4c82af4ce30b630b02e06991693bc0 (patch)
tree8f2259e8aa6d6c6bff4887a908e30dbdcfcbc052
parentaab2810ce71a60c79f165a3b73e27690df2de40c (diff)
057: Code Refactoring - Siamese Architectures
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py41
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py17
2 files changed, 37 insertions, 21 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index e5bdf04..760570f 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -55,6 +55,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
@@ -125,29 +128,31 @@ if siamesetransquest_config["evaluate_during_training"]:
inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
dev_samples.append(inp_example)
- evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='dev')
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples)
model.evaluate(evaluator,
- output_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt"))
+ output_path=siamesetransquest_config['cache_dir'])
+ dev_preds[:, i] = model.predict(dev_sentence_pairs)
+ test_preds[:, i] = model.predict(test_sentence_pairs)
# test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model)
# test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8)
- for index, row in test.iterrows():
- score = random.uniform(0, 1)
- inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
- test_samples.append(inp_example)
-
- evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='dev')
-
- model.evaluate(evaluator,
- output_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"),
- verbose=False)
-
- with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f:
- dev_preds[:, i] = list(map(float, f.read().splitlines()))
-
- with open(os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt")) as f:
- test_preds[:, i] = list(map(float, f.read().splitlines()))
+ # for index, row in test.iterrows():
+ # score = random.uniform(0, 1)
+ # inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ # test_samples.append(inp_example)
+ #
+ # evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples)
+ #
+ # model.evaluate(evaluator,
+ # output_path=siamesetransquest_config['cache_dir'],
+ # verbose=False)
+
+ # with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f:
+ # dev_preds[:, i] = list(map(float, f.read().splitlines()))
+ #
+ # with open(os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt")) as f:
+ # test_preds[:, i] = list(map(float, f.read().splitlines()))
dev['predictions'] = dev_preds.mean(axis=1)
test['predictions'] = test_preds.mean(axis=1)
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 33fdb91..26ab4b6 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -10,6 +10,7 @@ import numpy as np
from numpy import ndarray
import transformers
import torch
+from sklearn.metrics.pairwise import paired_cosine_distances
from torch import nn, Tensor, device
from torch.optim.optimizer import Optimizer
@@ -227,7 +228,20 @@ class SiameseTransQuestModel(nn.Sequential):
return all_embeddings
+ def predict(self, to_predict):
+ sentences1 = []
+ sentences2 = []
+ for text_1, text_2 in to_predict:
+ sentences1.append(text_1)
+ sentences2.append(text_2)
+
+ embeddings1 = self.encode(sentences1, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
+ embeddings2 = self.encode(sentences2, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
+
+ cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
+
+ return cosine_scores
def start_multi_process_pool(self, target_devices: List[str] = None):
"""
@@ -259,7 +273,6 @@ class SiameseTransQuestModel(nn.Sequential):
return {'input': input_queue, 'output': output_queue, 'processes': processes}
-
@staticmethod
def stop_multi_process_pool(pool):
"""
@@ -275,7 +288,6 @@ class SiameseTransQuestModel(nn.Sequential):
pool['input'].close()
pool['output'].close()
-
def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None):
"""
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
@@ -616,7 +628,6 @@ class SiameseTransQuestModel(nn.Sequential):
if save_best_model:
self.save(output_path)
-
@staticmethod
def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
"""