Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py')
-rw-r--r--examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py18
1 files changed, 7 insertions, 11 deletions
diff --git a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
index cd5a981..57ad114 100644
--- a/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ne_en/siamesetransquest.py
@@ -1,25 +1,20 @@
-import csv
+
import logging
-import math
+
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
-from torch.utils.data import DataLoader
-from examples.sentence_level.wmt_2020.common.util.download import download_from_google_drive
from examples.sentence_level.wmt_2020.common.util.draw import draw_scatterplot, print_stat
from examples.sentence_level.wmt_2020.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2020.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2020.common.util.reader import read_annotated_file, read_test_file
from examples.sentence_level.wmt_2020.ne_en.siamesetransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
siamesetransquest_config, SEED, RESULT_FILE, RESULT_IMAGE, SUBMISSION_FILE
-from transquest.algo.sentence_level.siamesetransquest import LoggingHandler, SentencesDataset, \
- SiameseTransQuestModel
-from transquest.algo.sentence_level.siamesetransquest import models, losses
-from transquest.algo.sentence_level.siamesetransquest.evaluation import EmbeddingSimilarityEvaluator
-from transquest.algo.sentence_level.siamesetransquest.readers import QEDataReader
+from transquest.algo.sentence_level.siamesetransquest.logging_handler import LoggingHandler
+from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@@ -29,8 +24,6 @@ logging.basicConfig(format='%(asctime)s - %(message)s',
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-if GOOGLE_DRIVE:
- download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)
TRAIN_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/train.neen.df.short.tsv"
DEV_FILE = "examples/sentence_level/wmt_2020/ne_en/data/ne-en/dev.neen.df.short.tsv"
@@ -49,6 +42,9 @@ train = train.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_
dev = dev.rename(columns={'original': 'text_a', 'translation': 'text_b', 'z_mean': 'labels'}).dropna()
test = test.rename(columns={'original': 'text_a', 'translation': 'text_b'}).dropna()
+dev_sentence_pairs = list(map(list, zip(dev['text_a'].to_list(), dev['text_b'].to_list())))
+test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'].to_list())))
+
train = fit(train, 'labels')
dev = fit(dev, 'labels')