diff options
author | John Bauer <horatio@gmail.com> | 2022-09-29 19:47:49 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-11 09:18:45 +0300 |
commit | 94059289b74f5c79e1321f4823f8722c94a37f48 (patch) | |
tree | c830813593a728e7dc695e527c59f94ea9501e83 | |
parent | 2775fa42763c17f1ace94b7fc4e23aa3c05ea9c5 (diff) |
Mix with N vectors instead of just 1bert_mix
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index ada53e22..b19ae4a5 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -374,8 +374,7 @@ class LSTMModel(BaseModel, nn.Module): self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False) nn.init.zeros_(self.bert_layer_mix.weight) elif args['bert_mix'] == BertMix.QUERY: - self.bert_layer_mix = nn.Linear(self.bert_dim, 1, bias=False) - nn.init.zeros_(self.bert_layer_mix.weight) + self.register_parameter('bert_query', torch.nn.Parameter(torch.zeros(self.bert_dim, args['bert_hidden_layers'], requires_grad=True))) else: raise ValueError("Unhandled BertMix {}".format(args['bert_mix'])) self.word_input_size = self.word_input_size + self.bert_dim @@ -736,7 +735,7 @@ class LSTMModel(BaseModel, nn.Module): # we will take 1:-1 if we don't care about the endpoints bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device, keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE, - num_layers=self.args['bert_hidden_layers'] if self.bert_layer_mix is not None else None) + num_layers=self.args['bert_hidden_layers'] if self.args['bert_mix'] != BertMix.NONE else None) if self.args['bert_mix'] == BertMix.NONE: pass elif self.args['bert_mix'] == BertMix.LINEAR: @@ -747,8 +746,9 @@ class LSTMModel(BaseModel, nn.Module): elif self.args['bert_mix'] == BertMix.QUERY: mixed_bert_embeddings = [] for feature in bert_embeddings: - weighted_feature = self.bert_layer_mix(feature.transpose(1, 2)) - weighted_feature = torch.softmax(weighted_feature, dim=1) + # result will be num words x num bert layers + weighted_feature = (feature * self.bert_query).sum(dim=1) + weighted_feature = torch.softmax(weighted_feature, dim=1).unsqueeze(2) weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2) mixed_bert_embeddings.append(weighted_feature) bert_embeddings = mixed_bert_embeddings |