diff options
-rw-r--r-- | stanza/pipeline/constituency_processor.py | 9 | ||||
-rw-r--r-- | stanza/tests/pipeline/test_pipeline_constituency_processor.py | 4 |
2 files changed, 13 insertions, 0 deletions
diff --git a/stanza/pipeline/constituency_processor.py b/stanza/pipeline/constituency_processor.py index 0ead83ba..b70a7b2d 100644 --- a/stanza/pipeline/constituency_processor.py +++ b/stanza/pipeline/constituency_processor.py @@ -66,3 +66,12 @@ class ConstituencyProcessor(UDProcessor): trees = trainer.parse_tagged_words(self._model.model, words, self._batch_size) document.set(CONSTITUENCY, trees, to_sentence=True) return document + + def get_constituents(self): + """ + Return a set of the constituents known by this model + + For a pipeline, this can be queried with + pipeline.processors["constituency"].get_constituents() + """ + return set(self._model.model.constituents) diff --git a/stanza/tests/pipeline/test_pipeline_constituency_processor.py b/stanza/tests/pipeline/test_pipeline_constituency_processor.py index 0cc01d0f..77a83b48 100644 --- a/stanza/tests/pipeline/test_pipeline_constituency_processor.py +++ b/stanza/tests/pipeline/test_pipeline_constituency_processor.py @@ -35,3 +35,7 @@ def test_sorted_two_batch(): pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=2) doc = pipe(TEST_TEXT) check_results(doc) + +def test_get_constituents(): + pipe = stanza.Pipeline("en", processors="tokenize,pos,constituency") + assert "SBAR" in pipe.processors["constituency"].get_constituents() |