Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-20 19:59:06 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-20 19:59:06 +0300
commitf8786c1da3a1c0ee2f754b9c2d409903bc93fc59 (patch)
tree682c5c517c9f2e7e1ca550aa956f5aefabec909a /examples
parent3a04a785172a5d74833a32f9cd3fee64947898b6 (diff)
057: Code Refactoring - Siamese Architectures
Diffstat (limited to 'examples')
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index d36bb65..763b35a 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -92,15 +92,15 @@ if siamesetransquest_config["evaluate_during_training"]:
sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1,
score_col_idx=2,
- normalize_scores=False, min_score=0, max_score=1, header=True)
+ normalize_scores=False, min_score=0, max_score=1)
word_embedding_model = Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
'max_seq_length'])
pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension(),
- pooling_mode_mean_tokens=True,
- pooling_mode_cls_token=False,
- pooling_mode_max_tokens=False)
+ pooling_mode_mean_tokens=True,
+ pooling_mode_cls_token=False,
+ pooling_mode_max_tokens=False)
model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model])
train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model)
@@ -138,7 +138,7 @@ if siamesetransquest_config["evaluate_during_training"]:
evaluator = EmbeddingSimilarityEvaluator(dev_dataloader)
start = time.time()
model.evaluate(evaluator,
- result_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt"))
+ output_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt"))
end = time.time()
print("Testing time")
@@ -149,7 +149,7 @@ if siamesetransquest_config["evaluate_during_training"]:
evaluator = EmbeddingSimilarityEvaluator(test_dataloader)
model.evaluate(evaluator,
- result_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"),
+ output_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"),
verbose=False)
with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f: