Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-31 23:39:03 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-31 23:39:03 +0300
commit96c61060fbe04f72d359dd47db3090b56eaf4d55 (patch)
tree47254743c837cbfaf43b7561671bed4c1ae3e697
parentcb413f8459647d329220e68f86ed648944b9acb3 (diff)
Don't save optimizers for the non-checkpoints (and fix a save bug for the end of epoch save)
-rw-r--r--stanza/models/classifier.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index c3c6c83d..9ad26c43 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -477,7 +477,7 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step)
if trainer.best_score is None or dev_score > trainer.best_score:
trainer.best_score = dev_score
- trainer.save(model_file)
+ trainer.save(model_file, save_optimizer=False)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1))
model.train()
epoch_loss += running_loss
@@ -493,10 +493,10 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1)
if args.save_intermediate_models:
intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score)
- trainer.save(intermediate_file)
+ trainer.save(intermediate_file, save_optimizer=False)
if trainer.best_score is None or dev_score > trainer.best_score:
trainer.best_score = dev_score
- trainer.save(model_file, trainer)
+ trainer.save(model_file, save_optimizer=False)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1))
if args.wandb: