diff options
author | John Bauer <horatio@gmail.com> | 2022-10-26 09:56:41 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-04 09:13:23 +0300 |
commit | 79aad63462b14ff351e1f59e9326f4c8725557f9 (patch) | |
tree | d10c4f8c66b9ee302f8021feae308af5ff8d8042 | |
parent | 03942709c8c2cf259151cebb22d2d4a9e733bf58 (diff) |
maybe use a TQDM when scoring stuff
-rw-r--r-- | stanza/models/constituency/trans_lm.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/stanza/models/constituency/trans_lm.py b/stanza/models/constituency/trans_lm.py index b42b2bf7..fc9b99d9 100644 --- a/stanza/models/constituency/trans_lm.py +++ b/stanza/models/constituency/trans_lm.py @@ -18,6 +18,8 @@ from torchtext.vocab import build_vocab_from_iterator from stanza.models.common import utils +tqdm = utils.get_tqdm() + class TransformerModel(nn.Module): def __init__(self, vocab, args): @@ -67,7 +69,7 @@ class TransformerModel(nn.Module): output = self.decoder(output) return output - def score(self, sentences, batch_size=10): + def score(self, sentences, batch_size=10, use_tqdm=False): device = next(self.parameters()).device if isinstance(sentences, str): sentences = [sentences] @@ -76,6 +78,8 @@ class TransformerModel(nn.Module): with torch.no_grad(): data, indices = data_process(self.vocab, self.tokenizer, None, iter(sentences)) data = batchify(data, batch_size, None) + if use_tqdm: + data = tqdm(data, leave=False) # TODO: save this mask max_len = max(max(len(x) for x in y) for y in data) |