diff options
author | John Bauer <horatio@gmail.com> | 2022-10-30 09:24:06 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-30 09:20:37 +0300 |
commit | b56c86ed8fabc7f5fdb288443d0a4fa01feb1ee5 (patch) | |
tree | 127fc3b176831332aafc63e57191d7833ddac673 | |
parent | 886365195fa768e84145f9c1334ae089232ac61b (diff) |
Since we just ran into a bug where checkpoints were not correctly loaded, add a test of exactly that functionality
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index c7104850..f6874725 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -1,5 +1,6 @@ from collections import defaultdict import logging +import pathlib import tempfile import pytest @@ -139,7 +140,8 @@ class TestTrainer: def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args): # let's not make the model huge... args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10', - '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname, '--save_dir', tmpdirname, '--save_name', 'test.pt', + '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname, + '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'), '--train_file', train_treebank_file, '--eval_file', eval_treebank_file, '--epoch_size', '6', '--train_batch_size', '3', '--shorthand', 'en_test'] @@ -149,7 +151,7 @@ class TestTrainer: args['wandb'] = None return args - def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False): + def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False): """ Runs a test of the trainer for a few iterations. @@ -165,8 +167,9 @@ class TestTrainer: extra_args += ['--silver_file', str(eval_treebank_file)] args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) - each_name = os.path.join(args['save_dir'], 'each_%02d.pt') - assert not os.path.exists(args['save_name']) + each_name = args['save_each_name'] + if not exists_ok: + assert not os.path.exists(args['save_name']) retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) tr = trainer.train(args, None, each_name, retag_pipeline) # check that hooks are in the model if expected @@ -219,6 +222,32 @@ class TestTrainer: with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True) + def test_train_checkpoint(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations, then restart + + This tests that the 5th iteration save file is not rewritten + and that the iterations continue to 10 + + TODO: could make it more robust by verifying that only 5 more + epochs are trained. Perhaps a "most recent epochs" could be + saved in the trainer + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False) + save_5 = args['save_each_name'] % 5 + save_10 = args['save_each_name'] % 10 + assert os.path.exists(save_5) + assert not os.path.exists(save_10) + + save_5_stat = pathlib.Path(save_5).stat() + + self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True) + assert os.path.exists(save_5) + assert os.path.exists(save_10) + + assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime + def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = ['--multistage', '--pattn_num_layers', '1'] |