diff options
author | John Bauer <horatio@gmail.com> | 2022-09-18 07:16:43 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-19 06:27:21 +0300 |
commit | 0413d552e89fd401165ae8fb722d6de7e76cd20c (patch) | |
tree | 0c4f4c98b9aa44258827e3945204153ea4df4429 | |
parent | 5a50652c50d52468e8bbe22c7029ecb701267fe9 (diff) |
Also refactor a constituent_lstm_stack. The unary transitions are a little wonky to get right, but they kind of suck anywayrefactor_lstm
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 56 | ||||
-rw-r--r-- | stanza/models/constituency/lstm_tree_stack.py | 13 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 11 | ||||
-rw-r--r-- | stanza/tests/constituency/test_lstm_model.py | 4 |
4 files changed, 45 insertions, 39 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 0a384801..0da5461e 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -57,8 +57,6 @@ WordNode = namedtuple("WordNode", ['value', 'hx']) # We do this to maintain consistency between the different operations, # which sometimes result in different shapes # This will be unsqueezed in order to put into the next layer if needed -# lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents -ConstituentNode = namedtuple("ConstituentNode", ['value', 'tree_hx', 'lstm_hx', 'lstm_cx']) Constituent = namedtuple("Constituent", ['value', 'tree_hx']) # The sentence boundary vectors are marginally useful at best. @@ -255,7 +253,6 @@ class LSTMModel(BaseModel, nn.Module): # also register a buffer of zeros so that we can always get zeros on the appropriate device self.register_buffer('word_zeros', torch.zeros(self.hidden_size)) - self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size)) # possibly add a couple vectors for bookends of the sentence # We put the word_start and word_end here, AFTER counting the @@ -364,10 +361,14 @@ class LSTMModel(BaseModel, nn.Module): self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map), embedding_dim = self.hidden_size) nn.init.normal_(self.constituent_open_embedding.weight, std=0.2) - if self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING: - self.register_parameter('constituent_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.hidden_size, requires_grad=True))) + # input_size is hidden_size - could introduce a new constituent_size instead if we liked - self.constituent_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, dropout=self.lstm_layer_dropout) + self.constituent_lstm_stack = LSTMTreeStack(input_size=self.hidden_size, + hidden_size=self.hidden_size, + num_lstm_layers=self.num_lstm_layers, + dropout=self.lstm_layer_dropout, + uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING, + input_dropout=self.lstm_input_dropout) if args['combined_dummy_embedding']: self.dummy_embedding = self.constituent_open_embedding @@ -613,15 +614,7 @@ class LSTMModel(BaseModel, nn.Module): """ Return an initial TreeStack with no constituents """ - if self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING: - constituent_start = self.constituent_start_embedding.unsqueeze(0).unsqueeze(0) - output, (hx, cx) = self.constituent_lstm(constituent_start) - constituent_start = output[0, 0, :] - else: - constituent_start = self.constituent_zeros[-1, 0, :] - hx = self.constituent_zeros - cx = self.constituent_zeros - return TreeStack(value=ConstituentNode(None, constituent_start, hx, cx), parent=None, length=1) + return self.constituent_lstm_stack.initial_state() def get_word(self, word_node): return word_node.value @@ -639,11 +632,15 @@ class LSTMModel(BaseModel, nn.Module): def unary_transform(self, constituents, labels): # TODO: this can be faster by stacking things - top_constituent = constituents.value + # the double dereference is because we expect the Constiuent + # wrapped in an LSTMTreeStack Node + top_constituent = constituents.value.value for label in reversed(labels): # double nested: the Constituent is in a list of just one child # and there is just one item in the list (hence the stacking comment) - top_constituent = self.build_constituents([(label,)], [[top_constituent]])[0] + # the fake Constituent is because normally the Constituent + # items are wrapped from the LSTMTreeStack + top_constituent = self.build_constituents([(label,)], [[Constituent(top_constituent, None)]])[0] return top_constituent def build_constituents(self, labels, children_lists): @@ -654,7 +651,7 @@ class LSTMModel(BaseModel, nn.Module): children_lists is a list of children that go under each of the new nodes lists of each are used so that we can stack operations """ - node_hx = [[child.tree_hx for child in children] for children in children_lists] + node_hx = [[child.value.tree_hx for child in children] for children in children_lists] if (self.constituency_composition == ConstituencyComposition.BILSTM or self.constituency_composition == ConstituencyComposition.BILSTM_MAX): @@ -713,7 +710,7 @@ class LSTMModel(BaseModel, nn.Module): constituents = [] for idx, (label, children) in enumerate(zip(labels, children_lists)): - children = [child.value for child in children] + children = [child.value.value for child in children] if isinstance(label, str): node = Tree(label=label, children=children) else: @@ -724,15 +721,6 @@ class LSTMModel(BaseModel, nn.Module): return constituents def push_constituents(self, constituent_stacks, constituents): - current_nodes = [stack.value for stack in constituent_stacks] - - constituent_input = torch.stack([x.tree_hx for x in constituents]) - constituent_input = constituent_input.unsqueeze(0) - constituent_input = self.lstm_input_dropout(constituent_input) - - hx = torch.cat([current_node.lstm_hx for current_node in current_nodes], axis=1) - cx = torch.cat([current_node.lstm_cx for current_node in current_nodes], axis=1) - output, (hx, cx) = self.constituent_lstm(constituent_input, (hx, cx)) # Another possibility here would be to use output[0, i, :] # from the constituency lstm for the value of the new node. # This might theoretically make the new constituent include @@ -742,9 +730,12 @@ class LSTMModel(BaseModel, nn.Module): # averaged over 5 trials, had the following loss in accuracy: # 150 epochs: 0.8971 to 0.8953 # 200 epochs: 0.8985 to 0.8964 - new_stacks = [stack.push(ConstituentNode(constituent.value, constituents[i].tree_hx, hx[:, i:i+1, :], cx[:, i:i+1, :])) - for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))] - return new_stacks + current_nodes = [stack.value for stack in constituent_stacks] + + constituent_input = torch.stack([x.tree_hx for x in constituents]) + constituent_input = constituent_input.unsqueeze(0) + # the constituents are already Constituent(tree, tree_hx) + return self.constituent_lstm_stack.push_states(constituent_stacks, constituents, constituent_input) def get_top_constituent(self, constituents): """ @@ -752,7 +743,8 @@ class LSTMModel(BaseModel, nn.Module): sequence, even though it has multiple addition pieces of information """ - constituent_node = constituents.value + # TreeStack value -> LSTMTreeStack value -> Constituent value + constituent_node = constituents.value.value return constituent_node.value def push_transitions(self, transition_stacks, transitions): diff --git a/stanza/models/constituency/lstm_tree_stack.py b/stanza/models/constituency/lstm_tree_stack.py index c53fe8f3..3832596d 100644 --- a/stanza/models/constituency/lstm_tree_stack.py +++ b/stanza/models/constituency/lstm_tree_stack.py @@ -18,17 +18,17 @@ class LSTMTreeStack(nn.Module): self.uses_boundary_vector = uses_boundary_vector # The start embedding needs to be input_size as we put it through the LSTM - # A zeros vector needs to be *hidden_size* as we do not put that through the LSTM if uses_boundary_vector: self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True))) else: - self.register_buffer('zeros', torch.zeros(num_lstm_layers, 1, hidden_size)) + self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size)) + self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size)) self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout) self.input_dropout = input_dropout - def initial_state(self): + def initial_state(self, initial_value=None): """ Return an initial state, either based on zeros or based on the initial embedding and LSTM @@ -43,9 +43,10 @@ class LSTMTreeStack(nn.Module): output, (hx, cx) = self.lstm(start) start = output[0, 0, :] else: - hx = self.zeros - cx = self.zeros - return TreeStack(value=Node(None, hx, cx), parent=None, length=1) + start = self.input_zeros + hx = self.hidden_zeros + cx = self.hidden_zeros + return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1) def push_states(self, stacks, values, inputs): """ diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index a2869d39..b5b31d85 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -107,6 +107,17 @@ class Trainer: params['model']['transition_lstm_stack.lstm.bias_ih_l1'] = params['model']['transition_lstm.bias_ih_l1'] params['model']['transition_lstm_stack.lstm.bias_hh_l1'] = params['model']['transition_lstm.bias_hh_l1'] + if 'constituent_start_embedding' in params['model']: + params['model']['constituent_lstm_stack.start_embedding'] = params['model']['constituent_start_embedding'] + params['model']['constituent_lstm_stack.lstm.weight_ih_l0'] = params['model']['constituent_lstm.weight_ih_l0'] + params['model']['constituent_lstm_stack.lstm.weight_hh_l0'] = params['model']['constituent_lstm.weight_hh_l0'] + params['model']['constituent_lstm_stack.lstm.bias_ih_l0'] = params['model']['constituent_lstm.bias_ih_l0'] + params['model']['constituent_lstm_stack.lstm.bias_hh_l0'] = params['model']['constituent_lstm.bias_hh_l0'] + params['model']['constituent_lstm_stack.lstm.weight_ih_l1'] = params['model']['constituent_lstm.weight_ih_l1'] + params['model']['constituent_lstm_stack.lstm.weight_hh_l1'] = params['model']['constituent_lstm.weight_hh_l1'] + params['model']['constituent_lstm_stack.lstm.bias_ih_l1'] = params['model']['constituent_lstm.bias_ih_l1'] + params['model']['constituent_lstm_stack.lstm.bias_hh_l1'] = params['model']['constituent_lstm.bias_hh_l1'] + model_type = checkpoint['model_type'] if model_type == 'LSTM': pt = load_pretrain(saved_args.get('wordvec_pretrain_file', None), foundation_cache) diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index dce1ce2d..d5ecbe96 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -353,7 +353,9 @@ def check_structure_test(pretrain_file, args1, args2): assert torch.allclose(i.lstm_hx, j.lstm_hx) assert torch.allclose(i.lstm_cx, j.lstm_cx) for i, j in zip(other_states[0].constituents, model_states[0].constituents): - assert torch.allclose(i.tree_hx, j.tree_hx) + assert (i.value is None) == (j.value is None) + if i.value is not None: + assert torch.allclose(i.value.tree_hx, j.value.tree_hx) assert torch.allclose(i.lstm_hx, j.lstm_hx) assert torch.allclose(i.lstm_cx, j.lstm_cx) |