diff options
Diffstat (limited to 'stanza/pipeline/pos_processor.py')
-rw-r--r-- | stanza/pipeline/pos_processor.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/stanza/pipeline/pos_processor.py b/stanza/pipeline/pos_processor.py index da918fdf..89658ee2 100644 --- a/stanza/pipeline/pos_processor.py +++ b/stanza/pipeline/pos_processor.py @@ -4,12 +4,14 @@ Processor for performing part-of-speech tagging from stanza.models.common import doc from stanza.models.common.pretrain import Pretrain -from stanza.models.common.utils import unsort +from stanza.models.common.utils import get_tqdm, unsort from stanza.models.pos.data import DataLoader from stanza.models.pos.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor +tqdm = get_tqdm() + @register_processor(name=POS) class POSProcessor(UDProcessor): @@ -23,14 +25,21 @@ class POSProcessor(UDProcessor): self._pretrain = Pretrain(config['pretrain_path']) if 'pretrain_path' in config else None # set up trainer self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], use_cuda=use_gpu) + self._tqdm = 'tqdm' in config and config['tqdm'] def process(self, document): batch = DataLoader( document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True, sort_during_eval=True) preds = [] - for i, b in enumerate(batch): - preds += self.trainer.predict(b) + + if self._tqdm: + for i, b in enumerate(tqdm(batch)): + preds += self.trainer.predict(b) + else: + for i, b in enumerate(batch): + preds += self.trainer.predict(b) + preds = unsort(preds, batch.data_orig_idx) batch.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x]) return batch.doc |