diff options
author | John Bauer <horatio@gmail.com> | 2022-09-08 21:30:47 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-08 21:30:47 +0300 |
commit | 80b96bad477f6c74cf03d6d6d8dc057d1af88c8e (patch) | |
tree | 12a99d60f046259ffbb05d2a521d50c81499cf88 | |
parent | 19e229565f001da003a1949517ca1f9bafd24920 (diff) |
NER get_known_tags possibly applies to multiple models
-rw-r--r-- | stanza/pipeline/ner_processor.py | 5 | ||||
-rw-r--r-- | stanza/tests/pipeline/test_pipeline_ner_processor.py | 3 |
2 files changed, 6 insertions, 2 deletions
diff --git a/stanza/pipeline/ner_processor.py b/stanza/pipeline/ner_processor.py index a5306abf..72520e14 100644 --- a/stanza/pipeline/ner_processor.py +++ b/stanza/pipeline/ner_processor.py @@ -104,10 +104,11 @@ class NERProcessor(UDProcessor): doc.build_ents() return docs - def get_known_tags(self): + def get_known_tags(self, model_idx=0): """ Return the tags known by this model Removes the S-, B-, etc, and does not include O + Specify model_idx if the processor has more than one model """ - return self._trainer.get_known_tags() + return self.trainers[model_idx].get_known_tags() diff --git a/stanza/tests/pipeline/test_pipeline_ner_processor.py b/stanza/tests/pipeline/test_pipeline_ner_processor.py index a2d7768e..c67089cf 100644 --- a/stanza/tests/pipeline/test_pipeline_ner_processor.py +++ b/stanza/tests/pipeline/test_pipeline_ner_processor.py @@ -142,3 +142,6 @@ class TestMultiNERProcessor: multi_ner = [[token.multi_ner for token in sentence.tokens] for sentence in doc.sentences] assert multi_ner == EXPECTED_MULTI_NER + def test_known_tags(self, pipeline): + assert pipeline.processors["ner"].get_known_tags() == ["DISEASE"] + assert len(pipeline.processors["ner"].get_known_tags(1)) == 18 |