diff options
author | John Bauer <horatio@gmail.com> | 2022-08-31 23:39:03 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-31 23:39:03 +0300 |
commit | 96c61060fbe04f72d359dd47db3090b56eaf4d55 (patch) | |
tree | 47254743c837cbfaf43b7561671bed4c1ae3e697 | |
parent | cb413f8459647d329220e68f86ed648944b9acb3 (diff) |
Don't save optimizers for the non-checkpoints (and fix a save bug for the end of epoch save)
-rw-r--r-- | stanza/models/classifier.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index c3c6c83d..9ad26c43 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -477,7 +477,7 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step) if trainer.best_score is None or dev_score > trainer.best_score: trainer.best_score = dev_score - trainer.save(model_file) + trainer.save(model_file, save_optimizer=False) logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1)) model.train() epoch_loss += running_loss @@ -493,10 +493,10 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1) if args.save_intermediate_models: intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score) - trainer.save(intermediate_file) + trainer.save(intermediate_file, save_optimizer=False) if trainer.best_score is None or dev_score > trainer.best_score: trainer.best_score = dev_score - trainer.save(model_file, trainer) + trainer.save(model_file, save_optimizer=False) logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1)) if args.wandb: |