Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-06-25 19:45:26 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-24 02:56:24 +0300
commitba5d388904e8be0772faa2983273af8fe426ada7 (patch)
tree2c4e310b78a98697b0dfabff151298495806b2fc
parentffd9ca03eb43b646c76752b3d8ff9432c22b90bd (diff)
Restart transitions when restarting trainingcon_restart_transitions
-rw-r--r--stanza/models/constituency/lstm_model.py13
-rw-r--r--stanza/tests/constituency/test_lstm_model.py57
2 files changed, 68 insertions, 2 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index a2c8ddbd..f591ce76 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -468,6 +468,8 @@ class LSTMModel(BaseModel, nn.Module):
This will rebuild the model in such a way that the outputs will be
exactly the same as the previous model.
"""
+ transitions_died = (torch.linalg.norm(other.transition_start_embedding).item() < 1e-08 and
+ torch.linalg.norm(other.transition_embedding.weight).item() < 1e-08)
for name, other_parameter in other.named_parameters():
if name.startswith('word_lstm.weight_ih_l0'):
# bottom layer shape may have changed from adding a new pattn / lattn block
@@ -477,6 +479,17 @@ class LSTMModel(BaseModel, nn.Module):
new_values = torch.zeros_like(my_parameter.data)
new_values[:, :copy_size] = other_parameter.data[:, :copy_size]
my_parameter.data.copy_(new_values)
+ elif transitions_died and name.startswith('transition_'):
+ # keep our new random layers
+ # the output layer will be zeroed for the transitions
+ # ... it probably already is zero, if this happened
+ continue
+ elif transitions_died and name == 'output_layers.0.weight':
+ my_parameter = self.get_parameter(name)
+ new_values = torch.zeros_like(my_parameter.data)
+ new_values[:, :self.hidden_size] = other_parameter.data[:, :self.hidden_size]
+ new_values[:, -self.hidden_size:] = other_parameter.data[:, -self.hidden_size:]
+ my_parameter.data.copy_(new_values)
else:
self.get_parameter(name).data.copy_(other_parameter.data)
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
index 4165b546..80fe7053 100644
--- a/stanza/tests/constituency/test_lstm_model.py
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -2,6 +2,7 @@ import os
import pytest
import torch
+from torch import nn
from stanza.models.common import pretrain
from stanza.models.common.utils import set_random_seed
@@ -366,12 +367,12 @@ def check_structure_test(pretrain_file, args1, args2):
model.eval()
assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
- assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
+ assert not torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight)
model.copy_with_new_structure(other)
assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
- assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
+ assert torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight)
# the norms will be the same, as the non-zero values are all the same
assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))
@@ -396,6 +397,9 @@ def check_structure_test(pretrain_file, args1, args2):
assert torch.allclose(i.value.tree_hx, j.value.tree_hx)
assert torch.allclose(i.lstm_hx, j.lstm_hx)
assert torch.allclose(i.lstm_cx, j.lstm_cx)
+ model_hx = model(model_states)
+ other_hx = other(other_states)
+ assert torch.allclose(model_hx, other_hx)
def test_copy_with_new_structure_pattn(pretrain_file):
check_structure_test(pretrain_file,
@@ -411,3 +415,52 @@ def test_copy_with_new_structure_lattn(pretrain_file):
check_structure_test(pretrain_file,
['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'],
['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
+
+def test_copy_with_new_structure_transitions(pretrain_file):
+ args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '25', '--delta_embedding_dim', '10']
+ set_random_seed(1000, False)
+ other = build_model(pretrain_file, *args)
+ other.eval()
+
+ set_random_seed(1001, False)
+ model = build_model(pretrain_file, *args)
+ model.eval()
+
+ for name, other_parameter in other.named_parameters():
+ if name.startswith("transition_"):
+ nn.init.zeros_(other_parameter.data)
+
+ assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
+ assert not torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight)
+
+ model.copy_with_new_structure(other)
+
+ assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
+ assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
+ assert torch.allclose(model.output_layers[1].weight, other.output_layers[1].weight)
+ # the norms will be the same, as the non-zero values are all the same
+ assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))
+
+ # now, check that applying one transition to an initial state
+ # results in the same values in the output states for both models
+ # as the pattn layer inputs are 0, the output values should be equal
+ shift = [parse_transitions.Shift(), parse_transitions.Shift()]
+ model_states = test_parse_transitions.build_initial_state(model, 2)
+ model_states = parse_transitions.bulk_apply(model, model_states, shift)
+
+ other_states = test_parse_transitions.build_initial_state(other, 2)
+ other_states = parse_transitions.bulk_apply(other, other_states, shift)
+
+ for i, j in zip(other_states[0].word_queue, model_states[0].word_queue):
+ assert torch.allclose(i.hx, j.hx)
+ #for i, j in zip(other_states[0].transitions, model_states[0].transitions):
+ # assert torch.allclose(i.output, j.output)
+ # assert torch.allclose(i.lstm_hx, j.lstm_hx)
+ # assert torch.allclose(i.lstm_cx, j.lstm_cx)
+ for i, j in zip(other_states[0].constituents, model_states[0].constituents):
+ assert torch.allclose(i.tree_hx, j.tree_hx)
+ assert torch.allclose(i.lstm_hx, j.lstm_hx)
+ assert torch.allclose(i.lstm_cx, j.lstm_cx)
+ model_hx = model(model_states)
+ other_hx = other(other_states)
+ assert torch.allclose(model_hx, other_hx)