diff options
author | John Bauer <horatio@gmail.com> | 2022-08-31 09:15:25 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-31 09:18:28 +0300 |
commit | 1e7ab92c8cf0f2f51f78df74472a6196b51f4637 (patch) | |
tree | a32b33738c1167601dbde41440e695209be376b5 | |
parent | e5793c9dd5359f7e8f4fe82bf318a2f8fd190f54 (diff) |
Save checkpoints with epochs_trained+1 at the end of an epoch (otherwise the epoch will not be incremented properly when reloading)
-rw-r--r-- | stanza/models/classifier.py | 2 | ||||
-rw-r--r-- | stanza/models/classifiers/trainer.py | 11 |
2 files changed, 10 insertions, 3 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 826065dc..b580e0fc 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -456,7 +456,7 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, if args.wandb: wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1, 'epoch_loss': epoch_loss}, step=trainer.global_step) if checkpoint_file: - trainer.save(checkpoint_file) + trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1) if args.save_intermediate_models: intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score) trainer.save(intermediate_file) diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py index c74e7b0a..a0ce2848 100644 --- a/stanza/models/classifiers/trainer.py +++ b/stanza/models/classifiers/trainer.py @@ -29,7 +29,14 @@ class Trainer: self.epochs_trained = epochs_trained self.global_step = global_step - def save(self, filename, skip_modules=True, save_optimizer=True): + def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True): + """ + save the current model, optimizer, and other state to filename + + epochs_trained can be passed as a parameter to handle saving at the end of an epoch + """ + if epochs_trained is None: + epochs_trained = self.epochs_trained save_dir = os.path.split(filename)[0] os.makedirs(save_dir, exist_ok=True) model_state = self.model.state_dict() @@ -43,7 +50,7 @@ class Trainer: 'config': self.model.config, 'labels': self.model.labels, 'extra_vocab': self.model.extra_vocab, - 'epochs_trained': self.epochs_trained, + 'epochs_trained': epochs_trained, 'global_step': self.global_step, } if save_optimizer and self.optimizer is not None: |