diff options
author | John Bauer <horatio@gmail.com> | 2022-06-25 19:45:26 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-24 02:56:24 +0300 |
commit | ba5d388904e8be0772faa2983273af8fe426ada7 (patch) | |
tree | 2c4e310b78a98697b0dfabff151298495806b2fc | |
parent | ffd9ca03eb43b646c76752b3d8ff9432c22b90bd (diff) |
Restart transitions when restarting trainingcon_restart_transitions
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 13 | ||||
-rw-r--r-- | stanza/tests/constituency/test_lstm_model.py | 57 |
2 files changed, 68 insertions, 2 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index a2c8ddbd..f591ce76 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -468,6 +468,8 @@ class LSTMModel(BaseModel, nn.Module): This will rebuild the model in such a way that the outputs will be exactly the same as the previous model. """ + transitions_died = (torch.linalg.norm(other.transition_start_embedding).item() < 1e-08 and + torch.linalg.norm(other.transition_embedding.weight).item() < 1e-08) for name, other_parameter in other.named_parameters(): if name.startswith('word_lstm.weight_ih_l0'): # bottom layer shape may have changed from adding a new pattn / lattn block @@ -477,6 +479,17 @@ class LSTMModel(BaseModel, nn.Module): new_values = torch.zeros_like(my_parameter.data) new_values[:, :copy_size] = other_parameter.data[:, :copy_size] my_parameter.data.copy_(new_values) + elif transitions_died and name.startswith('transition_'): + # keep our new random layers + # the output layer will be zeroed for the transitions + # ... it probably already is zero, if this happened + continue + elif transitions_died and name == 'output_layers.0.weight': + my_parameter = self.get_parameter(name) + new_values = torch.zeros_like(my_parameter.data) + new_values[:, :self.hidden_size] = other_parameter.data[:, :self.hidden_size] + new_values[:, -self.hidden_size:] = other_parameter.data[:, -self.hidden_size:] + my_parameter.data.copy_(new_values) else: self.get_parameter(name).data.copy_(other_parameter.data) diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index 4165b546..80fe7053 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -2,6 +2,7 @@ import os import pytest import torch +from torch import nn from stanza.models.common import pretrain from stanza.models.common.utils import set_random_seed @@ -366,12 +367,12 @@ def check_structure_test(pretrain_file, args1, args2): model.eval() assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) - assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) + assert not torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight) model.copy_with_new_structure(other) assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) - assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) + assert torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight) # the norms will be the same, as the non-zero values are all the same assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0)) @@ -396,6 +397,9 @@ def check_structure_test(pretrain_file, args1, args2): assert torch.allclose(i.value.tree_hx, j.value.tree_hx) assert torch.allclose(i.lstm_hx, j.lstm_hx) assert torch.allclose(i.lstm_cx, j.lstm_cx) + model_hx = model(model_states) + other_hx = other(other_states) + assert torch.allclose(model_hx, other_hx) def test_copy_with_new_structure_pattn(pretrain_file): check_structure_test(pretrain_file, @@ -411,3 +415,52 @@ def test_copy_with_new_structure_lattn(pretrain_file): check_structure_test(pretrain_file, ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'], ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) + +def test_copy_with_new_structure_transitions(pretrain_file): + args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '25', '--delta_embedding_dim', '10'] + set_random_seed(1000, False) + other = build_model(pretrain_file, *args) + other.eval() + + set_random_seed(1001, False) + model = build_model(pretrain_file, *args) + model.eval() + + for name, other_parameter in other.named_parameters(): + if name.startswith("transition_"): + nn.init.zeros_(other_parameter.data) + + assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) + assert not torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight) + + model.copy_with_new_structure(other) + + assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) + assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) + assert torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight) + # the norms will be the same, as the non-zero values are all the same + assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0)) + + # now, check that applying one transition to an initial state + # results in the same values in the output states for both models + # as the pattn layer inputs are 0, the output values should be equal + shift = [parse_transitions.Shift(), parse_transitions.Shift()] + model_states = test_parse_transitions.build_initial_state(model, 2) + model_states = parse_transitions.bulk_apply(model, model_states, shift) + + other_states = test_parse_transitions.build_initial_state(other, 2) + other_states = parse_transitions.bulk_apply(other, other_states, shift) + + for i, j in zip(other_states[0].word_queue, model_states[0].word_queue): + assert torch.allclose(i.hx, j.hx) + #for i, j in zip(other_states[0].transitions, model_states[0].transitions): + # assert torch.allclose(i.output, j.output) + # assert torch.allclose(i.lstm_hx, j.lstm_hx) + # assert torch.allclose(i.lstm_cx, j.lstm_cx) + for i, j in zip(other_states[0].constituents, model_states[0].constituents): + assert torch.allclose(i.tree_hx, j.tree_hx) + assert torch.allclose(i.lstm_hx, j.lstm_hx) + assert torch.allclose(i.lstm_cx, j.lstm_cx) + model_hx = model(model_states) + other_hx = other(other_states) + assert torch.allclose(model_hx, other_hx) |