diff options
author | John Bauer <horatio@gmail.com> | 2022-03-05 02:50:31 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-15 01:54:08 +0300 |
commit | 8e872ab0b645b181b54d5184e31cda0e479fbb8f (patch) | |
tree | fda5220f22d0bcbc06a6004f4d4b1aee922411ec | |
parent | b82d2ab9fdee37e22fe168b6a6e698fae6700b2b (diff) |
Attempt to come up with an initial tree_cx for the TREE_LSTM methodcon_tree_lstm
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index e65bb236..3edc54f4 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -261,7 +261,6 @@ class LSTMModel(BaseModel, nn.Module): self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers)) self.register_buffer('transition_zeros', torch.zeros(self.num_lstm_layers, 1, self.transition_hidden_size)) self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size)) - self.register_buffer('word_constituent_zeros', torch.zeros(self.num_tree_lstm_layers, self.hidden_size)) # possibly add a couple vectors for bookends of the sentence # We put the word_start and word_end here, AFTER counting the @@ -408,6 +407,8 @@ class LSTMModel(BaseModel, nn.Module): elif self.constituency_composition == ConstituencyComposition.ATTN: self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads) elif self.constituency_composition == ConstituencyComposition.TREE_LSTM: + self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2, + embedding_dim = self.num_tree_lstm_layers * self.hidden_size) self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout) else: raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition)) @@ -652,7 +653,13 @@ class LSTMModel(BaseModel, nn.Module): word_node = state.word_queue[state.word_position] word = word_node.value if self.constituency_composition == ConstituencyComposition.TREE_LSTM: - return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_constituent_zeros) + # the UNK tag will be trained thanks to occasionally dropping out tags + tag = word.label + tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size) + tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)] + tree_cx = self.constituent_reduce_embedding(tag_tensor) + tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size) + return Constituent(word, tree_hx, tree_cx * tree_hx) else: return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None) @@ -749,7 +756,6 @@ class LSTMModel(BaseModel, nn.Module): label_hx = torch.stack(label_hx).unsqueeze(0) max_length = max(len(children) for children in children_lists) - zeros = self.word_constituent_zeros # stacking will let us do elementwise multiplication faster, hopefully node_hx = [[child.tree_hx for child in children] for children in children_lists] |