diff options
author | John Bauer <horatio@gmail.com> | 2022-03-09 06:31:33 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-03-09 06:31:33 +0300 |
commit | c8c0c395bcf39bd4cb8bd96c0fef1ba1c0e5e3de (patch) | |
tree | a8e90b19cded903e5f5c8712c46247eda27fbf84 | |
parent | fad0fb7167657e2cad01a7b6fdf11c386c3fb8e3 (diff) |
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 5 |
1 files changed, 1 insertions, 4 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 45f002a9..42304a65 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -220,7 +220,7 @@ class LSTMModel(BaseModel, nn.Module): self.lstm_layer_dropout = self.args['lstm_layer_dropout'] # also register a buffer of zeros so that we can always get zeros on the appropriate device - self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers)) + self.register_buffer('word_zeros', torch.zeros(self.hidden_size)) self.register_buffer('transition_zeros', torch.zeros(self.num_lstm_layers, 1, self.transition_hidden_size)) self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size)) @@ -367,8 +367,6 @@ class LSTMModel(BaseModel, nn.Module): initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size) initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size) elif self.constituency_composition == ConstituencyComposition.TREE_LSTM: - self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2, - embedding_dim = self.num_tree_lstm_layers * self.hidden_size) self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout) else: raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition)) @@ -725,7 +723,6 @@ class LSTMModel(BaseModel, nn.Module): cx = self.constituent_zeros tree_hx = self.constituent_zeros[-1, 0, :] tree_cx = self.constituent_zeros[-1, 0, :] - #print(hx.shape, cx.shape, tree_hx.shape, tree_cx.shape) return TreeStack(value=ConstituentNode(None, hx, cx, tree_hx, tree_cx), parent=None, length=1) def get_word(self, word_node): |