diff options
author | John Bauer <horatio@gmail.com> | 2022-11-06 03:54:38 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-06 03:54:38 +0300 |
commit | 435bc94aba934c1f3c3bb5e0188ea166a67be3df (patch) | |
tree | e7e26c6e3b9e57a764386c169e251cfbf3487097 | |
parent | 647940ab8d0329b68a7344b6736bd086345be981 (diff) |
Update log line & allow list of str instead of list of tuples
-rw-r--r-- | stanza/models/common/bert_embedding.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py index 8b849256..627076df 100644 --- a/stanza/models/common/bert_embedding.py +++ b/stanza/models/common/bert_embedding.py @@ -74,7 +74,7 @@ def filter_data(model_name, data, tokenizer = None): filtered_data = [] #eliminate all the sentences that are too long for bert model for sent in data: - sentence = [word[0] for word in sent] + sentence = [word if isinstance(word, str) else word[0] for word in sent] _, tokenized_sent = tokenize_manual(model_name, sentence, tokenizer) if len(tokenized_sent) > tokenizer.model_max_length - 2: @@ -82,7 +82,7 @@ def filter_data(model_name, data, tokenizer = None): filtered_data.append(sent) - logger.info("Eliminated {} datapoints because their length is over maximum size of BERT model. ".format(len(data)-len(filtered_data))) + logger.info("Eliminated %d of %d datapoints because their length is over maximum size of BERT model. ", (len(data)-len(filtered_data)), len(data)) return filtered_data |