Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-26 09:56:41 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-04 09:13:23 +0300
commit79aad63462b14ff351e1f59e9326f4c8725557f9 (patch)
treed10c4f8c66b9ee302f8021feae308af5ff8d8042
parent03942709c8c2cf259151cebb22d2d4a9e733bf58 (diff)
maybe use a TQDM when scoring stuff
-rw-r--r--stanza/models/constituency/trans_lm.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/stanza/models/constituency/trans_lm.py b/stanza/models/constituency/trans_lm.py
index b42b2bf7..fc9b99d9 100644
--- a/stanza/models/constituency/trans_lm.py
+++ b/stanza/models/constituency/trans_lm.py
@@ -18,6 +18,8 @@ from torchtext.vocab import build_vocab_from_iterator
from stanza.models.common import utils
+tqdm = utils.get_tqdm()
+
class TransformerModel(nn.Module):
def __init__(self, vocab, args):
@@ -67,7 +69,7 @@ class TransformerModel(nn.Module):
output = self.decoder(output)
return output
- def score(self, sentences, batch_size=10):
+ def score(self, sentences, batch_size=10, use_tqdm=False):
device = next(self.parameters()).device
if isinstance(sentences, str):
sentences = [sentences]
@@ -76,6 +78,8 @@ class TransformerModel(nn.Module):
with torch.no_grad():
data, indices = data_process(self.vocab, self.tokenizer, None, iter(sentences))
data = batchify(data, batch_size, None)
+ if use_tqdm:
+ data = tqdm(data, leave=False)
# TODO: save this mask
max_len = max(max(len(x) for x in y) for y in data)