Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 19:56:36 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 19:56:36 +0300
commita3fe38c57dd2426f282ef8351e66581a0a96e325 (patch)
treef64b080607331261a6416466d029c6c6cfce99d8
parent89fc006ae1985d89c147aa2d913b0e12bf1bb2d1 (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--examples/sentence_level/wmt_2020/common/util/postprocess.py6
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py7
2 files changed, 1 insertions, 12 deletions
diff --git a/examples/sentence_level/wmt_2020/common/util/postprocess.py b/examples/sentence_level/wmt_2020/common/util/postprocess.py
index 5697909..6a68630 100644
--- a/examples/sentence_level/wmt_2020/common/util/postprocess.py
+++ b/examples/sentence_level/wmt_2020/common/util/postprocess.py
@@ -6,11 +6,7 @@ def format_submission(df, language_pair, method, index, path, index_type=None):
elif index_type == "Auto":
index = range(0, df.shape[0])
- predictions = df['predictions'].tolist()
-
- print(index)
- print(predictions)
-
+ predictions = df['predictions']
with open(path, 'w') as f:
for number, prediction in zip(index, predictions):
text = language_pair + "\t" + method + "\t" + str(number) + "\t" + str(prediction)
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index 861779b..74400fe 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -61,7 +61,6 @@ test_sentence_pairs = list(map(list, zip(test['text_a'].to_list(), test['text_b'
train = fit(train, 'labels')
dev = fit(dev, 'labels')
-
assert (len(test_index) == 1000)
if siamesetransquest_config["evaluate_during_training"]:
if siamesetransquest_config["n_fold"] > 0:
@@ -138,12 +137,6 @@ if siamesetransquest_config["evaluate_during_training"]:
dev['predictions'] = dev_preds.mean(axis=1)
test['predictions'] = test_preds.mean(axis=1)
-# # random_list = random.sample(range(0, 1000), 1000)
-# # newList = list(map(lambda x: x/1000, random_list))
-#
-# dev['predictions'] = newList
-# test['predictions'] = newList
-
dev = un_fit(dev, 'labels')
dev = un_fit(dev, 'predictions')
test = un_fit(test, 'predictions')