diff options
author | John Bauer <horatio@gmail.com> | 2022-09-28 04:58:59 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-29 04:00:17 +0300 |
commit | e4f6b9a9562e713d390f43d96715eccf5ae127bf (patch) | |
tree | 4a578896b792048f3dda56e337adce9976ba4d9f | |
parent | 34a956d61fa8bd11990d2d41525a6c2833c894b5 (diff) |
Don't remove pattn... we rely on dropout to prevent pattn from going to 0con_vector_dropout
update test to match
-rw-r--r-- | stanza/models/constituency/trainer.py | 5 | ||||
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 16 |
2 files changed, 9 insertions, 12 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index f8015fd8..0da66639 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -415,13 +415,12 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil scheduler = build_scheduler(args, optimizer) trainer = Trainer(args, model, optimizer, scheduler) elif args['multistage']: - # run adadelta over the model for half the time with no pattn or lattn + # run adadelta over the model for half the time with no lattn # training then switches to a different optimizer for the rest # this works surprisingly well logger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2) temp_args = dict(args) - # remove the attention layers for the temporary model - temp_args['pattn_num_layers'] = 0 + # remove the lattn layer for the temporary model (need to experiment with varying numbers of pattn, lattn, etc) temp_args['lattn_d_proj'] = 0 temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args) diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 910b11fb..364daf8d 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -222,18 +222,16 @@ class TestTrainer: assert tr.epochs_trained == i word_input_sizes[tr.model.word_input_size].append(i) if use_lattn: - # there should be three stages: no attn, pattn, pattn+lattn - assert len(word_input_sizes) == 3 + # there should be two stages: pattn, pattn+lattn + assert len(word_input_sizes) == 2 word_input_keys = sorted(word_input_sizes.keys()) - assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4] - assert word_input_sizes[word_input_keys[1]] == [5, 6] - assert word_input_sizes[word_input_keys[2]] == [7, 8] + assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4, 5, 6] + assert word_input_sizes[word_input_keys[1]] == [7, 8] else: - # with no lattn, there are two stages: no attn, pattn - assert len(word_input_sizes) == 2 + # with no lattn, there is just one stage + assert len(word_input_sizes) == 1 word_input_keys = sorted(word_input_sizes.keys()) - assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4] - assert word_input_sizes[word_input_keys[1]] == [5, 6, 7, 8] + assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4, 5, 6, 7, 8] def test_multistage_lattn(self, wordvec_pretrain_file): """ |