Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py')
-rw-r--r--examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
index 37de783..a3390d2 100644
--- a/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
+++ b/examples/sentence_level/wmt_2018/en_de/smt/monotransquest.py
@@ -10,7 +10,7 @@ from examples.sentence_level.wmt_2018.common.util.draw import draw_scatterplot,
from examples.sentence_level.wmt_2018.common.util.normalizer import fit, un_fit
from examples.sentence_level.wmt_2018.common.util.postprocess import format_submission
from examples.sentence_level.wmt_2018.common.util.reader import read_annotated_file, read_test_file
-from examples.sentence_level.wmt_2018.en_de.smt.monotransquest_config import TEMP_DIRECTORY, GOOGLE_DRIVE, DRIVE_FILE_ID, MODEL_NAME, \
+from examples.sentence_level.wmt_2018.en_de.smt.monotransquest_config import TEMP_DIRECTORY, MODEL_NAME, \
monotransquest_config, MODEL_TYPE, SEED, RESULT_FILE, SUBMISSION_FILE, RESULT_IMAGE
from transquest.algo.sentence_level.monotransquest.evaluation import pearson_corr, spearman_corr
from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel
@@ -18,13 +18,14 @@ from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQue
if not os.path.exists(TEMP_DIRECTORY):
os.makedirs(TEMP_DIRECTORY)
-
TRAIN_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
DEV_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
TEST_FOLDER = "examples/sentence_level/wmt_2018/en_de/data/en_de/"
-train = read_annotated_file(path=TRAIN_FOLDER, original_file="train.smt.src", translation_file="train.smt.mt", hter_file="train.smt.hter")
-dev = read_annotated_file(path=DEV_FOLDER, original_file="dev.smt.src", translation_file="dev.smt.mt", hter_file="dev.smt.hter")
+train = read_annotated_file(path=TRAIN_FOLDER, original_file="train.smt.src", translation_file="train.smt.mt",
+ hter_file="train.smt.hter")
+dev = read_annotated_file(path=DEV_FOLDER, original_file="dev.smt.src", translation_file="dev.smt.mt",
+ hter_file="dev.smt.hter")
test = read_test_file(path=TEST_FOLDER, original_file="test.smt.src", translation_file="test.smt.mt")
train = train[['original', 'translation', 'hter']]
@@ -47,15 +48,17 @@ if monotransquest_config["evaluate_during_training"]:
test_preds = np.zeros((len(test), monotransquest_config["n_fold"]))
for i in range(monotransquest_config["n_fold"]):
- if os.path.exists(monotransquest_config['output_dir']) and os.path.isdir(monotransquest_config['output_dir']):
+ if os.path.exists(monotransquest_config['output_dir']) and os.path.isdir(
+ monotransquest_config['output_dir']):
shutil.rmtree(monotransquest_config['output_dir'])
model = MonoTransQuestModel(MODEL_TYPE, MODEL_NAME, num_labels=1, use_cuda=torch.cuda.is_available(),
args=monotransquest_config)
- train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED*i)
+ train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
model.train_model(train_df, eval_df=eval_df, pearson_corr=pearson_corr, spearman_corr=spearman_corr,
mae=mean_absolute_error)
- model = MonoTransQuestModel(MODEL_TYPE, monotransquest_config["best_model_dir"], num_labels=1, use_cuda=torch.cuda.is_available(), args=monotransquest_config)
+ model = MonoTransQuestModel(MODEL_TYPE, monotransquest_config["best_model_dir"], num_labels=1,
+ use_cuda=torch.cuda.is_available(), args=monotransquest_config)
result, model_outputs, wrong_predictions = model.eval_model(dev, pearson_corr=pearson_corr,
spearman_corr=spearman_corr,
mae=mean_absolute_error)
@@ -97,4 +100,4 @@ test = un_fit(test, 'predictions')
dev.to_csv(os.path.join(TEMP_DIRECTORY, RESULT_FILE), header=True, sep='\t', index=False, encoding='utf-8')
draw_scatterplot(dev, 'labels', 'predictions', os.path.join(TEMP_DIRECTORY, RESULT_IMAGE), "English-German-SMT")
print_stat(dev, 'labels', 'predictions')
-format_submission(df=test, index=index, method="TransQuest", path=os.path.join(TEMP_DIRECTORY, SUBMISSION_FILE)) \ No newline at end of file
+format_submission(df=test, index=index, method="TransQuest", path=os.path.join(TEMP_DIRECTORY, SUBMISSION_FILE))