diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 21:54:02 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 21:54:02 +0300 |
commit | 35d35371baa27ecbf6e4b04fe36afe4a4a5eac7c (patch) | |
tree | 74230bca682a612b15e49b1c2ef457bbb2b0e035 | |
parent | d877045d976ca86e3640686aaf72fd0b3eda4eb8 (diff) |
057: Code cleaning
-rwxr-xr-x | examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py | 2 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 4 |
2 files changed, 4 insertions, 2 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py index fbdbb95..4c5f5dc 100755 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py @@ -61,7 +61,7 @@ if siamesetransquest_config["evaluate_during_training"]: shutil.rmtree(siamesetransquest_config['cache_dir']) train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - model = SiameseTransQuestModel(MODEL_NAME) + model = SiameseTransQuestModel(MODEL_NAME, args=siamesetransquest_config) model.train_model(train_df, eval_df) model = SiameseTransQuestModel(siamesetransquest_config['best_model_dir']) diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index 847afc2..511a3b6 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -57,7 +57,7 @@ class SiameseTransQuestModel(nn.Sequential): if self.args.n_gpu > 0: torch.cuda.manual_seed_all(self.args.manual_seed) - transformer_model = Transformer(model_name, max_seq_length=80) + transformer_model = Transformer(model_name, max_seq_length=args.max_seq_length) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) @@ -333,6 +333,8 @@ class SiameseTransQuestModel(nn.Sequential): with open(os.path.join(path, 'modules.json'), 'w') as fOut: json.dump(contained_modules, fOut, indent=2) + self.save_model_args(path) + def smart_batching_collate(self, batch): """ Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model |