diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 19:01:34 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 19:01:34 +0300 |
commit | a1607b64499a64f93b4c3af151d8a268c08e959f (patch) | |
tree | 32ddd74473482b563f6f6ef29c52ef0e826a9e61 | |
parent | 5f1190a4dbd3cef29d52a318169b5c765c77bf23 (diff) |
057: Code Refactoring - Siamese Architectures
5 files changed, 38 insertions, 53 deletions
diff --git a/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py b/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py index d38787c..de62b47 100644 --- a/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/en_de/siamesetransquest.py @@ -1,26 +1,21 @@ -import csv + import logging -import math + import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file -from examples.sentence_level.wmt_2020.en_de.siamesetransquest_config import TEMP_DIRECTORY, DRIVE_FILE_ID, MODEL_NAME, \ - siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE, GOOGLE_DRIVE - -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from examples.sentence_level.wmt_2020.en_de.siamesetransquest_config import TEMP_DIRECTORY, MODEL_NAME, \ + siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -30,8 +25,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/en_de/data/en-de/train.ende.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/en_de/data/en-de/dev.ende.df.short.tsv" @@ -50,6 +43,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') diff --git a/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py b/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py index cde2d17..56e410c 100644 --- a/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/en_zh/siamesetransquest.py @@ -6,20 +6,16 @@ import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive + from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.en_zh.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +25,7 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) + TRAIN_FILE = "examples/sentence_level/wmt_2020/en_zh/data/en-zh/train.enzh.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/en_zh/data/en-zh/dev.enzh.df.short.tsv" @@ -49,6 +44,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') diff --git a/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py index 129cb74..2904241 100644 --- a/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/et_en/siamesetransquest.py @@ -8,18 +8,14 @@ import numpy as np from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.et_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +25,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/train.eten.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/et_en/data/et-en/dev.eten.df.short.tsv" @@ -49,6 +43,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') diff --git a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py index cd5a981..57ad114 100644 --- a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py @@ -1,25 +1,20 @@ -import csv + import logging -import math + import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.ne_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +24,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv" @@ -49,6 +42,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') diff --git a/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py index 1636db8..184afec 100644 --- a/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ru_en/siamesetransquest.py @@ -1,25 +1,22 @@ -import csv + import logging -import math + import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive + + from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.ru_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +26,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/train.ruen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ru_en/data/ru-en/dev.ruen.df.short.tsv" @@ -50,6 +45,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') |