Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-09 09:09:52 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-10 00:11:47 +0300
commit1548568455a13cb2c37ad1cf053225ee068f63b0 (patch)
treecffe01c842519ba1dd320daad354cf04dfce61e7
parent2db43c834bc8adbb8b096cf135f0fab8b8d886cb (diff)
Verify that hooks behave as expected when loading & saving
-rw-r--r--stanza/models/constituency/trainer.py9
-rw-r--r--stanza/tests/constituency/test_trainer.py20
2 files changed, 24 insertions, 5 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index af703cf5..bd776c1e 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -479,11 +479,13 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_
foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file)
- iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator)
+ trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator)
if args['wandb']:
wandb.finish()
+ return trainer
+
TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])
class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used'])):
@@ -507,9 +509,6 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
batch predict the model's interpretation of the current states
add the errors to the list of things to backprop
advance the parsing state for each of the trees
-
- Currently the only method implemented for advancing the parsing state
- is to use the gold transition.
"""
model = trainer.model
@@ -622,6 +621,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
add_grad_clipping(trainer, args['grad_clipping'])
model = new_model
+ return trainer
+
def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args):
interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
random.shuffle(interval_starts)
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index 6cb6a5da..ad94ed22 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -166,7 +166,14 @@ class TestTrainer:
each_name = os.path.join(args['save_dir'], 'each_%2d.pt')
assert not os.path.exists(save_name)
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
- trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline)
+ tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline)
+ # check that hooks are in the model if expected
+ for p in tr.model.parameters():
+ if p.requires_grad:
+ if args['grad_clipping'] is not None:
+ assert len(p._backward_hooks) == 1
+ else:
+ assert p._backward_hooks is None
# check that the model can be loaded back
assert os.path.exists(save_name)
@@ -174,6 +181,9 @@ class TestTrainer:
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained >= 1
+ for p in tr.model.parameters():
+ if p.requires_grad:
+ assert p._backward_hooks is None
tr = trainer.Trainer.load(latest_name, load_optimizer=True)
assert tr.optimizer is not None
@@ -240,3 +250,11 @@ class TestTrainer:
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False)
+
+ def test_hooks(self, wordvec_pretrain_file):
+ """
+ Verify that grad clipping is not saved with the model, but is attached at training time
+ """
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ args = ['--grad_clipping', '25']
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)