From 0e6de808eacf14cd64622415eeaeeac2d60faab2 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 8 Sep 2022 18:23:30 -0700 Subject: Always save checkpoints. Always load from a checkpoint if one exists. Build the constituency optimizer using knowledge of how far you are in the training process - multistage part 1 gets Adadelta, for example Test that a multistage training process builds the correct optimizers, including when reloading When continuing training from a checkpoint, use the existing epochs_trained Restart epochs count when doing a finetune --- stanza/models/constituency/trainer.py | 84 +++++++++++++++++-------------- stanza/models/constituency/utils.py | 51 ++++++++++++++----- stanza/models/constituency_parser.py | 26 ++++------ stanza/tests/constituency/test_trainer.py | 51 +++++++++++++++---- 4 files changed, 136 insertions(+), 76 deletions(-) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index bd776c1e..1bea94a9 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -132,8 +132,12 @@ class Trainer: if saved_args['cuda']: model.cuda() + epochs_trained = checkpoint.get('epochs_trained', 0) + if load_optimizer: - optimizer = build_optimizer(saved_args, model) + # need to match the optimizer we build with the one that was used at training time + build_simple_adadelta = checkpoint['args']['multistage'] and epochs_trained < checkpoint['args']['epochs'] // 2 + optimizer = build_optimizer(saved_args, model, build_simple_adadelta) if checkpoint.get('optimizer_state_dict', None) is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) @@ -151,8 +155,6 @@ class Trainer: for k in model.args.keys(): logger.debug(" --%s: %s", k, model.args[k]) - epochs_trained = checkpoint.get('epochs_trained', -1) - return Trainer(args=saved_args, model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained) @@ -348,9 +350,25 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file']) bert_model, bert_tokenizer = foundation_cache.load_bert(args['bert_model']) - if args['finetune'] or (args['maybe_finetune'] and os.path.exists(model_load_file)): - logger.info("Loading model to continue training from %s", model_load_file) + trainer = None + if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']): + logger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name']) + trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache) + # grad clipping is not saved with the rest of the model + add_grad_clipping(trainer, args['grad_clipping']) + + # TODO: turn finetune, relearn_structure, multistage into an enum + # finetune just means continue learning, so checkpoint is sufficient + # relearn_structure is essentially a one stage multistage + # multistage with a checkpoint will have the proper optimizer for that epoch + # and no special learning mode means we are training a new model and should continue + return trainer, train_sequences, train_transitions + + if args['finetune']: + logger.info("Loading model to finetune: %s", model_load_file) trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=foundation_cache) + # a new finetuning will start with a new epochs_trained count + trainer.epochs_trained = 0 elif args['relearn_structure']: logger.info("Loading model to continue training with new structure from %s", model_load_file) temp_args = dict(args) @@ -366,7 +384,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil if args['cuda']: model.cuda() model.copy_with_new_structure(trainer.model) - optimizer = build_optimizer(args, model) + optimizer = build_optimizer(args, model, False) scheduler = build_scheduler(args, optimizer) trainer = Trainer(args, model, optimizer, scheduler) elif args['multistage']: @@ -375,13 +393,6 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil # this works surprisingly well logger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2) temp_args = dict(args) - temp_args['optim'] = 'adadelta' - temp_args['learning_rate'] = DEFAULT_LEARNING_RATES['adadelta'] - temp_args['learning_eps'] = DEFAULT_LEARNING_EPS['adadelta'] - temp_args['learning_rho'] = DEFAULT_LEARNING_RHO - temp_args['weight_decay'] = DEFAULT_WEIGHT_DECAY['adadelta'] - temp_args['epochs'] = args['epochs'] // 2 - temp_args['learning_rate_warmup'] = 0 # remove the attention layers for the temporary model temp_args['pattn_num_layers'] = 0 temp_args['lattn_d_proj'] = 0 @@ -389,8 +400,8 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args) if args['cuda']: temp_model.cuda() - temp_optim = build_optimizer(temp_args, temp_model) - scheduler = build_scheduler(args, temp_optim) + temp_optim = build_optimizer(temp_args, temp_model, True) + scheduler = build_scheduler(temp_args, temp_optim) trainer = Trainer(temp_args, temp_model, temp_optim, scheduler) else: model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args) @@ -398,7 +409,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil model.cuda() logger.info("Number of words in the training set found in the embedding: {} out of {}".format(model.num_words_known(words), len(words))) - optimizer = build_optimizer(args, model) + optimizer = build_optimizer(args, model, False) scheduler = build_scheduler(args, optimizer) trainer = Trainer(args, model, optimizer, scheduler) @@ -435,7 +446,7 @@ def remove_no_tags(trees): logger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees))) return new_trees -def train(args, model_save_file, model_load_file, model_save_latest_file, model_save_each_file, retag_pipeline): +def train(args, model_load_file, model_save_each_file, retag_pipeline): """ Build a model, train it using the requested train & dev files """ @@ -479,7 +490,7 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_ foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file) - trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator) + trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_each_file, evaluator) if args['wandb']: wandb.finish() @@ -498,7 +509,7 @@ class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used) -def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_filename, model_latest_filename, model_save_each_filename, evaluator): +def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator): """ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model @@ -543,9 +554,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d leftover_training_data = [] best_f1 = 0.0 best_epoch = 0 - for epoch in range(1, args['epochs']+1): + # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1 + for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1): model.train() - logger.info("Starting epoch %d", epoch) + logger.info("Starting epoch %d", trainer.epochs_trained) if args['log_norms']: model.log_norms() epoch_data = leftover_training_data @@ -556,7 +568,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d epoch_data = epoch_data[:args['epoch_size']] epoch_data.sort(key=lambda x: len(x[1])) - epoch_stats = train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args) + epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args) # print statistics f1 = run_dev_set(model, dev_trees, args, evaluator) @@ -566,16 +578,16 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d # very simple model didn't learn anything logger.info("New best dev score: %.5f > %.5f", f1, best_f1) best_f1 = f1 - best_epoch = epoch - trainer.save(model_filename, save_optimizer=True) - if model_latest_filename: - trainer.save(model_latest_filename, save_optimizer=True) + best_epoch = trainer.epochs_trained + trainer.save(args['save_name'], save_optimizer=False) + if args['checkpoint'] and args['checkpoint_save_name']: + trainer.save(args['checkpoint_save_name'], save_optimizer=True) if model_save_each_filename: - trainer.save(model_save_each_filename % epoch, save_optimizer=True) - logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", epoch, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, epoch, f1, best_epoch, best_f1) + trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True) + logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, best_epoch, best_f1) if args['wandb']: - wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=epoch) + wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained) if args['wandb_norm_regex']: watch_regex = re.compile(args['wandb_norm_regex']) for n, p in model.named_parameters(): @@ -583,22 +595,20 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d wandb.log({n: torch.linalg.norm(p)}) # don't redo the optimizer a second time if we're not changing the structure - if args['multistage'] and epoch in multistage_splits: - # TODO: start counting epoch from trainer.epochs_trained for a previously trained model? - + if args['multistage'] and trainer.epochs_trained in multistage_splits: # we may be loading a save model from an earlier epoch if the scores stopped increasing epochs_trained = trainer.epochs_trained - stage_pattn_layers, stage_uses_lattn = multistage_splits[epoch] + stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained] # when loading the model, let the saved model determine whether it has pattn or lattn temp_args = dict(trainer.args) temp_args.pop('pattn_num_layers', None) temp_args.pop('lattn_d_proj', None) # overwriting the old trainer & model will hopefully free memory - trainer = Trainer.load(model_filename, temp_args, load_optimizer=False, foundation_cache=foundation_cache) + trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False, foundation_cache=foundation_cache) model = trainer.model - logger.info("Finished stage at epoch %d. Restarting optimizer", epoch) + logger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained) logger.info("Previous best model was at epoch %d", trainer.epochs_trained) temp_args = dict(args) @@ -615,7 +625,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d new_model.cuda() new_model.copy_with_new_structure(model) - optimizer = build_optimizer(temp_args, new_model) + optimizer = build_optimizer(temp_args, new_model, False) scheduler = build_scheduler(temp_args, optimizer) trainer = Trainer(temp_args, new_model, optimizer, scheduler, epochs_trained) add_grad_clipping(trainer, args['grad_clipping']) @@ -648,8 +658,6 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio if old_lr != new_lr: logger.info("Updating learning rate from %f to %f", old_lr, new_lr) - trainer.epochs_trained += 1 - # TODO: refactor the logging? total_correct = sum(v for _, v in epoch_stats.transitions_correct.items()) total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items()) diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index a6a05119..ca642551 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -4,6 +4,7 @@ Collects a few of the conparser utility methods which don't belong elsewhere from collections import deque import copy +import logging import torch.nn as nn from torch import optim @@ -15,6 +16,8 @@ DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } DEFAULT_LEARNING_RHO = 0.9 DEFAULT_MOMENTUM = { "madgrad": 0.9, "sgd": 0.9 } +logger = logging.getLogger('stanza') + # madgrad experiment for weight decay # with learning_rate set to 0.0000007 and momentum 0.9 # on en_wsj, with a baseline model trained on adadela for 200, @@ -116,32 +119,56 @@ def build_nonlinearity(nonlinearity): return NONLINEARITY[nonlinearity]() raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity) -def build_optimizer(args, model): +def build_optimizer(args, model, build_simple_adadelta=False): """ Build an optimizer based on the arguments given + + If we are "multistage" training and epochs_trained < epochs // 2, + we build an AdaDelta optimizer instead of whatever was requested + The build_simple_adadelta parameter controls this """ + if build_simple_adadelta: + optim_type = 'adadelta' + learning_eps = DEFAULT_LEARNING_EPS['adadelta'] + learning_rate = DEFAULT_LEARNING_RATES['adadelta'] + learning_rho = DEFAULT_LEARNING_RHO + weight_decay = DEFAULT_WEIGHT_DECAY['adadelta'] + else: + optim_type = args['optim'].lower() + learning_beta2 = args['learning_beta2'] + learning_eps = args['learning_eps'] + learning_rate = args['learning_rate'] + learning_rho = args['learning_rho'] + momentum = args['momentum'] + weight_decay = args['weight_decay'] + parameters = [param for name, param in model.named_parameters() if not model.is_unsaved_module(name)] - if args['optim'].lower() == 'sgd': - optimizer = optim.SGD(parameters, lr=args['learning_rate'], momentum=args['momentum'], weight_decay=args['weight_decay']) - elif args['optim'].lower() == 'adadelta': - optimizer = optim.Adadelta(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], rho=args['learning_rho']) - elif args['optim'].lower() == 'adamw': - optimizer = optim.AdamW(parameters, lr=args['learning_rate'], betas=(0.9, args['learning_beta2']), eps=args['learning_eps'], weight_decay=args['weight_decay']) - elif args['optim'].lower() == 'adabelief': + if optim_type == 'sgd': + logger.info("Building SGD with lr=%f, momentum=%f, weight_decay=%f", learning_rate, momentum, weight_decay) + optimizer = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay) + elif optim_type == 'adadelta': + logger.info("Building Adadelta with lr=%f, eps=%f, weight_decay=%f, rho=%f", learning_rate, learning_eps, weight_decay, learning_rho) + optimizer = optim.Adadelta(parameters, lr=learning_rate, eps=learning_eps, weight_decay=weight_decay, rho=learning_rho) + elif optim_type == 'adamw': + logger.info("Building AdamW with lr=%f, beta2=%f, eps=%f, weight_decay=%f", learning_rate, learning_beta2, learning_eps, weight_decay) + optimizer = optim.AdamW(parameters, lr=learning_rate, betas=(0.9, learning_beta2), eps=learning_eps, weight_decay=weight_decay) + elif optim_type == 'adabelief': try: from adabelief_pytorch import AdaBelief except ModuleNotFoundError as e: raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e + logger.info("Building AdaBelief with lr=%f, eps=%f, weight_decay=%f", learning_rate, learning_eps, weight_decay) # TODO: make these args - optimizer = AdaBelief(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], weight_decouple=False, rectify=False) - elif args['optim'].lower() == 'madgrad': + optimizer = AdaBelief(parameters, lr=learning_rate, eps=learning_eps, weight_decay=weight_decay, weight_decouple=False, rectify=False) + elif optim_type == 'madgrad': try: import madgrad except ModuleNotFoundError as e: raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e - optimizer = madgrad.MADGRAD(parameters, lr=args['learning_rate'], weight_decay=args['weight_decay'], momentum=args['momentum']) + logger.info("Building AdaBelief with lr=%f, weight_decay=%f, momentum=%f", learning_rate, weight_decay, momentum) + optimizer = madgrad.MADGRAD(parameters, lr=learning_rate, weight_decay=weight_decay, momentum=momentum) else: - raise ValueError("Unknown optimizer: %s" % args['optim']) + raise ValueError("Unknown optimizer: %s" % optim) return optimizer def build_scheduler(args, optimizer): diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 91d1d23c..1c588e01 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -242,7 +242,6 @@ def parse_args(args=None): parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.') parser.add_argument('--save_name', type=str, default=None, help="File name to save the model") - parser.add_argument('--save_latest_name', type=str, default=None, help="Save the latest model here regardless of score. Useful for restarting training") parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing") parser.add_argument('--seed', type=int, default=1234) @@ -376,7 +375,8 @@ def parse_args(args=None): parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw') parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path') - parser.add_argument('--maybe_finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path if it exists. Useful for running in situations where a job is frequently being preempted') + parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") + parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file') parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') @@ -450,8 +450,13 @@ def parse_args(args=None): else: raise ValueError("Unknown retag method {}".format(xpos)) - if args['multistage'] and (args['finetune'] or args['maybe_finetune'] or args['relearn_structure']): - raise ValueError('Learning multistage from a previously started model is not yet implemented. TODO') + model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand']) + + if args['checkpoint']: + args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name']) + + model_save_file = os.path.join(args['save_dir'], model_save_file) + args['save_name'] = model_save_file return args @@ -468,13 +473,6 @@ def main(args=None): logger.info("Running constituency parser in %s mode", args['mode']) logger.debug("Using GPU: %s", args['cuda']) - model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand']) - model_save_file = os.path.join(args['save_dir'], model_save_file) - - model_save_latest_file = None - if args['save_latest_name']: - model_save_latest_file = os.path.join(args['save_dir'], args['save_latest_name']) - model_save_each_file = None if args['save_each_name']: model_save_each_file = os.path.join(args['save_dir'], args['save_each_name']) @@ -485,14 +483,12 @@ def main(args=None): pieces = os.path.splitext(model_save_each_file) model_save_each_file = pieces[0] + "_%4d" + pieces[1] - model_load_file = model_save_file + model_load_file = args['save_name'] if args['load_name']: if os.path.exists(args['load_name']): model_load_file = args['load_name'] else: model_load_file = os.path.join(args['save_dir'], args['load_name']) - elif args['mode'] == 'train' and args['save_latest_name']: - model_load_file = model_save_latest_file if args['retag_package'] is not None and args['mode'] != 'remove_optimizer': if '_' in args['retag_package']: @@ -509,7 +505,7 @@ def main(args=None): retag_pipeline = None if args['mode'] == 'train': - trainer.train(args, model_save_file, model_load_file, model_save_latest_file, model_save_each_file, retag_pipeline) + trainer.train(args, model_load_file, model_save_each_file, retag_pipeline) elif args['mode'] == 'predict': trainer.evaluate(args, model_load_file, retag_pipeline) elif args['mode'] == 'remove_optimizer': diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 1e05f249..910b11fb 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -4,6 +4,7 @@ import tempfile import pytest import torch +from torch import optim from stanza import Pipeline @@ -161,12 +162,10 @@ class TestTrainer: train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) - save_name = os.path.join(args['save_dir'], args['save_name']) - latest_name = os.path.join(args['save_dir'], 'latest.pt') each_name = os.path.join(args['save_dir'], 'each_%02d.pt') - assert not os.path.exists(save_name) + assert not os.path.exists(args['save_name']) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) - tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + tr = trainer.train(args, None, each_name, retag_pipeline) # check that hooks are in the model if expected for p in tr.model.parameters(): if p.requires_grad: @@ -176,8 +175,8 @@ class TestTrainer: assert p._backward_hooks is None # check that the model can be loaded back - assert os.path.exists(save_name) - tr = trainer.Trainer.load(save_name, load_optimizer=True) + assert os.path.exists(args['save_name']) + tr = trainer.Trainer.load(args['save_name'], load_optimizer=True) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained >= 1 @@ -185,7 +184,7 @@ class TestTrainer: if p.requires_grad: assert p._backward_hooks is None - tr = trainer.Trainer.load(latest_name, load_optimizer=True) + tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained == num_epochs @@ -205,12 +204,13 @@ class TestTrainer: with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname) - def run_multistage_tests(self, wordvec_pretrain_file, use_lattn): - with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = ['--multistage', '--pattn_num_layers', '1'] if use_lattn: args += ['--lattn_d_proj', '16'] + if extra_args: + args += extra_args args = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args) each_name = os.path.join(args['save_dir'], 'each_%02d.pt') @@ -241,7 +241,8 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ - self.run_multistage_tests(wordvec_pretrain_file, use_lattn=True) + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True) def test_multistage_no_lattn(self, wordvec_pretrain_file): """ @@ -249,7 +250,35 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ - self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False) + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False) + + def test_multistage_optimizer(self, wordvec_pretrain_file): + """ + Test that the correct optimizers are built for a multistage training process + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + extra_args = ['--optim', 'adamw'] + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args) + + # check that the optimizers which get rebuilt when loading + # the models are adadelta for the first half of the + # multistage, then adamw + each_name = os.path.join(tmpdirname, 'each_%02d.pt') + for i in range(1, 3): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.Adadelta) + # double check that this is actually a valid test + assert not isinstance(tr.optimizer, optim.AdamW) + + for i in range(4, 8): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.AdamW) + def test_hooks(self, wordvec_pretrain_file): """ -- cgit v1.2.3