diff options
Diffstat (limited to 'stanza/tests/constituency/test_trainer.py')
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 51 |
1 files changed, 40 insertions, 11 deletions
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 1e05f249..910b11fb 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -4,6 +4,7 @@ import tempfile import pytest import torch +from torch import optim from stanza import Pipeline @@ -161,12 +162,10 @@ class TestTrainer: train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) - save_name = os.path.join(args['save_dir'], args['save_name']) - latest_name = os.path.join(args['save_dir'], 'latest.pt') each_name = os.path.join(args['save_dir'], 'each_%02d.pt') - assert not os.path.exists(save_name) + assert not os.path.exists(args['save_name']) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) - tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + tr = trainer.train(args, None, each_name, retag_pipeline) # check that hooks are in the model if expected for p in tr.model.parameters(): if p.requires_grad: @@ -176,8 +175,8 @@ class TestTrainer: assert p._backward_hooks is None # check that the model can be loaded back - assert os.path.exists(save_name) - tr = trainer.Trainer.load(save_name, load_optimizer=True) + assert os.path.exists(args['save_name']) + tr = trainer.Trainer.load(args['save_name'], load_optimizer=True) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained >= 1 @@ -185,7 +184,7 @@ class TestTrainer: if p.requires_grad: assert p._backward_hooks is None - tr = trainer.Trainer.load(latest_name, load_optimizer=True) + tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True) assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained == num_epochs @@ -205,12 +204,13 @@ class TestTrainer: with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname) - def run_multistage_tests(self, wordvec_pretrain_file, use_lattn): - with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = ['--multistage', '--pattn_num_layers', '1'] if use_lattn: args += ['--lattn_d_proj', '16'] + if extra_args: + args += extra_args args = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args) each_name = os.path.join(args['save_dir'], 'each_%02d.pt') @@ -241,7 +241,8 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ - self.run_multistage_tests(wordvec_pretrain_file, use_lattn=True) + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True) def test_multistage_no_lattn(self, wordvec_pretrain_file): """ @@ -249,7 +250,35 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ - self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False) + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False) + + def test_multistage_optimizer(self, wordvec_pretrain_file): + """ + Test that the correct optimizers are built for a multistage training process + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + extra_args = ['--optim', 'adamw'] + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args) + + # check that the optimizers which get rebuilt when loading + # the models are adadelta for the first half of the + # multistage, then adamw + each_name = os.path.join(tmpdirname, 'each_%02d.pt') + for i in range(1, 3): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.Adadelta) + # double check that this is actually a valid test + assert not isinstance(tr.optimizer, optim.AdamW) + + for i in range(4, 8): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.AdamW) + def test_hooks(self, wordvec_pretrain_file): """ |