diff options
author | John Bauer <horatio@gmail.com> | 2022-10-27 05:42:29 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-27 06:04:42 +0300 |
commit | f6d71816b99402d5e6df45d0a7b9d24a4095d63b (patch) | |
tree | 738fcbaf7285a17ee1bde7d164697621d312096b | |
parent | 206999aac4b7930246451b212bbb381c25ad7fa8 (diff) |
Add an extraction for bartpho
this tokenizer needs to(device)
-rw-r--r-- | stanza/models/common/bert_embedding.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py index f13460bb..8b849256 100644 --- a/stanza/models/common/bert_embedding.py +++ b/stanza/models/common/bert_embedding.py @@ -103,6 +103,42 @@ def cloned_feature(feature, num_layers): feature = torch.stack(feature[-num_layers:], axis=3) return feature.clone().detach() +def extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers): + """ + Handles vi-bart. May need testing before using on other bart + + https://github.com/VinAIResearch/BARTpho + """ + processed = [] # final product, returns the list of list of word representation + + sentences = [" ".join([word.replace(" ", "_") for word in sentence]) for sentence in data] + tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True) + input_ids = tokenized['input_ids'].to(device) + attention_mask = tokenized['attention_mask'].to(device) + + for i in range(int(math.ceil(len(sentences)/128))): + with torch.no_grad(): + start_sentence = i * 128 + end_sentence = min(start_sentence + 128, len(sentences)) + input_ids = input_ids[start_sentence:end_sentence] + attention_mask = attention_mask[start_sentence:end_sentence] + + features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + features = features.decoder_hidden_states + if num_layers is None: + features = torch.stack(features[-4:-1], axis=3).sum(axis=3) / 4 + else: + features = torch.stack(features[-num_layers:], axis=3) + features = features.clone().detach() + + for feature, sentence in zip(features, data): + # +2 for the endpoints + feature = feature[:len(sentence)+2] + if not keep_endpoints: + feature = feature[1:-1] + processed.append(feature) + + return processed def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers): """ @@ -267,6 +303,10 @@ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_end if model_name.startswith("vinai/phobert"): return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers) + if model_name == "vinai/bartpho-word": + # not sure this works with any other Bart + return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers) + if isinstance(data, tuple): data = list(data) |