Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 19:01:34 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 19:01:34 +0300
commita1607b64499a64f93b4c3af151d8a268c08e959f (patch)
tree32ddd74473482b563f6f6ef29c52ef0e826a9e61
parent5f1190a4dbd3cef29d52a318169b5c765c77bf23 (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--examples/sentence_level/wmt_2020/en_de/siamesetransquest.py24
-rw-r--r--examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py16
-rw-r--r--examples/sentence_level/wmt_2020/et_en/siamesetransquest.py13
-rw-r--r--examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py18
-rw-r--r--examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py20
5 files changed, 38 insertions, 53 deletions
diff --git a/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py b/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py
index d38787c..de62b47 100644
--- a/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py
@@ -1,26 +1,21 @@
-import csv
+
import logging
-import math
+
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
+
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
-from examples.sentence_level.wmt_2020.en_de.siamesetransquest_config import TEMP_DIRECTORY, DRIVE_FILE_ID, MODEL_NAME, \
- siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE, GOOGLE_DRIVE
-
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from examples.sentence_level.wmt_2020.en_de.siamesetransquest_config import TEMP_DIRECTORY, MODEL_NAME, \
+ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -30,8 +25,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/en_de/data/en-de/train.ende.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/en_de/data/en-de/dev.ende.df.short.tsv"
@@ -50,6 +43,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
diff --git a/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py b/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py
index cde2d17..56e410c 100644
--- a/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py
@@ -6,20 +6,16 @@ import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
+
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.en_zh.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +25,7 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
+
TRAIN_FILE = "examples/sentence_level/wmt_2020/en_zh/data/en-zh/train.enzh.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/en_zh/data/en-zh/dev.enzh.df.short.tsv"
@@ -49,6 +44,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
diff --git a/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py
index 129cb74..2904241 100644
--- a/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py
@@ -8,18 +8,14 @@ import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.et_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +25,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/train.eten.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/dev.eten.df.short.tsv"
@@ -49,6 +43,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
diff --git a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
index cd5a981..57ad114 100644
--- a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
@@ -1,25 +1,20 @@
-import csv
+
import logging
-import math
+
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.ne_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +24,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv"
@@ -49,6 +42,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')
diff --git a/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py
index 1636db8..184afec 100644
--- a/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py
@@ -1,25 +1,22 @@
-import csv
+
import logging
-import math
+
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
+
+
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.ru_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +26,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/train.ruen.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/dev.ruen.df.short.tsv"
@@ -50,6 +45,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')