Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-03-09 06:31:33 +0300
committerJohn Bauer <horatio@gmail.com>2022-03-09 06:31:33 +0300
commitc8c0c395bcf39bd4cb8bd96c0fef1ba1c0e5e3de (patch)
treea8e90b19cded903e5f5c8712c46247eda27fbf84
parentfad0fb7167657e2cad01a7b6fdf11c386c3fb8e3 (diff)
-rw-r--r--stanza/models/constituency/lstm_model.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 45f002a9..42304a65 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -220,7 +220,7 @@ class LSTMModel(BaseModel, nn.Module):
self.lstm_layer_dropout = self.args['lstm_layer_dropout']
# also register a buffer of zeros so that we can always get zeros on the appropriate device
- self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))
+ self.register_buffer('word_zeros', torch.zeros(self.hidden_size))
self.register_buffer('transition_zeros', torch.zeros(self.num_lstm_layers, 1, self.transition_hidden_size))
self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
@@ -367,8 +367,6 @@ class LSTMModel(BaseModel, nn.Module):
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
- self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,
- embedding_dim = self.num_tree_lstm_layers * self.hidden_size)
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
else:
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
@@ -725,7 +723,6 @@ class LSTMModel(BaseModel, nn.Module):
cx = self.constituent_zeros
tree_hx = self.constituent_zeros[-1, 0, :]
tree_cx = self.constituent_zeros[-1, 0, :]
- #print(hx.shape, cx.shape, tree_hx.shape, tree_cx.shape)
return TreeStack(value=ConstituentNode(None, hx, cx, tree_hx, tree_cx), parent=None, length=1)
def get_word(self, word_node):