Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/pipeline/pos_processor.py')
-rw-r--r--stanza/pipeline/pos_processor.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/stanza/pipeline/pos_processor.py b/stanza/pipeline/pos_processor.py
index da918fdf..89658ee2 100644
--- a/stanza/pipeline/pos_processor.py
+++ b/stanza/pipeline/pos_processor.py
@@ -4,12 +4,14 @@ Processor for performing part-of-speech tagging
from stanza.models.common import doc
from stanza.models.common.pretrain import Pretrain
-from stanza.models.common.utils import unsort
+from stanza.models.common.utils import get_tqdm, unsort
from stanza.models.pos.data import DataLoader
from stanza.models.pos.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
+tqdm = get_tqdm()
+
@register_processor(name=POS)
class POSProcessor(UDProcessor):
@@ -23,14 +25,21 @@ class POSProcessor(UDProcessor):
self._pretrain = Pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
# set up trainer
self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], use_cuda=use_gpu)
+ self._tqdm = 'tqdm' in config and config['tqdm']
def process(self, document):
batch = DataLoader(
document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,
sort_during_eval=True)
preds = []
- for i, b in enumerate(batch):
- preds += self.trainer.predict(b)
+
+ if self._tqdm:
+ for i, b in enumerate(tqdm(batch)):
+ preds += self.trainer.predict(b)
+ else:
+ for i, b in enumerate(batch):
+ preds += self.trainer.predict(b)
+
preds = unsort(preds, batch.data_orig_idx)
batch.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])
return batch.doc