Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-03-05 02:50:31 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-15 01:54:08 +0300
commit8e872ab0b645b181b54d5184e31cda0e479fbb8f (patch)
treefda5220f22d0bcbc06a6004f4d4b1aee922411ec
parentb82d2ab9fdee37e22fe168b6a6e698fae6700b2b (diff)
Attempt to come up with an initial tree_cx for the TREE_LSTM methodcon_tree_lstm
-rw-r--r--stanza/models/constituency/lstm_model.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index e65bb236..3edc54f4 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -261,7 +261,6 @@ class LSTMModel(BaseModel, nn.Module):
self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))
self.register_buffer('transition_zeros', torch.zeros(self.num_lstm_layers, 1, self.transition_hidden_size))
self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
- self.register_buffer('word_constituent_zeros', torch.zeros(self.num_tree_lstm_layers, self.hidden_size))
# possibly add a couple vectors for bookends of the sentence
# We put the word_start and word_end here, AFTER counting the
@@ -408,6 +407,8 @@ class LSTMModel(BaseModel, nn.Module):
elif self.constituency_composition == ConstituencyComposition.ATTN:
self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
+ self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,
+ embedding_dim = self.num_tree_lstm_layers * self.hidden_size)
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
else:
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
@@ -652,7 +653,13 @@ class LSTMModel(BaseModel, nn.Module):
word_node = state.word_queue[state.word_position]
word = word_node.value
if self.constituency_composition == ConstituencyComposition.TREE_LSTM:
- return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_constituent_zeros)
+ # the UNK tag will be trained thanks to occasionally dropping out tags
+ tag = word.label
+ tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size)
+ tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)]
+ tree_cx = self.constituent_reduce_embedding(tag_tensor)
+ tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size)
+ return Constituent(word, tree_hx, tree_cx * tree_hx)
else:
return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None)
@@ -749,7 +756,6 @@ class LSTMModel(BaseModel, nn.Module):
label_hx = torch.stack(label_hx).unsqueeze(0)
max_length = max(len(children) for children in children_lists)
- zeros = self.word_constituent_zeros
# stacking will let us do elementwise multiplication faster, hopefully
node_hx = [[child.tree_hx for child in children] for children in children_lists]