diff options
author | John Bauer <horatio@gmail.com> | 2022-10-30 07:01:55 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-30 09:20:37 +0300 |
commit | 08424f6ca2da07021b6e55cce2460b86c857601b (patch) | |
tree | b6826436c8813428456778b7e8bbcfb45d6dc559 | |
parent | d881203e04be81c97b96c376d6f350c4e5181886 (diff) |
Track how many batches a model gets trained for. Backdoor test for the silver trees, since adding a silver treebank makes an epoch take twice as long
-rw-r--r-- | stanza/models/constituency/trainer.py | 11 | ||||
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 1 |
2 files changed, 9 insertions, 3 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 6f1e4818..1bd676e5 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -47,13 +47,14 @@ class Trainer: Not inheriting from common/trainer.py because there's no concept of change_lr (yet?) """ - def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, best_f1=0.0, best_epoch=0): + def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0): self.model = model self.optimizer = optimizer self.scheduler = scheduler # keeping track of the epochs trained will be useful # for adjusting the learning scheme self.epochs_trained = epochs_trained + self.batches_trained = batches_trained self.best_f1 = best_f1 self.best_epoch = best_epoch @@ -65,6 +66,7 @@ class Trainer: checkpoint = { 'params': params, 'epochs_trained': self.epochs_trained, + 'batches_trained': self.batches_trained, 'best_f1': self.best_f1, 'best_epoch': self.best_epoch, } @@ -139,6 +141,7 @@ class Trainer: model.cuda() epochs_trained = checkpoint['epochs_trained'] + batches_trained = checkpoint.get('batches_trained', 0) best_f1 = checkpoint['best_f1'] best_epoch = checkpoint['best_epoch'] @@ -168,7 +171,7 @@ class Trainer: for k in model.args.keys(): logger.debug(" --%s: %s", k, model.args[k]) - return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, best_f1=best_f1, best_epoch=best_epoch) + return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch) def load_pretrain_or_wordvec(args): @@ -708,6 +711,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d if args['multistage'] and trainer.epochs_trained in multistage_splits: # we may be loading a save model from an earlier epoch if the scores stopped increasing epochs_trained = trainer.epochs_trained + batches_trained = trainer.batches_trained stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained] @@ -737,7 +741,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d optimizer = build_optimizer(temp_args, new_model, False) scheduler = build_scheduler(temp_args, optimizer) - trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, trainer.best_f1, trainer.best_epoch) + trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch) add_grad_clipping(trainer, args['grad_clipping']) model = new_model @@ -762,6 +766,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args) + trainer.batches_trained += 1 # Early in the training, some trees will be degenerate in a # way that results in layers going up the tree amplifying the diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index a9a37f4f..80d191a8 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -197,6 +197,7 @@ class TestTrainer: assert os.path.exists(model_name) tr = trainer.Trainer.load(model_name, load_optimizer=True) assert tr.epochs_trained == i + assert tr.batches_trained == (4 * i if use_silver else 2 * i) return args |