diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-20 19:59:06 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-20 19:59:06 +0300 |
commit | f8786c1da3a1c0ee2f754b9c2d409903bc93fc59 (patch) | |
tree | 682c5c517c9f2e7e1ca550aa956f5aefabec909a /examples | |
parent | 3a04a785172a5d74833a32f9cd3fee64947898b6 (diff) |
057: Code Refactoring - Siamese Architectures
Diffstat (limited to 'examples')
-rwxr-xr-x | examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py index d36bb65..763b35a 100755 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py @@ -92,15 +92,15 @@ if siamesetransquest_config["evaluate_during_training"]: sts_reader = QEDataReader(siamesetransquest_config['cache_dir'], s1_col_idx=0, s2_col_idx=1, score_col_idx=2, - normalize_scores=False, min_score=0, max_score=1, header=True) + normalize_scores=False, min_score=0, max_score=1) word_embedding_model = Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ 'max_seq_length']) pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False) + pooling_mode_mean_tokens=True, + pooling_mode_cls_token=False, + pooling_mode_max_tokens=False) model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model]) train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model) @@ -138,7 +138,7 @@ if siamesetransquest_config["evaluate_during_training"]: evaluator = EmbeddingSimilarityEvaluator(dev_dataloader) start = time.time() model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) + output_path=os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) end = time.time() print("Testing time") @@ -149,7 +149,7 @@ if siamesetransquest_config["evaluate_during_training"]: evaluator = EmbeddingSimilarityEvaluator(test_dataloader) model.evaluate(evaluator, - result_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"), + output_path=os.path.join(siamesetransquest_config['cache_dir'], "test_result.txt"), verbose=False) with open(os.path.join(siamesetransquest_config['cache_dir'], "dev_result.txt")) as f: |