Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-31 22:39:49 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-31 22:39:49 +0300
commit1e42295deedb4702e49547541c870355da215be8 (patch)
tree757f57949eccea8b495b473d394f7d8b0e647872
parentf257cd4bdb841aa5e69545040bdea8f8a10286c5 (diff)
Save the best score when training a model so that future training from a checkpoint knows when to save a better model
-rw-r--r--stanza/models/classifier.py13
-rw-r--r--stanza/models/classifiers/trainer.py9
2 files changed, 14 insertions, 8 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index 28ab9287..9e229cf2 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -387,11 +387,12 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
train_set_by_len = data.sort_dataset_by_len(train_set)
- best_score = 0
if trainer.global_step > 0:
# We reloaded the model, so let's report its current dev set score
- best_score, _, _ = score_dev_set(model, dev_set, args.dev_eval_scoring)
+ _ = score_dev_set(model, dev_set, args.dev_eval_scoring)
logger.info("Reloaded model for continued training.")
+ if trainer.best_score is not None:
+ logger.info("Previous best score: %.5f", trainer.best_score)
log_param_sizes(model)
@@ -442,8 +443,8 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)
if args.wandb:
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step)
- if best_score is None or dev_score > best_score:
- best_score = dev_score
+ if trainer.best_score is None or dev_score > trainer.best_score:
+ trainer.best_score = dev_score
trainer.save(model_file)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1))
model.train()
@@ -461,8 +462,8 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
if args.save_intermediate_models:
intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score)
trainer.save(intermediate_file)
- if best_score is None or dev_score > best_score:
- best_score = dev_score
+ if trainer.best_score is None or dev_score > trainer.best_score:
+ trainer.best_score = dev_score
trainer.save(model_file, trainer)
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1))
diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py
index a0ce2848..089ee656 100644
--- a/stanza/models/classifiers/trainer.py
+++ b/stanza/models/classifiers/trainer.py
@@ -21,13 +21,16 @@ class Trainer:
Stores a constituency model and its optimizer
"""
- def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0):
+ def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
self.model = model
self.optimizer = optimizer
# we keep track of position in the learning so that we can
# checkpoint & restart if needed without restarting the epoch count
self.epochs_trained = epochs_trained
self.global_step = global_step
+ # save the best dev score so that when reloading a checkpoint
+ # of a model, we know how far we got
+ self.best_score = best_score
def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
"""
@@ -52,6 +55,7 @@ class Trainer:
'extra_vocab': self.model.extra_vocab,
'epochs_trained': epochs_trained,
'global_step': self.global_step,
+ 'best_score': self.best_score,
}
if save_optimizer and self.optimizer is not None:
params['optimizer_state_dict'] = self.optimizer.state_dict()
@@ -84,6 +88,7 @@ class Trainer:
epochs_trained = checkpoint.get('epochs_trained', 0)
global_step = checkpoint.get('global_step', 0)
+ best_score = checkpoint.get('best_score', None)
# TODO: the getattr is not needed when all models have this baked into the config
model_type = getattr(checkpoint['config'], 'model_type', 'CNNClassifier')
@@ -128,7 +133,7 @@ class Trainer:
else:
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
- trainer = Trainer(model, optimizer, epochs_trained, global_step)
+ trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
return trainer