diff options
author | John Bauer <horatio@gmail.com> | 2022-11-05 06:56:18 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-05 06:56:18 +0300 |
commit | 2f2e1e21c0258ec8f9c2c67f8a49c0a4b150d81f (patch) | |
tree | d004f9107def156f6a1b11524f22823b44ac46aa | |
parent | 758bc862929d93a25303f7121fcddd746b0e3297 (diff) |
Allow unknown compound transitions composed of known transitions in the dev or silver sets. Note that this is impossible to predict for TOP_DOWN_COMPOUND, but it doesn't need to fail
-rw-r--r-- | stanza/models/constituency/parse_transitions.py | 36 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 12 | ||||
-rw-r--r-- | stanza/tests/constituency/test_parse_transitions.py | 20 |
3 files changed, 57 insertions, 11 deletions
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py index d62e7c42..dff3123c 100644 --- a/stanza/models/constituency/parse_transitions.py +++ b/stanza/models/constituency/parse_transitions.py @@ -157,6 +157,15 @@ class Transition(ABC): at parse time, the parser might choose a transition which cannot be made """ + def components(self): + """ + Return a list of transitions which could theoretically make up this transition + + For example, an Open transition with multiple labels would + return a list of Opens with those labels + """ + return [self] + @abstractmethod def short_name(self): """ @@ -275,6 +284,9 @@ class CompoundUnary(Transition): else: return is_root + def components(self): + return [CompoundUnary(label) for label in self.labels] + def short_name(self): return "Unary" @@ -403,6 +415,9 @@ class OpenConstituent(Transition): return True return True + def components(self): + return [OpenConstituent(label) for label in self.label] + def short_name(self): return "Open" @@ -523,6 +538,27 @@ class CloseConstituent(Transition): def __hash__(self): return hash(93) +def check_transitions(train_transitions, other_transitions, treebank_name): + """ + Check that all the transitions in the other dataset are known in the train set + + Weird nested unaries are warned rather than failed as long as the + components are all known + + There is a tree in VLSP, for example, with three (!) nested NP nodes + If this is an unknown compound transition, we won't possibly get it + right when parsing, but at least we don't need to fail + """ + unknown_transitions = set() + for trans in other_transitions: + if trans not in train_transitions: + for component in trans.components(): + if component not in train_transitions: + raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name)) + unknown_transitions.add(trans) + if len(unknown_transitions) > 0: + logger.warning("Found transitions where the components are all valid transitions, but the complete transition is unknown: %s", unknown_transitions) + def bulk_apply(model, state_batch, transitions, fail=False): """ Apply the given list of Transitions to the given list of States, using the model as a reference diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index c4fca06e..07474c54 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -357,14 +357,6 @@ def check_constituents(train_constituents, trees, treebank_name): if con not in train_constituents: raise RuntimeError("Found label {} in the {} set which don't exist in the train set".format(con, treebank_name)) -def check_transitions(train_transitions, other_transitions, treebank_name): - """ - Check that all the transitions in the other dataset are known in the train set - """ - for trans in other_transitions: - if trans not in train_transitions: - raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name)) - def check_root_labels(root_labels, other_trees, treebank_name): """ Check that all the root states in the other dataset are known in the train set @@ -404,9 +396,9 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, logger.info("Total unique transitions in train set: %d", len(train_transitions)) logger.info("Unique transitions in training set: %s", train_transitions) - check_transitions(train_transitions, dev_transitions, "dev") + parse_transitions.check_transitions(train_transitions, dev_transitions, "dev") # theoretically could just train based on the items in the silver dataset - check_transitions(train_transitions, silver_transitions, "silver") + parse_transitions.check_transitions(train_transitions, silver_transitions, "silver") verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit) verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit) diff --git a/stanza/tests/constituency/test_parse_transitions.py b/stanza/tests/constituency/test_parse_transitions.py index 8c184b14..602188df 100644 --- a/stanza/tests/constituency/test_parse_transitions.py +++ b/stanza/tests/constituency/test_parse_transitions.py @@ -2,7 +2,7 @@ import pytest from stanza.models.constituency import parse_transitions from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT -from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent from stanza.tests import * pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @@ -411,3 +411,21 @@ def test_sort(): transitions = set(expected) transitions = sorted(transitions) assert transitions == expected + +def test_check_transitions(): + """ + Test that check_transitions passes or fails a couple simple, small test cases + """ + transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")} + + other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")} + parse_transitions.check_transitions(transitions, other, "test") + + # This will get a pass because it is a unary made out of existing unaries + other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")} + parse_transitions.check_transitions(transitions, other, "test") + + # This should fail + with pytest.raises(RuntimeError): + other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")} + parse_transitions.check_transitions(transitions, other, "test") |