Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-29 19:47:49 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-11 09:18:45 +0300
commit94059289b74f5c79e1321f4823f8722c94a37f48 (patch)
treec830813593a728e7dc695e527c59f94ea9501e83
parent2775fa42763c17f1ace94b7fc4e23aa3c05ea9c5 (diff)
Mix with N vectors instead of just 1bert_mix
-rw-r--r--stanza/models/constituency/lstm_model.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index ada53e22..b19ae4a5 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -374,8 +374,7 @@ class LSTMModel(BaseModel, nn.Module):
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
elif args['bert_mix'] == BertMix.QUERY:
- self.bert_layer_mix = nn.Linear(self.bert_dim, 1, bias=False)
- nn.init.zeros_(self.bert_layer_mix.weight)
+ self.register_parameter('bert_query', torch.nn.Parameter(torch.zeros(self.bert_dim, args['bert_hidden_layers'], requires_grad=True)))
else:
raise ValueError("Unhandled BertMix {}".format(args['bert_mix']))
self.word_input_size = self.word_input_size + self.bert_dim
@@ -736,7 +735,7 @@ class LSTMModel(BaseModel, nn.Module):
# we will take 1:-1 if we don't care about the endpoints
bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,
keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,
- num_layers=self.args['bert_hidden_layers'] if self.bert_layer_mix is not None else None)
+ num_layers=self.args['bert_hidden_layers'] if self.args['bert_mix'] != BertMix.NONE else None)
if self.args['bert_mix'] == BertMix.NONE:
pass
elif self.args['bert_mix'] == BertMix.LINEAR:
@@ -747,8 +746,9 @@ class LSTMModel(BaseModel, nn.Module):
elif self.args['bert_mix'] == BertMix.QUERY:
mixed_bert_embeddings = []
for feature in bert_embeddings:
- weighted_feature = self.bert_layer_mix(feature.transpose(1, 2))
- weighted_feature = torch.softmax(weighted_feature, dim=1)
+ # result will be num words x num bert layers
+ weighted_feature = (feature * self.bert_query).sum(dim=1)
+ weighted_feature = torch.softmax(weighted_feature, dim=1).unsqueeze(2)
weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2)
mixed_bert_embeddings.append(weighted_feature)
bert_embeddings = mixed_bert_embeddings