Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-30 07:01:55 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-30 09:20:37 +0300
commit08424f6ca2da07021b6e55cce2460b86c857601b (patch)
treeb6826436c8813428456778b7e8bbcfb45d6dc559
parentd881203e04be81c97b96c376d6f350c4e5181886 (diff)
Track how many batches a model gets trained for. Backdoor test for the silver trees, since adding a silver treebank makes an epoch take twice as long
-rw-r--r--stanza/models/constituency/trainer.py11
-rw-r--r--stanza/tests/constituency/test_trainer.py1
2 files changed, 9 insertions, 3 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 6f1e4818..1bd676e5 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -47,13 +47,14 @@ class Trainer:
Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)
"""
- def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, best_f1=0.0, best_epoch=0):
+ def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
# keeping track of the epochs trained will be useful
# for adjusting the learning scheme
self.epochs_trained = epochs_trained
+ self.batches_trained = batches_trained
self.best_f1 = best_f1
self.best_epoch = best_epoch
@@ -65,6 +66,7 @@ class Trainer:
checkpoint = {
'params': params,
'epochs_trained': self.epochs_trained,
+ 'batches_trained': self.batches_trained,
'best_f1': self.best_f1,
'best_epoch': self.best_epoch,
}
@@ -139,6 +141,7 @@ class Trainer:
model.cuda()
epochs_trained = checkpoint['epochs_trained']
+ batches_trained = checkpoint.get('batches_trained', 0)
best_f1 = checkpoint['best_f1']
best_epoch = checkpoint['best_epoch']
@@ -168,7 +171,7 @@ class Trainer:
for k in model.args.keys():
logger.debug(" --%s: %s", k, model.args[k])
- return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, best_f1=best_f1, best_epoch=best_epoch)
+ return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch)
def load_pretrain_or_wordvec(args):
@@ -708,6 +711,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
if args['multistage'] and trainer.epochs_trained in multistage_splits:
# we may be loading a save model from an earlier epoch if the scores stopped increasing
epochs_trained = trainer.epochs_trained
+ batches_trained = trainer.batches_trained
stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained]
@@ -737,7 +741,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
optimizer = build_optimizer(temp_args, new_model, False)
scheduler = build_scheduler(temp_args, optimizer)
- trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, trainer.best_f1, trainer.best_epoch)
+ trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch)
add_grad_clipping(trainer, args['grad_clipping'])
model = new_model
@@ -762,6 +766,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio
for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args)
+ trainer.batches_trained += 1
# Early in the training, some trees will be degenerate in a
# way that results in layers going up the tree amplifying the
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index a9a37f4f..80d191a8 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -197,6 +197,7 @@ class TestTrainer:
assert os.path.exists(model_name)
tr = trainer.Trainer.load(model_name, load_optimizer=True)
assert tr.epochs_trained == i
+ assert tr.batches_trained == (4 * i if use_silver else 2 * i)
return args