diff options
author | John Bauer <horatio@gmail.com> | 2022-08-31 22:39:49 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-31 22:39:49 +0300 |
commit | 1e42295deedb4702e49547541c870355da215be8 (patch) | |
tree | 757f57949eccea8b495b473d394f7d8b0e647872 | |
parent | f257cd4bdb841aa5e69545040bdea8f8a10286c5 (diff) |
Save the best score when training a model so that future training from a checkpoint knows when to save a better model
-rw-r--r-- | stanza/models/classifier.py | 13 | ||||
-rw-r--r-- | stanza/models/classifiers/trainer.py | 9 |
2 files changed, 14 insertions, 8 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 28ab9287..9e229cf2 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -387,11 +387,12 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, train_set_by_len = data.sort_dataset_by_len(train_set) - best_score = 0 if trainer.global_step > 0: # We reloaded the model, so let's report its current dev set score - best_score, _, _ = score_dev_set(model, dev_set, args.dev_eval_scoring) + _ = score_dev_set(model, dev_set, args.dev_eval_scoring) logger.info("Reloaded model for continued training.") + if trainer.best_score is not None: + logger.info("Previous best score: %.5f", trainer.best_score) log_param_sizes(model) @@ -442,8 +443,8 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring) if args.wandb: wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step) - if best_score is None or dev_score > best_score: - best_score = dev_score + if trainer.best_score is None or dev_score > trainer.best_score: + trainer.best_score = dev_score trainer.save(model_file) logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1)) model.train() @@ -461,8 +462,8 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, if args.save_intermediate_models: intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score) trainer.save(intermediate_file) - if best_score is None or dev_score > best_score: - best_score = dev_score + if trainer.best_score is None or dev_score > trainer.best_score: + trainer.best_score = dev_score trainer.save(model_file, trainer) logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1)) diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py index a0ce2848..089ee656 100644 --- a/stanza/models/classifiers/trainer.py +++ b/stanza/models/classifiers/trainer.py @@ -21,13 +21,16 @@ class Trainer: Stores a constituency model and its optimizer """ - def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0): + def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None): self.model = model self.optimizer = optimizer # we keep track of position in the learning so that we can # checkpoint & restart if needed without restarting the epoch count self.epochs_trained = epochs_trained self.global_step = global_step + # save the best dev score so that when reloading a checkpoint + # of a model, we know how far we got + self.best_score = best_score def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True): """ @@ -52,6 +55,7 @@ class Trainer: 'extra_vocab': self.model.extra_vocab, 'epochs_trained': epochs_trained, 'global_step': self.global_step, + 'best_score': self.best_score, } if save_optimizer and self.optimizer is not None: params['optimizer_state_dict'] = self.optimizer.state_dict() @@ -84,6 +88,7 @@ class Trainer: epochs_trained = checkpoint.get('epochs_trained', 0) global_step = checkpoint.get('global_step', 0) + best_score = checkpoint.get('best_score', None) # TODO: the getattr is not needed when all models have this baked into the config model_type = getattr(checkpoint['config'], 'model_type', 'CNNClassifier') @@ -128,7 +133,7 @@ class Trainer: else: logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer") - trainer = Trainer(model, optimizer, epochs_trained, global_step) + trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score) return trainer |