diff options
Diffstat (limited to 'stanza/pipeline/tokenize_processor.py')
-rw-r--r-- | stanza/pipeline/tokenize_processor.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py index 4efb71c3..6a50313f 100644 --- a/stanza/pipeline/tokenize_processor.py +++ b/stanza/pipeline/tokenize_processor.py @@ -12,6 +12,7 @@ from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor from stanza.utils.postprocess_vietnamese_tokenizer_data import paras_to_chunks from stanza.models.common import doc +from stanza.utils.jieba import JiebaTokenizer from stanza.utils.spacy import SpacyTokenizer logger = logging.getLogger('stanza') @@ -30,6 +31,10 @@ class TokenizeProcessor(UDProcessor): # set up trainer if config.get('pretokenized'): self._trainer = None + elif config.get('with_jieba', False): + self._trainer = None + self._jieba_tokenizer = JiebaTokenizer(config.get('lang')) + logger.info("Using jieba as tokenizer") elif config.get('with_spacy', False): self._trainer = None self._spacy_tokenizer = SpacyTokenizer(config.get('lang')) @@ -49,7 +54,7 @@ class TokenizeProcessor(UDProcessor): document = [] if isinstance(input_src, str): - sentences = [sent.rstrip(' ').split() for sent in input_src.rstrip('\n').split('\n') if sent] + sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0] elif isinstance(input_src, list): sentences = input_src idx = 0 @@ -59,7 +64,6 @@ class TokenizeProcessor(UDProcessor): sent.append({doc.ID: str(token_id + 1), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'}) idx += len(token) + 1 document.append(sent) - idx += 1 raw_text = ' '.join([' '.join(sentence) for sentence in sentences]) return raw_text, document @@ -69,24 +73,24 @@ class TokenizeProcessor(UDProcessor): if self.config.get('pretokenized'): raw_text, document = self.process_pre_tokenized_text(document) + elif self.config.get('with_jieba', False): + return self._jieba_tokenizer.tokenize(document) elif self.config.get('with_spacy', False): return self._spacy_tokenizer.tokenize(document) else: - raw_text = document + raw_text = '\n\n'.join(document) if isinstance(document, list) else document # set up batches if self.config.get('lang') == 'vi': # special processing is due for Vietnamese - text = '\n\n'.join([x for x in document.split('\n\n')]).rstrip() + text = '\n\n'.join([x for x in raw_text.split('\n\n')]).rstrip() dummy_labels = '\n\n'.join(['0' * len(x) for x in text.split('\n\n')]) data = paras_to_chunks(text, dummy_labels) batches = DataLoader(self.config, input_data=data, vocab=self.vocab, evaluation=True) else: - if isinstance(document, list): - document = '\n\n'.join(document) - batches = DataLoader(self.config, input_text=document, vocab=self.vocab, evaluation=True) + batches = DataLoader(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True) # get dict data _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None, self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT), - orig_text = document, + orig_text=raw_text, no_ssplit=self.config.get('no_ssplit', False)) return doc.Document(document, raw_text) |