diff options
author | John Bauer <horatio@gmail.com> | 2022-09-09 09:09:52 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-10 00:11:47 +0300 |
commit | 1548568455a13cb2c37ad1cf053225ee068f63b0 (patch) | |
tree | cffe01c842519ba1dd320daad354cf04dfce61e7 /stanza | |
parent | 2db43c834bc8adbb8b096cf135f0fab8b8d886cb (diff) |
Verify that hooks behave as expected when loading & saving
Diffstat (limited to 'stanza')
-rw-r--r-- | stanza/models/constituency/trainer.py | 9 | ||||
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 20 |
2 files changed, 24 insertions, 5 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index af703cf5..bd776c1e 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -479,11 +479,13 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_ foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file) - iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator) + trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator) if args['wandb']: wandb.finish() + return trainer + TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used'])): @@ -507,9 +509,6 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d batch predict the model's interpretation of the current states add the errors to the list of things to backprop advance the parsing state for each of the trees - - Currently the only method implemented for advancing the parsing state - is to use the gold transition. """ model = trainer.model @@ -622,6 +621,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d add_grad_clipping(trainer, args['grad_clipping']) model = new_model + return trainer + def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args): interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) random.shuffle(interval_starts) diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 6cb6a5da..ad94ed22 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -166,7 +166,14 @@ class TestTrainer: each_name = os.path.join(args['save_dir'], 'each_%2d.pt') assert not os.path.exists(save_name) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) - trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline) + # check that hooks are in the model if expected + for p in tr.model.parameters(): + if p.requires_grad: + if args['grad_clipping'] is not None: + assert len(p._backward_hooks) == 1 + else: + assert p._backward_hooks is None # check that the model can be loaded back assert os.path.exists(save_name) @@ -174,6 +181,9 @@ class TestTrainer: assert tr.optimizer is not None assert tr.scheduler is not None assert tr.epochs_trained >= 1 + for p in tr.model.parameters(): + if p.requires_grad: + assert p._backward_hooks is None tr = trainer.Trainer.load(latest_name, load_optimizer=True) assert tr.optimizer is not None @@ -240,3 +250,11 @@ class TestTrainer: This should start with no pattn or lattn, have pattn in the middle, then lattn at the end """ self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False) + + def test_hooks(self, wordvec_pretrain_file): + """ + Verify that grad clipping is not saved with the model, but is attached at training time + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--grad_clipping', '25'] + self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) |