diff options
Diffstat (limited to 'examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py')
-rw-r--r-- | examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py index cd5a981..57ad114 100644 --- a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py @@ -1,25 +1,20 @@ -import csv + import logging -import math + import os import shutil import numpy as np from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader -from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file from examples.sentence_level.wmt_2020.ne_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \ siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE -from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \ - SiameseTransQuestModel -from transquest.algo.sentence_level.siamesetransquest import models, losses -from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator -from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader +from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -29,8 +24,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s', if not os.path.exists(TEMP_DIRECTORY): os.makedirs(TEMP_DIRECTORY) -if GOOGLE_DRIVE: - download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME) TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv" DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv" @@ -49,6 +42,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_ dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna() test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna() +dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list()))) +test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list()))) + train = fit(train, 'labels') dev = fit(dev, 'labels') |