diff options
Diffstat (limited to 'stanza/tests/constituency/test_trainer.py')
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 6cb6a5da..ad94ed22 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -166,7 +166,14 @@ class TestTrainer: each_name = os.path.join(args['save_dir'], 'each_%2d.pt') assert not os.path.exists(save_name) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) - trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + # check that hooks are in the model if expected + for p in tr.model.parameters(): + if p.requires_grad: + if args['grad_clipping'] is not None: + assert len(p._backward_hooks) == 1 + else: + assert p._backward_hooks is None # check that the model can be loaded back assert os.path.exists(save_name) @@ -174,6 +181,9 @@ class TestTrainer: assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained >= 1 + for p in tr.model.parameters(): + if p.requires_grad: + assert p._backward_hooks is None tr = trainer.Trainer.load(latest_name, load_optimizer=True) assert tr.optimizer is not None @@ -240,3 +250,11 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False) + + def test_hooks(self, wordvec_pretrain_file): + """ + Verify that grad clipping is not saved with the model, but is attached at training time + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--grad_clipping', '25'] + self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) |