Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-30 09:24:06 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-30 09:20:37 +0300
commitb56c86ed8fabc7f5fdb288443d0a4fa01feb1ee5 (patch)
tree127fc3b176831332aafc63e57191d7833ddac673
parent886365195fa768e84145f9c1334ae089232ac61b (diff)
Since we just ran into a bug where checkpoints were not correctly loaded, add a test of exactly that functionality
-rw-r--r--stanza/tests/constituency/test_trainer.py37
1 files changed, 33 insertions, 4 deletions
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index c7104850..f6874725 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -1,5 +1,6 @@
from collections import defaultdict
import logging
+import pathlib
import tempfile
import pytest
@@ -139,7 +140,8 @@ class TestTrainer:
def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args):
# let's not make the model huge...
args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10',
- '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname, '--save_dir', tmpdirname, '--save_name', 'test.pt',
+ '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname,
+ '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'),
'--train_file', train_treebank_file, '--eval_file', eval_treebank_file,
'--epoch_size', '6', '--train_batch_size', '3',
'--shorthand', 'en_test']
@@ -149,7 +151,7 @@ class TestTrainer:
args['wandb'] = None
return args
- def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False):
+ def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False):
"""
Runs a test of the trainer for a few iterations.
@@ -165,8 +167,9 @@ class TestTrainer:
extra_args += ['--silver_file', str(eval_treebank_file)]
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
- each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
- assert not os.path.exists(args['save_name'])
+ each_name = args['save_each_name']
+ if not exists_ok:
+ assert not os.path.exists(args['save_name'])
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
tr = trainer.train(args, None, each_name, retag_pipeline)
# check that hooks are in the model if expected
@@ -219,6 +222,32 @@ class TestTrainer:
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)
+ def test_train_checkpoint(self, wordvec_pretrain_file):
+ """
+ Test the whole thing for a few iterations, then restart
+
+ This tests that the 5th iteration save file is not rewritten
+ and that the iterations continue to 10
+
+ TODO: could make it more robust by verifying that only 5 more
+ epochs are trained. Perhaps a "most recent epochs" could be
+ saved in the trainer
+ """
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ args = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False)
+ save_5 = args['save_each_name'] % 5
+ save_10 = args['save_each_name'] % 10
+ assert os.path.exists(save_5)
+ assert not os.path.exists(save_10)
+
+ save_5_stat = pathlib.Path(save_5).stat()
+
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True)
+ assert os.path.exists(save_5)
+ assert os.path.exists(save_10)
+
+ assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime
+
def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = ['--multistage', '--pattn_num_layers', '1']