Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 21:54:02 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 21:54:02 +0300
commit35d35371baa27ecbf6e4b04fe36afe4a4a5eac7c (patch)
tree74230bca682a612b15e49b1c2ef457bbb2b0e035
parentd877045d976ca86e3640686aaf72fd0b3eda4eb8 (diff)
057: Code cleaning
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py2
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py4
2 files changed, 4 insertions, 2 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index fbdbb95..4c5f5dc 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -61,7 +61,7 @@ if siamesetransquest_config["evaluate_during_training"]:
shutil.rmtree(siamesetransquest_config['cache_dir'])
train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
- model = SiameseTransQuestModel(MODEL_NAME)
+ model = SiameseTransQuestModel(MODEL_NAME, args=siamesetransquest_config)
model.train_model(train_df, eval_df)
model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir'])
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 847afc2..511a3b6 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -57,7 +57,7 @@ class SiameseTransQuestModel(nn.Sequential):
if self.args.n_gpu > 0:
torch.cuda.manual_seed_all(self.args.manual_seed)
- transformer_model = Transformer(model_name, max_seq_length=80)
+ transformer_model = Transformer(model_name, max_seq_length=args.max_seq_length)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
@@ -333,6 +333,8 @@ class SiameseTransQuestModel(nn.Sequential):
with open(os.path.join(path, 'modules.json'), 'w') as fOut:
json.dump(contained_modules, fOut, indent=2)
+ self.save_model_args(path)
+
def smart_batching_collate(self, batch):
"""
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model