Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-18 07:16:43 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-19 06:27:21 +0300
commit0413d552e89fd401165ae8fb722d6de7e76cd20c (patch)
tree0c4f4c98b9aa44258827e3945204153ea4df4429
parent5a50652c50d52468e8bbe22c7029ecb701267fe9 (diff)
Also refactor a constituent_lstm_stack. The unary transitions are a little wonky to get right, but they kind of suck anywayrefactor_lstm
-rw-r--r--stanza/models/constituency/lstm_model.py56
-rw-r--r--stanza/models/constituency/lstm_tree_stack.py13
-rw-r--r--stanza/models/constituency/trainer.py11
-rw-r--r--stanza/tests/constituency/test_lstm_model.py4
4 files changed, 45 insertions, 39 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 0a384801..0da5461e 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -57,8 +57,6 @@ WordNode = namedtuple("WordNode", ['value', 'hx'])
# We do this to maintain consistency between the different operations,
# which sometimes result in different shapes
# This will be unsqueezed in order to put into the next layer if needed
-# lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents
-ConstituentNode = namedtuple("ConstituentNode", ['value', 'tree_hx', 'lstm_hx', 'lstm_cx'])
Constituent = namedtuple("Constituent", ['value', 'tree_hx'])
# The sentence boundary vectors are marginally useful at best.
@@ -255,7 +253,6 @@ class LSTMModel(BaseModel, nn.Module):
# also register a buffer of zeros so that we can always get zeros on the appropriate device
self.register_buffer('word_zeros', torch.zeros(self.hidden_size))
- self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
# possibly add a couple vectors for bookends of the sentence
# We put the word_start and word_end here, AFTER counting the
@@ -364,10 +361,14 @@ class LSTMModel(BaseModel, nn.Module):
self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
embedding_dim = self.hidden_size)
nn.init.normal_(self.constituent_open_embedding.weight, std=0.2)
- if self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING:
- self.register_parameter('constituent_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.hidden_size, requires_grad=True)))
+
# input_size is hidden_size - could introduce a new constituent_size instead if we liked
- self.constituent_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, dropout=self.lstm_layer_dropout)
+ self.constituent_lstm_stack = LSTMTreeStack(input_size=self.hidden_size,
+ hidden_size=self.hidden_size,
+ num_lstm_layers=self.num_lstm_layers,
+ dropout=self.lstm_layer_dropout,
+ uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
+ input_dropout=self.lstm_input_dropout)
if args['combined_dummy_embedding']:
self.dummy_embedding = self.constituent_open_embedding
@@ -613,15 +614,7 @@ class LSTMModel(BaseModel, nn.Module):
"""
Return an initial TreeStack with no constituents
"""
- if self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING:
- constituent_start = self.constituent_start_embedding.unsqueeze(0).unsqueeze(0)
- output, (hx, cx) = self.constituent_lstm(constituent_start)
- constituent_start = output[0, 0, :]
- else:
- constituent_start = self.constituent_zeros[-1, 0, :]
- hx = self.constituent_zeros
- cx = self.constituent_zeros
- return TreeStack(value=ConstituentNode(None, constituent_start, hx, cx), parent=None, length=1)
+ return self.constituent_lstm_stack.initial_state()
def get_word(self, word_node):
return word_node.value
@@ -639,11 +632,15 @@ class LSTMModel(BaseModel, nn.Module):
def unary_transform(self, constituents, labels):
# TODO: this can be faster by stacking things
- top_constituent = constituents.value
+ # the double dereference is because we expect the Constiuent
+ # wrapped in an LSTMTreeStack Node
+ top_constituent = constituents.value.value
for label in reversed(labels):
# double nested: the Constituent is in a list of just one child
# and there is just one item in the list (hence the stacking comment)
- top_constituent = self.build_constituents([(label,)], [[top_constituent]])[0]
+ # the fake Constituent is because normally the Constituent
+ # items are wrapped from the LSTMTreeStack
+ top_constituent = self.build_constituents([(label,)], [[Constituent(top_constituent, None)]])[0]
return top_constituent
def build_constituents(self, labels, children_lists):
@@ -654,7 +651,7 @@ class LSTMModel(BaseModel, nn.Module):
children_lists is a list of children that go under each of the new nodes
lists of each are used so that we can stack operations
"""
- node_hx = [[child.tree_hx for child in children] for children in children_lists]
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
if (self.constituency_composition == ConstituencyComposition.BILSTM or
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
@@ -713,7 +710,7 @@ class LSTMModel(BaseModel, nn.Module):
constituents = []
for idx, (label, children) in enumerate(zip(labels, children_lists)):
- children = [child.value for child in children]
+ children = [child.value.value for child in children]
if isinstance(label, str):
node = Tree(label=label, children=children)
else:
@@ -724,15 +721,6 @@ class LSTMModel(BaseModel, nn.Module):
return constituents
def push_constituents(self, constituent_stacks, constituents):
- current_nodes = [stack.value for stack in constituent_stacks]
-
- constituent_input = torch.stack([x.tree_hx for x in constituents])
- constituent_input = constituent_input.unsqueeze(0)
- constituent_input = self.lstm_input_dropout(constituent_input)
-
- hx = torch.cat([current_node.lstm_hx for current_node in current_nodes], axis=1)
- cx = torch.cat([current_node.lstm_cx for current_node in current_nodes], axis=1)
- output, (hx, cx) = self.constituent_lstm(constituent_input, (hx, cx))
# Another possibility here would be to use output[0, i, :]
# from the constituency lstm for the value of the new node.
# This might theoretically make the new constituent include
@@ -742,9 +730,12 @@ class LSTMModel(BaseModel, nn.Module):
# averaged over 5 trials, had the following loss in accuracy:
# 150 epochs: 0.8971 to 0.8953
# 200 epochs: 0.8985 to 0.8964
- new_stacks = [stack.push(ConstituentNode(constituent.value, constituents[i].tree_hx, hx[:, i:i+1, :], cx[:, i:i+1, :]))
- for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))]
- return new_stacks
+ current_nodes = [stack.value for stack in constituent_stacks]
+
+ constituent_input = torch.stack([x.tree_hx for x in constituents])
+ constituent_input = constituent_input.unsqueeze(0)
+ # the constituents are already Constituent(tree, tree_hx)
+ return self.constituent_lstm_stack.push_states(constituent_stacks, constituents, constituent_input)
def get_top_constituent(self, constituents):
"""
@@ -752,7 +743,8 @@ class LSTMModel(BaseModel, nn.Module):
sequence, even though it has multiple addition pieces of
information
"""
- constituent_node = constituents.value
+ # TreeStack value -> LSTMTreeStack value -> Constituent value
+ constituent_node = constituents.value.value
return constituent_node.value
def push_transitions(self, transition_stacks, transitions):
diff --git a/stanza/models/constituency/lstm_tree_stack.py b/stanza/models/constituency/lstm_tree_stack.py
index c53fe8f3..3832596d 100644
--- a/stanza/models/constituency/lstm_tree_stack.py
+++ b/stanza/models/constituency/lstm_tree_stack.py
@@ -18,17 +18,17 @@ class LSTMTreeStack(nn.Module):
self.uses_boundary_vector = uses_boundary_vector
# The start embedding needs to be input_size as we put it through the LSTM
- # A zeros vector needs to be *hidden_size* as we do not put that through the LSTM
if uses_boundary_vector:
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
else:
- self.register_buffer('zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
+ self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size))
+ self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)
self.input_dropout = input_dropout
- def initial_state(self):
+ def initial_state(self, initial_value=None):
"""
Return an initial state, either based on zeros or based on the initial embedding and LSTM
@@ -43,9 +43,10 @@ class LSTMTreeStack(nn.Module):
output, (hx, cx) = self.lstm(start)
start = output[0, 0, :]
else:
- hx = self.zeros
- cx = self.zeros
- return TreeStack(value=Node(None, hx, cx), parent=None, length=1)
+ start = self.input_zeros
+ hx = self.hidden_zeros
+ cx = self.hidden_zeros
+ return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)
def push_states(self, stacks, values, inputs):
"""
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index a2869d39..b5b31d85 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -107,6 +107,17 @@ class Trainer:
params['model']['transition_lstm_stack.lstm.bias_ih_l1'] = params['model']['transition_lstm.bias_ih_l1']
params['model']['transition_lstm_stack.lstm.bias_hh_l1'] = params['model']['transition_lstm.bias_hh_l1']
+ if 'constituent_start_embedding' in params['model']:
+ params['model']['constituent_lstm_stack.start_embedding'] = params['model']['constituent_start_embedding']
+ params['model']['constituent_lstm_stack.lstm.weight_ih_l0'] = params['model']['constituent_lstm.weight_ih_l0']
+ params['model']['constituent_lstm_stack.lstm.weight_hh_l0'] = params['model']['constituent_lstm.weight_hh_l0']
+ params['model']['constituent_lstm_stack.lstm.bias_ih_l0'] = params['model']['constituent_lstm.bias_ih_l0']
+ params['model']['constituent_lstm_stack.lstm.bias_hh_l0'] = params['model']['constituent_lstm.bias_hh_l0']
+ params['model']['constituent_lstm_stack.lstm.weight_ih_l1'] = params['model']['constituent_lstm.weight_ih_l1']
+ params['model']['constituent_lstm_stack.lstm.weight_hh_l1'] = params['model']['constituent_lstm.weight_hh_l1']
+ params['model']['constituent_lstm_stack.lstm.bias_ih_l1'] = params['model']['constituent_lstm.bias_ih_l1']
+ params['model']['constituent_lstm_stack.lstm.bias_hh_l1'] = params['model']['constituent_lstm.bias_hh_l1']
+
model_type = checkpoint['model_type']
if model_type == 'LSTM':
pt = load_pretrain(saved_args.get('wordvec_pretrain_file', None), foundation_cache)
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
index dce1ce2d..d5ecbe96 100644
--- a/stanza/tests/constituency/test_lstm_model.py
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -353,7 +353,9 @@ def check_structure_test(pretrain_file, args1, args2):
assert torch.allclose(i.lstm_hx, j.lstm_hx)
assert torch.allclose(i.lstm_cx, j.lstm_cx)
for i, j in zip(other_states[0].constituents, model_states[0].constituents):
- assert torch.allclose(i.tree_hx, j.tree_hx)
+ assert (i.value is None) == (j.value is None)
+ if i.value is not None:
+ assert torch.allclose(i.value.tree_hx, j.value.tree_hx)
assert torch.allclose(i.lstm_hx, j.lstm_hx)
assert torch.allclose(i.lstm_cx, j.lstm_cx)