diff options
author | John Bauer <horatio@gmail.com> | 2022-10-27 01:15:50 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-27 01:15:50 +0300 |
commit | 56821bec2fed7c6be5bc8d3fca5908b7cc0a1316 (patch) | |
tree | e40b29e3c242f16b89ce4bd45cf65c61508bab59 | |
parent | 6e300655797372f2f2fa68bbe904c9a0458a8d9e (diff) |
Adjust attention masks for vi phobert
-rw-r--r-- | stanza/models/common/bert_embedding.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py index ce9046ec..f13460bb 100644 --- a/stanza/models/common/bert_embedding.py +++ b/stanza/models/common/bert_embedding.py @@ -145,7 +145,14 @@ def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_ # run only 1 time as the batch size seems to be 30 for i in range(int(math.ceil(size/128))): with torch.no_grad(): - feature = model(tokenized_sents_padded[128*i:128*i+128].clone().detach().to(device), output_hidden_states=True) + padded_input = tokenized_sents_padded[128*i:128*i+128] + start_sentence = i * 128 + end_sentence = start_sentence + padded_input.shape[0] + attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device) + for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]): + attention_mask[sent_idx, :len(sent)] = 1 + # TODO: is the clone().detach() necessary? + feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True) features += cloned_feature(feature, num_layers) assert len(features)==size |