Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-31 09:15:25 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-31 09:18:28 +0300
commit1e7ab92c8cf0f2f51f78df74472a6196b51f4637 (patch)
treea32b33738c1167601dbde41440e695209be376b5
parente5793c9dd5359f7e8f4fe82bf318a2f8fd190f54 (diff)
Save checkpoints with epochs_trained+1 at the end of an epoch (otherwise the epoch will not be incremented properly when reloading)
-rw-r--r--stanza/models/classifier.py2
-rw-r--r--stanza/models/classifiers/trainer.py11
2 files changed, 10 insertions, 3 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index 826065dc..b580e0fc 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -456,7 +456,7 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
if args.wandb:
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1, 'epoch_loss': epoch_loss}, step=trainer.global_step)
if checkpoint_file:
- trainer.save(checkpoint_file)
+ trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1)
if args.save_intermediate_models:
intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score)
trainer.save(intermediate_file)
diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py
index c74e7b0a..a0ce2848 100644
--- a/stanza/models/classifiers/trainer.py
+++ b/stanza/models/classifiers/trainer.py
@@ -29,7 +29,14 @@ class Trainer:
self.epochs_trained = epochs_trained
self.global_step = global_step
- def save(self, filename, skip_modules=True, save_optimizer=True):
+ def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
+ """
+ save the current model, optimizer, and other state to filename
+
+ epochs_trained can be passed as a parameter to handle saving at the end of an epoch
+ """
+ if epochs_trained is None:
+ epochs_trained = self.epochs_trained
save_dir = os.path.split(filename)[0]
os.makedirs(save_dir, exist_ok=True)
model_state = self.model.state_dict()
@@ -43,7 +50,7 @@ class Trainer:
'config': self.model.config,
'labels': self.model.labels,
'extra_vocab': self.model.extra_vocab,
- 'epochs_trained': self.epochs_trained,
+ 'epochs_trained': epochs_trained,
'global_step': self.global_step,
}
if save_optimizer and self.optimizer is not None: