Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-21 17:45:34 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-21 17:45:34 +0300
commit2c6864150cfda49f70125892625538568e897ce9 (patch)
tree5e5381014214650ce0311c5da7642aaaf88ff7c5 /examples
parent8b5f6ae09de4e20e24ec8fce7208fbedba5e4955 (diff)
057: Code Refactoring - Siamese Architectures
Diffstat (limited to 'examples')
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py82
1 files changed, 37 insertions, 45 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index a975c1f..5799286 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -2,6 +2,7 @@ import csv
import logging
import math
import os
+import random
import shutil
import time
@@ -16,20 +17,14 @@ from examples.sentence_level.wmt_2020.common.util.postprocess import format_subm
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.ro_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-# from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
-# SiameseTransQuestModel
-# from transquest.algo.sentence_level.siamesetransquest import models, losses
-# from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-# from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
-from transquest.algo.sentence_level.siamesetransquest import models
-from transquest.algo.sentence_level.siamesetransquest.datasets.sentences_dataset import SentencesDataset
+
from transquest.algo.sentence_level.siamesetransquest.evaluation.embedding_similarity_evaluator import \
EmbeddingSimilarityEvaluator
from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
from transquest.algo.sentence_level.siamesetransquest.losses.cosine_similarity_loss import CosineSimilarityLoss
from transquest.algo.sentence_level.siamesetransquest.models.pooling import Pooling
from transquest.algo.sentence_level.siamesetransquest.models.transformer import Transformer
-from transquest.algo.sentence_level.siamesetransquest.readers.qe_data_reader import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample
from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
@@ -81,18 +76,7 @@ if siamesetransquest_config["evaluate_during_training"]:
os.makedirs(siamesetransquest_config['cache_dir'])
train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
- train_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "train.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- eval_df.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "eval_df.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- dev.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "dev.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
- test.to_csv(os.path.join(siamesetransquest_config['cache_dir'], "test.tsv"), header=True, sep='\t',
- index=False, quoting=csv.QUOTE_NONE)
-
- sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1,
- score_col_idx=2,
- normalize_scores=False, min_score=0, max_score=1, header=True)
+
word_embedding_model = Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
'max_seq_length'])
@@ -103,21 +87,28 @@ if siamesetransquest_config["evaluate_during_training"]:
pooling_mode_max_tokens=False)
model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model])
- train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model)
- train_dataloader = DataLoader(train_data, shuffle=True,
- batch_size=siamesetransquest_config['train_batch_size'])
- train_loss = CosineSimilarityLoss(model=model)
- eval_data = SentencesDataset(examples=sts_reader.get_examples('eval_df.tsv'), model=model)
- eval_dataloader = DataLoader(eval_data, shuffle=False,
- batch_size=siamesetransquest_config['train_batch_size'])
- evaluator = EmbeddingSimilarityEvaluator(eval_dataloader)
+ train_samples = []
+ eval_samples = []
+ dev_samples = []
+ test_samples = []
- warmup_steps = math.ceil(
- len(train_data) * siamesetransquest_config["num_train_epochs"] / siamesetransquest_config[
- 'train_batch_size'] * 0.1)
+ for index, row in train_df.iterrows():
+ score = float(row["labels"])
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ train_samples.append(inp_example)
+
+ for index, row in eval_df.iterrows():
+ score = float(row["labels"])
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ eval_samples.append(inp_example)
+
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=siamesetransquest_config['train_batch_size'])
+ train_loss = CosineSimilarityLoss(model=model)
+
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_samples, name='eval')
+ warmup_steps = math.ceil(len(train_dataloader) * siamesetransquest_config["num_train_epochs"] * 0.1)
- start = time.time()
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=siamesetransquest_config['num_train_epochs'],
@@ -127,26 +118,27 @@ if siamesetransquest_config["evaluate_during_training"]:
'correct_bias': False},
warmup_steps=warmup_steps,
output_path=siamesetransquest_config['best_model_dir'])
- end = time.time()
- print("Training time")
- print(end - start)
model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir'])
- dev_data = SentencesDataset(examples=sts_reader.get_examples("dev.tsv"), model=model)
- dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=8)
- evaluator = EmbeddingSimilarityEvaluator(dev_dataloader)
- start = time.time()
+ for index, row in dev.iterrows():
+ score = float(row["labels"])
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ dev_samples.append(inp_example)
+
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='dev')
model.evaluate(evaluator,
output_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt"))
- end = time.time()
- print("Testing time")
- print(end - start)
+ # test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model)
+ # test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8)
+
+ for index, row in test.iterrows():
+ score = random.uniform(0, 1)
+ inp_example = InputExample(texts=[row['text_a'], row['text_b']], label=score)
+ test_samples.append(inp_example)
- test_data = SentencesDataset(examples=sts_reader.get_examples("test.tsv", test_file=True), model=model)
- test_dataloader = DataLoader(test_data, shuffle=False, batch_size=8)
- evaluator = EmbeddingSimilarityEvaluator(test_dataloader)
+ evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='dev')
model.evaluate(evaluator,
output_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"),