diff options
author | John Bauer <horatio@gmail.com> | 2022-11-10 19:27:39 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-10 19:27:39 +0300 |
commit | 3ee39b47369fc22769fe9805bcd1784cc2c8fbfe (patch) | |
tree | e6571f76d06f9494b6b098f9f22a158e88f57b15 | |
parent | d0a729801412372cb553a3328010675f404a1dca (diff) |
Use the last word piece instead of the firstvi_bert_last
-rw-r--r-- | stanza/models/common/bert_embedding.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py index 65a7fb16..4247d6ea 100644 --- a/stanza/models/common/bert_embedding.py +++ b/stanza/models/common/bert_embedding.py @@ -197,7 +197,7 @@ def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_ #process the output #only take the vector of the last word piece of a word/ you can do other methods such as first word piece or averaging. # idx2+1 compensates for the start token at the start of a sentence - offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if (idx2 > 0 and not list_tokenized[idx][idx2-1].endswith("@@")) or (idx2==0)] + offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if not list_tokenized[idx][idx2].endswith("@@")] for idx, sent in enumerate(processed)] if keep_endpoints: # [0] and [-1] grab the start and end representations as well |