Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-28 04:58:59 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-29 04:00:17 +0300
commite4f6b9a9562e713d390f43d96715eccf5ae127bf (patch)
tree4a578896b792048f3dda56e337adce9976ba4d9f
parent34a956d61fa8bd11990d2d41525a6c2833c894b5 (diff)
Don't remove pattn... we rely on dropout to prevent pattn from going to 0con_vector_dropout
update test to match
-rw-r--r--stanza/models/constituency/trainer.py5
-rw-r--r--stanza/tests/constituency/test_trainer.py16
2 files changed, 9 insertions, 12 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index f8015fd8..0da66639 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -415,13 +415,12 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
scheduler = build_scheduler(args, optimizer)
trainer = Trainer(args, model, optimizer, scheduler)
elif args['multistage']:
- # run adadelta over the model for half the time with no pattn or lattn
+ # run adadelta over the model for half the time with no lattn
# training then switches to a different optimizer for the rest
# this works surprisingly well
logger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2)
temp_args = dict(args)
- # remove the attention layers for the temporary model
- temp_args['pattn_num_layers'] = 0
+ # remove the lattn layer for the temporary model (need to experiment with varying numbers of pattn, lattn, etc)
temp_args['lattn_d_proj'] = 0
temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args)
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index 910b11fb..364daf8d 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -222,18 +222,16 @@ class TestTrainer:
assert tr.epochs_trained == i
word_input_sizes[tr.model.word_input_size].append(i)
if use_lattn:
- # there should be three stages: no attn, pattn, pattn+lattn
- assert len(word_input_sizes) == 3
+ # there should be two stages: pattn, pattn+lattn
+ assert len(word_input_sizes) == 2
word_input_keys = sorted(word_input_sizes.keys())
- assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4]
- assert word_input_sizes[word_input_keys[1]] == [5, 6]
- assert word_input_sizes[word_input_keys[2]] == [7, 8]
+ assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4, 5, 6]
+ assert word_input_sizes[word_input_keys[1]] == [7, 8]
else:
- # with no lattn, there are two stages: no attn, pattn
- assert len(word_input_sizes) == 2
+ # with no lattn, there is just one stage
+ assert len(word_input_sizes) == 1
word_input_keys = sorted(word_input_sizes.keys())
- assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4]
- assert word_input_sizes[word_input_keys[1]] == [5, 6, 7, 8]
+ assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4, 5, 6, 7, 8]
def test_multistage_lattn(self, wordvec_pretrain_file):
"""