Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-05 06:56:18 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-05 06:56:18 +0300
commit2f2e1e21c0258ec8f9c2c67f8a49c0a4b150d81f (patch)
treed004f9107def156f6a1b11524f22823b44ac46aa
parent758bc862929d93a25303f7121fcddd746b0e3297 (diff)
Allow unknown compound transitions composed of known transitions in the dev or silver sets. Note that this is impossible to predict for TOP_DOWN_COMPOUND, but it doesn't need to fail
-rw-r--r--stanza/models/constituency/parse_transitions.py36
-rw-r--r--stanza/models/constituency/trainer.py12
-rw-r--r--stanza/tests/constituency/test_parse_transitions.py20
3 files changed, 57 insertions, 11 deletions
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py
index d62e7c42..dff3123c 100644
--- a/stanza/models/constituency/parse_transitions.py
+++ b/stanza/models/constituency/parse_transitions.py
@@ -157,6 +157,15 @@ class Transition(ABC):
at parse time, the parser might choose a transition which cannot be made
"""
+ def components(self):
+ """
+ Return a list of transitions which could theoretically make up this transition
+
+ For example, an Open transition with multiple labels would
+ return a list of Opens with those labels
+ """
+ return [self]
+
@abstractmethod
def short_name(self):
"""
@@ -275,6 +284,9 @@ class CompoundUnary(Transition):
else:
return is_root
+ def components(self):
+ return [CompoundUnary(label) for label in self.labels]
+
def short_name(self):
return "Unary"
@@ -403,6 +415,9 @@ class OpenConstituent(Transition):
return True
return True
+ def components(self):
+ return [OpenConstituent(label) for label in self.label]
+
def short_name(self):
return "Open"
@@ -523,6 +538,27 @@ class CloseConstituent(Transition):
def __hash__(self):
return hash(93)
+def check_transitions(train_transitions, other_transitions, treebank_name):
+ """
+ Check that all the transitions in the other dataset are known in the train set
+
+ Weird nested unaries are warned rather than failed as long as the
+ components are all known
+
+ There is a tree in VLSP, for example, with three (!) nested NP nodes
+ If this is an unknown compound transition, we won't possibly get it
+ right when parsing, but at least we don't need to fail
+ """
+ unknown_transitions = set()
+ for trans in other_transitions:
+ if trans not in train_transitions:
+ for component in trans.components():
+ if component not in train_transitions:
+ raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name))
+ unknown_transitions.add(trans)
+ if len(unknown_transitions) > 0:
+ logger.warning("Found transitions where the components are all valid transitions, but the complete transition is unknown: %s", unknown_transitions)
+
def bulk_apply(model, state_batch, transitions, fail=False):
"""
Apply the given list of Transitions to the given list of States, using the model as a reference
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index c4fca06e..07474c54 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -357,14 +357,6 @@ def check_constituents(train_constituents, trees, treebank_name):
if con not in train_constituents:
raise RuntimeError("Found label {} in the {} set which don't exist in the train set".format(con, treebank_name))
-def check_transitions(train_transitions, other_transitions, treebank_name):
- """
- Check that all the transitions in the other dataset are known in the train set
- """
- for trans in other_transitions:
- if trans not in train_transitions:
- raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name))
-
def check_root_labels(root_labels, other_trees, treebank_name):
"""
Check that all the root states in the other dataset are known in the train set
@@ -404,9 +396,9 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache,
logger.info("Total unique transitions in train set: %d", len(train_transitions))
logger.info("Unique transitions in training set: %s", train_transitions)
- check_transitions(train_transitions, dev_transitions, "dev")
+ parse_transitions.check_transitions(train_transitions, dev_transitions, "dev")
# theoretically could just train based on the items in the silver dataset
- check_transitions(train_transitions, silver_transitions, "silver")
+ parse_transitions.check_transitions(train_transitions, silver_transitions, "silver")
verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit)
verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit)
diff --git a/stanza/tests/constituency/test_parse_transitions.py b/stanza/tests/constituency/test_parse_transitions.py
index 8c184b14..602188df 100644
--- a/stanza/tests/constituency/test_parse_transitions.py
+++ b/stanza/tests/constituency/test_parse_transitions.py
@@ -2,7 +2,7 @@ import pytest
from stanza.models.constituency import parse_transitions
from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
-from stanza.models.constituency.parse_transitions import TransitionScheme
+from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@@ -411,3 +411,21 @@ def test_sort():
transitions = set(expected)
transitions = sorted(transitions)
assert transitions == expected
+
+def test_check_transitions():
+ """
+ Test that check_transitions passes or fails a couple simple, small test cases
+ """
+ transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
+
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
+ parse_transitions.check_transitions(transitions, other, "test")
+
+ # This will get a pass because it is a unary made out of existing unaries
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")}
+ parse_transitions.check_transitions(transitions, other, "test")
+
+ # This should fail
+ with pytest.raises(RuntimeError):
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")}
+ parse_transitions.check_transitions(transitions, other, "test")