Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-27 05:42:29 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-27 06:04:42 +0300
commitf6d71816b99402d5e6df45d0a7b9d24a4095d63b (patch)
tree738fcbaf7285a17ee1bde7d164697621d312096b
parent206999aac4b7930246451b212bbb381c25ad7fa8 (diff)
Add an extraction for bartpho
this tokenizer needs to(device)
-rw-r--r--stanza/models/common/bert_embedding.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/stanza/models/common/bert_embedding.py b/stanza/models/common/bert_embedding.py
index f13460bb..8b849256 100644
--- a/stanza/models/common/bert_embedding.py
+++ b/stanza/models/common/bert_embedding.py
@@ -103,6 +103,42 @@ def cloned_feature(feature, num_layers):
feature = torch.stack(feature[-num_layers:], axis=3)
return feature.clone().detach()
+def extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers):
+ """
+ Handles vi-bart. May need testing before using on other bart
+
+ https://github.com/VinAIResearch/BARTpho
+ """
+ processed = [] # final product, returns the list of list of word representation
+
+ sentences = [" ".join([word.replace(" ", "_") for word in sentence]) for sentence in data]
+ tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True)
+ input_ids = tokenized['input_ids'].to(device)
+ attention_mask = tokenized['attention_mask'].to(device)
+
+ for i in range(int(math.ceil(len(sentences)/128))):
+ with torch.no_grad():
+ start_sentence = i * 128
+ end_sentence = min(start_sentence + 128, len(sentences))
+ input_ids = input_ids[start_sentence:end_sentence]
+ attention_mask = attention_mask[start_sentence:end_sentence]
+
+ features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
+ features = features.decoder_hidden_states
+ if num_layers is None:
+ features = torch.stack(features[-4:-1], axis=3).sum(axis=3) / 4
+ else:
+ features = torch.stack(features[-num_layers:], axis=3)
+ features = features.clone().detach()
+
+ for feature, sentence in zip(features, data):
+ # +2 for the endpoints
+ feature = feature[:len(sentence)+2]
+ if not keep_endpoints:
+ feature = feature[1:-1]
+ processed.append(feature)
+
+ return processed
def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers):
"""
@@ -267,6 +303,10 @@ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_end
if model_name.startswith("vinai/phobert"):
return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers)
+ if model_name == "vinai/bartpho-word":
+ # not sure this works with any other Bart
+ return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers)
+
if isinstance(data, tuple):
data = list(data)