Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-07-30 09:36:04 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-14 11:21:19 +0300
commit194e446c200e7c57d01901748c60324be877d6e8 (patch)
tree125b2b83c8b20b0e12c0f66c12dc26a9a88112c0
parent91e9e9c2cd27a8278a6dab4bfec8d0e46b5681be (diff)
Just to compare masked or not-masked for xlnetmasks2
-rw-r--r--stanza/models/common/bert_embedding.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py
index 49cd72f3..af7f685f 100644
--- a/stanza/models/common/bert_embedding.py
+++ b/stanza/models/common/bert_embedding.py
@@ -216,7 +216,7 @@ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_end
features = []
for i in range(int(math.ceil(len(data)/128))):
with torch.no_grad():
- attention_mask = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
+ attention_mask = None
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
# feature[2] is the same for bert, but it didn't work for