Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/tests/constituency/test_trainer.py')
-rw-r--r--stanza/tests/constituency/test_trainer.py51
1 files changed, 40 insertions, 11 deletions
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index 1e05f249..910b11fb 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -4,6 +4,7 @@ import tempfile
import pytest
import torch
+from torch import optim
from stanza import Pipeline
@@ -161,12 +162,10 @@ class TestTrainer:
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
- save_name = os.path.join(args['save_dir'], args['save_name'])
- latest_name = os.path.join(args['save_dir'], 'latest.pt')
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
- assert not os.path.exists(save_name)
+ assert not os.path.exists(args['save_name'])
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
- tr = trainer.train(args, save_name, None, latest_name, each_name, retag_pipeline)
+ tr = trainer.train(args, None, each_name, retag_pipeline)
# check that hooks are in the model if expected
for p in tr.model.parameters():
if p.requires_grad:
@@ -176,8 +175,8 @@ class TestTrainer:
assert p._backward_hooks is None
# check that the model can be loaded back
- assert os.path.exists(save_name)
- tr = trainer.Trainer.load(save_name, load_optimizer=True)
+ assert os.path.exists(args['save_name'])
+ tr = trainer.Trainer.load(args['save_name'], load_optimizer=True)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained >= 1
@@ -185,7 +184,7 @@ class TestTrainer:
if p.requires_grad:
assert p._backward_hooks is None
- tr = trainer.Trainer.load(latest_name, load_optimizer=True)
+ tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True)
assert tr.optimizer is not None
assert tr.scheduler is not None
assert tr.epochs_trained == num_epochs
@@ -205,12 +204,13 @@ class TestTrainer:
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname)
- def run_multistage_tests(self, wordvec_pretrain_file, use_lattn):
- with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = ['--multistage', '--pattn_num_layers', '1']
if use_lattn:
args += ['--lattn_d_proj', '16']
+ if extra_args:
+ args += extra_args
args = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
@@ -241,7 +241,8 @@ class TestTrainer:
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
- self.run_multistage_tests(wordvec_pretrain_file, use_lattn=True)
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)
def test_multistage_no_lattn(self, wordvec_pretrain_file):
"""
@@ -249,7 +250,35 @@ class TestTrainer:
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
"""
- self.run_multistage_tests(wordvec_pretrain_file, use_lattn=False)
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)
+
+ def test_multistage_optimizer(self, wordvec_pretrain_file):
+ """
+ Test that the correct optimizers are built for a multistage training process
+ """
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ extra_args = ['--optim', 'adamw']
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)
+
+ # check that the optimizers which get rebuilt when loading
+ # the models are adadelta for the first half of the
+ # multistage, then adamw
+ each_name = os.path.join(tmpdirname, 'each_%02d.pt')
+ for i in range(1, 3):
+ model_name = each_name % i
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
+ assert tr.epochs_trained == i
+ assert isinstance(tr.optimizer, optim.Adadelta)
+ # double check that this is actually a valid test
+ assert not isinstance(tr.optimizer, optim.AdamW)
+
+ for i in range(4, 8):
+ model_name = each_name % i
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
+ assert tr.epochs_trained == i
+ assert isinstance(tr.optimizer, optim.AdamW)
+
def test_hooks(self, wordvec_pretrain_file):
"""