Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-09 04:23:30 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-10 05:06:46 +0300
commit0e6de808eacf14cd64622415eeaeeac2d60faab2 (patch)
treeb2b1d99fccfea68f254184ea88826d4079cc70c1
parentec3e731a35be41886c24bf508727ab494588ebb7 (diff)
Always save checkpoints. Always load from a checkpoint if one exists.con_checkpoint
Build the constituency optimizer using knowledge of how far you are in the training process - multistage part 1 gets Adadelta, for example Test that a multistage training process builds the correct optimizers, including when reloading When continuing training from a checkpoint, use the existing epochs_trained Restart epochs count when doing a finetune
-rw-r--r--stanza/models/constituency/trainer.py84
-rw-r--r--stanza/models/constituency/utils.py51
-rw-r--r--stanza/models/constituency_parser.py26
-rw-r--r--stanza/tests/constituency/test_trainer.py51
4 files changed, 136 insertions, 76 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index bd776c1e..1bea94a9 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -132,8 +132,12 @@ class Trainer:
if saved_args['cuda']:
model.cuda()
+ epochs_trained = checkpoint.get('epochs_trained', 0)
+
if load_optimizer:
- optimizer = build_optimizer(saved_args, model)
+ # need to match the optimizer we build with the one that was used at training time
+ build_simple_adadelta = checkpoint['args']['multistage'] and epochs_trained < checkpoint['args']['epochs'] // 2
+ optimizer = build_optimizer(saved_args, model, build_simple_adadelta)
if checkpoint.get('optimizer_state_dict', None) is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
@@ -151,8 +155,6 @@ class Trainer:
for k in model.args.keys():
logger.debug(" --%s: %s", k, model.args[k])
- epochs_trained = checkpoint.get('epochs_trained', -1)
-
return Trainer(args=saved_args, model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained)
@@ -348,9 +350,25 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
bert_model, bert_tokenizer = foundation_cache.load_bert(args['bert_model'])
- if args['finetune'] or (args['maybe_finetune'] and os.path.exists(model_load_file)):
- logger.info("Loading model to continue training from %s", model_load_file)
+ trainer = None
+ if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']):
+ logger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name'])
+ trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache)
+ # grad clipping is not saved with the rest of the model
+ add_grad_clipping(trainer, args['grad_clipping'])
+
+ # TODO: turn finetune, relearn_structure, multistage into an enum
+ # finetune just means continue learning, so checkpoint is sufficient
+ # relearn_structure is essentially a one stage multistage
+ # multistage with a checkpoint will have the proper optimizer for that epoch
+ # and no special learning mode means we are training a new model and should continue
+ return trainer, train_sequences, train_transitions
+
+ if args['finetune']:
+ logger.info("Loading model to finetune: %s", model_load_file)
trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=foundation_cache)
+ # a new finetuning will start with a new epochs_trained count
+ trainer.epochs_trained = 0
elif args['relearn_structure']:
logger.info("Loading model to continue training with new structure from %s", model_load_file)
temp_args = dict(args)
@@ -366,7 +384,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
if args['cuda']:
model.cuda()
model.copy_with_new_structure(trainer.model)
- optimizer = build_optimizer(args, model)
+ optimizer = build_optimizer(args, model, False)
scheduler = build_scheduler(args, optimizer)
trainer = Trainer(args, model, optimizer, scheduler)
elif args['multistage']:
@@ -375,13 +393,6 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
# this works surprisingly well
logger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2)
temp_args = dict(args)
- temp_args['optim'] = 'adadelta'
- temp_args['learning_rate'] = DEFAULT_LEARNING_RATES['adadelta']
- temp_args['learning_eps'] = DEFAULT_LEARNING_EPS['adadelta']
- temp_args['learning_rho'] = DEFAULT_LEARNING_RHO
- temp_args['weight_decay'] = DEFAULT_WEIGHT_DECAY['adadelta']
- temp_args['epochs'] = args['epochs'] // 2
- temp_args['learning_rate_warmup'] = 0
# remove the attention layers for the temporary model
temp_args['pattn_num_layers'] = 0
temp_args['lattn_d_proj'] = 0
@@ -389,8 +400,8 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args)
if args['cuda']:
temp_model.cuda()
- temp_optim = build_optimizer(temp_args, temp_model)
- scheduler = build_scheduler(args, temp_optim)
+ temp_optim = build_optimizer(temp_args, temp_model, True)
+ scheduler = build_scheduler(temp_args, temp_optim)
trainer = Trainer(temp_args, temp_model, temp_optim, scheduler)
else:
model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args)
@@ -398,7 +409,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
model.cuda()
logger.info("Number of words in the training set found in the embedding: {} out of {}".format(model.num_words_known(words), len(words)))
- optimizer = build_optimizer(args, model)
+ optimizer = build_optimizer(args, model, False)
scheduler = build_scheduler(args, optimizer)
trainer = Trainer(args, model, optimizer, scheduler)
@@ -435,7 +446,7 @@ def remove_no_tags(trees):
logger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees)))
return new_trees
-def train(args, model_save_file, model_load_file, model_save_latest_file, model_save_each_file, retag_pipeline):
+def train(args, model_load_file, model_save_each_file, retag_pipeline):
"""
Build a model, train it using the requested train & dev files
"""
@@ -479,7 +490,7 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_
foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file)
- trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator)
+ trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_each_file, evaluator)
if args['wandb']:
wandb.finish()
@@ -498,7 +509,7 @@ class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct',
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used)
-def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_filename, model_latest_filename, model_save_each_filename, evaluator):
+def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator):
"""
Given an initialized model, a processed dataset, and a secondary dev dataset, train the model
@@ -543,9 +554,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
leftover_training_data = []
best_f1 = 0.0
best_epoch = 0
- for epoch in range(1, args['epochs']+1):
+ # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1
+ for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1):
model.train()
- logger.info("Starting epoch %d", epoch)
+ logger.info("Starting epoch %d", trainer.epochs_trained)
if args['log_norms']:
model.log_norms()
epoch_data = leftover_training_data
@@ -556,7 +568,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
epoch_data = epoch_data[:args['epoch_size']]
epoch_data.sort(key=lambda x: len(x[1]))
- epoch_stats = train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args)
+ epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args)
# print statistics
f1 = run_dev_set(model, dev_trees, args, evaluator)
@@ -566,16 +578,16 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
# very simple model didn't learn anything
logger.info("New best dev score: %.5f > %.5f", f1, best_f1)
best_f1 = f1
- best_epoch = epoch
- trainer.save(model_filename, save_optimizer=True)
- if model_latest_filename:
- trainer.save(model_latest_filename, save_optimizer=True)
+ best_epoch = trainer.epochs_trained
+ trainer.save(args['save_name'], save_optimizer=False)
+ if args['checkpoint'] and args['checkpoint_save_name']:
+ trainer.save(args['checkpoint_save_name'], save_optimizer=True)
if model_save_each_filename:
- trainer.save(model_save_each_filename % epoch, save_optimizer=True)
- logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", epoch, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, epoch, f1, best_epoch, best_f1)
+ trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True)
+ logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, best_epoch, best_f1)
if args['wandb']:
- wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=epoch)
+ wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)
if args['wandb_norm_regex']:
watch_regex = re.compile(args['wandb_norm_regex'])
for n, p in model.named_parameters():
@@ -583,22 +595,20 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
wandb.log({n: torch.linalg.norm(p)})
# don't redo the optimizer a second time if we're not changing the structure
- if args['multistage'] and epoch in multistage_splits:
- # TODO: start counting epoch from trainer.epochs_trained for a previously trained model?
-
+ if args['multistage'] and trainer.epochs_trained in multistage_splits:
# we may be loading a save model from an earlier epoch if the scores stopped increasing
epochs_trained = trainer.epochs_trained
- stage_pattn_layers, stage_uses_lattn = multistage_splits[epoch]
+ stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained]
# when loading the model, let the saved model determine whether it has pattn or lattn
temp_args = dict(trainer.args)
temp_args.pop('pattn_num_layers', None)
temp_args.pop('lattn_d_proj', None)
# overwriting the old trainer & model will hopefully free memory
- trainer = Trainer.load(model_filename, temp_args, load_optimizer=False, foundation_cache=foundation_cache)
+ trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False, foundation_cache=foundation_cache)
model = trainer.model
- logger.info("Finished stage at epoch %d. Restarting optimizer", epoch)
+ logger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained)
logger.info("Previous best model was at epoch %d", trainer.epochs_trained)
temp_args = dict(args)
@@ -615,7 +625,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
new_model.cuda()
new_model.copy_with_new_structure(model)
- optimizer = build_optimizer(temp_args, new_model)
+ optimizer = build_optimizer(temp_args, new_model, False)
scheduler = build_scheduler(temp_args, optimizer)
trainer = Trainer(temp_args, new_model, optimizer, scheduler, epochs_trained)
add_grad_clipping(trainer, args['grad_clipping'])
@@ -648,8 +658,6 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio
if old_lr != new_lr:
logger.info("Updating learning rate from %f to %f", old_lr, new_lr)
- trainer.epochs_trained += 1
-
# TODO: refactor the logging?
total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())
total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index a6a05119..ca642551 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -4,6 +4,7 @@ Collects a few of the conparser utility methods which don't belong elsewhere
from collections import deque
import copy
+import logging
import torch.nn as nn
from torch import optim
@@ -15,6 +16,8 @@ DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
DEFAULT_LEARNING_RHO = 0.9
DEFAULT_MOMENTUM = { "madgrad": 0.9, "sgd": 0.9 }
+logger = logging.getLogger('stanza')
+
# madgrad experiment for weight decay
# with learning_rate set to 0.0000007 and momentum 0.9
# on en_wsj, with a baseline model trained on adadela for 200,
@@ -116,32 +119,56 @@ def build_nonlinearity(nonlinearity):
return NONLINEARITY[nonlinearity]()
raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity)
-def build_optimizer(args, model):
+def build_optimizer(args, model, build_simple_adadelta=False):
"""
Build an optimizer based on the arguments given
+
+ If we are "multistage" training and epochs_trained < epochs // 2,
+ we build an AdaDelta optimizer instead of whatever was requested
+ The build_simple_adadelta parameter controls this
"""
+ if build_simple_adadelta:
+ optim_type = 'adadelta'
+ learning_eps = DEFAULT_LEARNING_EPS['adadelta']
+ learning_rate = DEFAULT_LEARNING_RATES['adadelta']
+ learning_rho = DEFAULT_LEARNING_RHO
+ weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']
+ else:
+ optim_type = args['optim'].lower()
+ learning_beta2 = args['learning_beta2']
+ learning_eps = args['learning_eps']
+ learning_rate = args['learning_rate']
+ learning_rho = args['learning_rho']
+ momentum = args['momentum']
+ weight_decay = args['weight_decay']
+
parameters = [param for name, param in model.named_parameters() if not model.is_unsaved_module(name)]
- if args['optim'].lower() == 'sgd':
- optimizer = optim.SGD(parameters, lr=args['learning_rate'], momentum=args['momentum'], weight_decay=args['weight_decay'])
- elif args['optim'].lower() == 'adadelta':
- optimizer = optim.Adadelta(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], rho=args['learning_rho'])
- elif args['optim'].lower() == 'adamw':
- optimizer = optim.AdamW(parameters, lr=args['learning_rate'], betas=(0.9, args['learning_beta2']), eps=args['learning_eps'], weight_decay=args['weight_decay'])
- elif args['optim'].lower() == 'adabelief':
+ if optim_type == 'sgd':
+ logger.info("Building SGD with lr=%f, momentum=%f, weight_decay=%f", learning_rate, momentum, weight_decay)
+ optimizer = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
+ elif optim_type == 'adadelta':
+ logger.info("Building Adadelta with lr=%f, eps=%f, weight_decay=%f, rho=%f", learning_rate, learning_eps, weight_decay, learning_rho)
+ optimizer = optim.Adadelta(parameters, lr=learning_rate, eps=learning_eps, weight_decay=weight_decay, rho=learning_rho)
+ elif optim_type == 'adamw':
+ logger.info("Building AdamW with lr=%f, beta2=%f, eps=%f, weight_decay=%f", learning_rate, learning_beta2, learning_eps, weight_decay)
+ optimizer = optim.AdamW(parameters, lr=learning_rate, betas=(0.9, learning_beta2), eps=learning_eps, weight_decay=weight_decay)
+ elif optim_type == 'adabelief':
try:
from adabelief_pytorch import AdaBelief
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e
+ logger.info("Building AdaBelief with lr=%f, eps=%f, weight_decay=%f", learning_rate, learning_eps, weight_decay)
# TODO: make these args
- optimizer = AdaBelief(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], weight_decouple=False, rectify=False)
- elif args['optim'].lower() == 'madgrad':
+ optimizer = AdaBelief(parameters, lr=learning_rate, eps=learning_eps, weight_decay=weight_decay, weight_decouple=False, rectify=False)
+ elif optim_type == 'madgrad':
try:
import madgrad
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e
- optimizer = madgrad.MADGRAD(parameters, lr=args['learning_rate'], weight_decay=args['weight_decay'], momentum=args['momentum'])
+ logger.info("Building AdaBelief with lr=%f, weight_decay=%f, momentum=%f", learning_rate, weight_decay, momentum)
+ optimizer = madgrad.MADGRAD(parameters, lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
else:
- raise ValueError("Unknown optimizer: %s" % args['optim'])
+ raise ValueError("Unknown optimizer: %s" % optim)
return optimizer
def build_scheduler(args, optimizer):
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 91d1d23c..1c588e01 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -242,7 +242,6 @@ def parse_args(args=None):
parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
- parser.add_argument('--save_latest_name', type=str, default=None, help="Save the latest model here regardless of score. Useful for restarting training")
parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
parser.add_argument('--seed', type=int, default=1234)
@@ -376,7 +375,8 @@ def parse_args(args=None):
parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
- parser.add_argument('--maybe_finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path if it exists. Useful for running in situations where a job is frequently being preempted')
+ parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
+ parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
@@ -450,8 +450,13 @@ def parse_args(args=None):
else:
raise ValueError("Unknown retag method {}".format(xpos))
- if args['multistage'] and (args['finetune'] or args['maybe_finetune'] or args['relearn_structure']):
- raise ValueError('Learning multistage from a previously started model is not yet implemented. TODO')
+ model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand'])
+
+ if args['checkpoint']:
+ args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name'])
+
+ model_save_file = os.path.join(args['save_dir'], model_save_file)
+ args['save_name'] = model_save_file
return args
@@ -468,13 +473,6 @@ def main(args=None):
logger.info("Running constituency parser in %s mode", args['mode'])
logger.debug("Using GPU: %s", args['cuda'])
- model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand'])
- model_save_file = os.path.join(args['save_dir'], model_save_file)
-
- model_save_latest_file = None
- if args['save_latest_name']:
- model_save_latest_file = os.path.join(args['save_dir'], args['save_latest_name'])
-
model_save_each_file = None
if args['save_each_name']:
model_save_each_file = os.path.join(args['save_dir'], args['save_each_name'])
@@ -485,14 +483,12 @@ def main(args=None):
pieces = os.path.splitext(model_save_each_file)
model_save_each_file = pieces[0] + "_%4d" + pieces[1]
- model_load_file = model_save_file
+ model_load_file = args['save_name']
if args['load_name']:
if os.path.exists(args['load_name']):
model_load_file = args['load_name']
else:
model_load_file = os.path.join(args['save_dir'], args['load_name'])
- elif args['mode'] == 'train' and args['save_latest_name']:
- model_load_file = model_save_latest_file
if args['retag_package'] is not None and args['mode'] != 'remove_optimizer':
if '_' in args['retag_package']:
@@ -509,7 +505,7 @@ def main(args=None):
retag_pipeline = None
if args['mode'] == 'train':
- trainer.train(args, model_save_file, model_load_file, model_save_latest_file, model_save_each_file, retag_pipeline)
+ trainer.train(args, model_load_file, model_save_each_file, retag_pipeline)
elif args['mode'] == 'predict':
trainer.evaluate(args, model_load_file, retag_pipeline)
elif args['mode'] == 'remove_optimizer':
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index 1e05f249..910b11fb 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -4,6 +4,7 @@ import tempfile
import pytest
import torch
+from torch import optim
from stanza import Pipeline
@@ -161,12 +162,10 @@ class TestTrainer:
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
- save_name = os.path.join(args['save_dir'], args['save_name'])
- latest_name = os.path.join(args['save_dir'], 'latest.pt')
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
- assert not os.path.exists(save_name)
+ assert not os.path.exists(args['save_name'])
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
- tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline)
+ tr = trainer.train(args, None, each_name, retag_pipeline)
# check that hooks are in the model if expected
for p in tr.model.parameters():
if p.requires_grad:
@@ -176,8 +175,8 @@ class TestTrainer:
assert p._backward_hooks is None
# check that the model can be loaded back
- assert os.path.exists(save_name)
- tr = trainer.Trainer.load(save_name, load_optimizer=True)
+ assert os.path.exists(args['save_name'])
+ tr = trainer.Trainer.load(args['save_name'], load_optimizer=True)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained >= 1
@@ -185,7 +184,7 @@ class TestTrainer:
if p.requires_grad:
assert p._backward_hooks is None
- tr = trainer.Trainer.load(latest_name, load_optimizer=True)
+ tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained == num_epochs
@@ -205,12 +204,13 @@ class TestTrainer:
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname)
- def run_multistage_tests(self, wordvec_pretrain_file, use_lattn):
- with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = ['--multistage', '--pattn_num_layers', '1']
if use_lattn:
args += ['--lattn_d_proj', '16']
+ if extra_args:
+ args += extra_args
args = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
@@ -241,7 +241,8 @@ class TestTrainer:
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
- self.run_multistage_tests(wordvec_pretrain_file, use_lattn=True)
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)
def test_multistage_no_lattn(self, wordvec_pretrain_file):
"""
@@ -249,7 +250,35 @@ class TestTrainer:
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
- self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False)
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)
+
+ def test_multistage_optimizer(self, wordvec_pretrain_file):
+ """
+ Test that the correct optimizers are built for a multistage training process
+ """
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ extra_args = ['--optim', 'adamw']
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)
+
+ # check that the optimizers which get rebuilt when loading
+ # the models are adadelta for the first half of the
+ # multistage, then adamw
+ each_name = os.path.join(tmpdirname, 'each_%02d.pt')
+ for i in range(1, 3):
+ model_name = each_name % i
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
+ assert tr.epochs_trained == i
+ assert isinstance(tr.optimizer, optim.Adadelta)
+ # double check that this is actually a valid test
+ assert not isinstance(tr.optimizer, optim.AdamW)
+
+ for i in range(4, 8):
+ model_name = each_name % i
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
+ assert tr.epochs_trained == i
+ assert isinstance(tr.optimizer, optim.AdamW)
+
def test_hooks(self, wordvec_pretrain_file):
"""