From 80b96bad477f6c74cf03d6d6d8dc057d1af88c8e Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 8 Sep 2022 11:30:47 -0700 Subject: NER get_known_tags possibly applies to multiple models --- stanza/pipeline/ner_processor.py | 5 +++-- stanza/tests/pipeline/test_pipeline_ner_processor.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/stanza/pipeline/ner_processor.py b/stanza/pipeline/ner_processor.py index a5306abf..72520e14 100644 --- a/stanza/pipeline/ner_processor.py +++ b/stanza/pipeline/ner_processor.py @@ -104,10 +104,11 @@ class NERProcessor(UDProcessor): doc.build_ents() return docs - def get_known_tags(self): + def get_known_tags(self, model_idx=0): """ Return the tags known by this model Removes the S-, B-, etc, and does not include O + Specify model_idx if the processor has more than one model """ - return self._trainer.get_known_tags() + return self.trainers[model_idx].get_known_tags() diff --git a/stanza/tests/pipeline/test_pipeline_ner_processor.py b/stanza/tests/pipeline/test_pipeline_ner_processor.py index a2d7768e..c67089cf 100644 --- a/stanza/tests/pipeline/test_pipeline_ner_processor.py +++ b/stanza/tests/pipeline/test_pipeline_ner_processor.py @@ -142,3 +142,6 @@ class TestMultiNERProcessor: multi_ner = [[token.multi_ner for token in sentence.tokens] for sentence in doc.sentences] assert multi_ner == EXPECTED_MULTI_NER + def test_known_tags(self, pipeline): + assert pipeline.processors["ner"].get_known_tags() == ["DISEASE"] + assert len(pipeline.processors["ner"].get_known_tags(1)) == 18 -- cgit v1.2.3