Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-27 01:15:50 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-27 01:15:50 +0300
commit56821bec2fed7c6be5bc8d3fca5908b7cc0a1316 (patch)
treee40b29e3c242f16b89ce4bd45cf65c61508bab59
parent6e300655797372f2f2fa68bbe904c9a0458a8d9e (diff)
Adjust attention masks for vi phobert
-rw-r--r--stanza/models/common/bert_embedding.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py
index ce9046ec..f13460bb 100644
--- a/stanza/models/common/bert_embedding.py
+++ b/stanza/models/common/bert_embedding.py
@@ -145,7 +145,14 @@ def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_
# run only 1 time as the batch size seems to be 30
for i in range(int(math.ceil(size/128))):
with torch.no_grad():
- feature = model(tokenized_sents_padded[128*i:128*i+128].clone().detach().to(device), output_hidden_states=True)
+ padded_input = tokenized_sents_padded[128*i:128*i+128]
+ start_sentence = i * 128
+ end_sentence = start_sentence + padded_input.shape[0]
+ attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device)
+ for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]):
+ attention_mask[sent_idx, :len(sent)] = 1
+ # TODO: is the clone().detach() necessary?
+ feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True)
features += cloned_feature(feature, num_layers)
assert len(features)==size