Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-03-01 00:30:32 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-03 19:54:23 +0300
commita847419660d90ee571fd7895ae2bd8ae4a4a2290 (patch)
tree611ece99e22544b790d20622abe0f1ca741470e8
parent19c6210d9f7c3a9ba72225e0e3311b4cf96ac304 (diff)
Add a tag label to a Shift
- when converting trees to transition sequences, add labels on the Shift based on the tags - labeled Shifts are only legal if the tag is the expected tag Fix shift transitions in the models that don't already have a label attribute
-rw-r--r--stanza/models/constituency/parse_transitions.py23
-rw-r--r--stanza/models/constituency/parse_tree.py1
-rw-r--r--stanza/models/constituency/trainer.py38
-rw-r--r--stanza/models/constituency/transition_sequence.py51
-rw-r--r--stanza/tests/constituency/test_transition_sequence.py25
5 files changed, 106 insertions, 32 deletions
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py
index d62e7c42..1422ef16 100644
--- a/stanza/models/constituency/parse_transitions.py
+++ b/stanza/models/constituency/parse_transitions.py
@@ -105,6 +105,12 @@ class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'g
def all_words(self, model):
return [model.get_word(x) for x in self.word_queue]
+ def next_tagged_word(self, model):
+ """
+ Returns a preterminal Tree, eg, has a tag node and word child node
+ """
+ return model.get_word(self.get_word(self.word_position))
+
def to_string(self, model):
return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens)
@@ -167,6 +173,12 @@ class Transition(ABC):
# put the Shift at the front of a list, and otherwise sort alphabetically
if self == other:
return False
+ if isinstance(self, Shift) and isinstance(other, Shift):
+ if other.label is None:
+ return False
+ if self.label is None:
+ return True
+ return self.label < other.label
if isinstance(self, Shift):
return True
if isinstance(other, Shift):
@@ -174,6 +186,9 @@ class Transition(ABC):
return str(self) < str(other)
class Shift(Transition):
+ def __init__(self, label=None):
+ self.label = label
+
def update_state(self, state, model):
"""
This will handle all aspects of a shift transition
@@ -190,6 +205,8 @@ class Shift(Transition):
"""
if state.empty_word_queue():
return False
+ if self.label is not None and self.label != state.next_tagged_word(model).label:
+ return False
if model.is_top_down():
# top down transition sequences cannot shift if there are currently no
# Open transitions on the stack. in such a case, the new constituent
@@ -227,16 +244,20 @@ class Shift(Transition):
return "Shift"
def __repr__(self):
+ if self.label is not None:
+ return "Shift(%s)" % self.label
return "Shift"
def __eq__(self, other):
if self is other:
return True
if isinstance(other, Shift):
- return True
+ return self.label == other.label
return False
def __hash__(self):
+ if self.label:
+ return hash(self.label) + 37
return hash(37)
class CompoundUnary(Transition):
diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py
index f145dafd..7db70caf 100644
--- a/stanza/models/constituency/parse_tree.py
+++ b/stanza/models/constituency/parse_tree.py
@@ -372,6 +372,7 @@ class Tree(StanzaObject):
threshold = max(int(len(words) * threshold), 1)
return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])
+
@staticmethod
def get_root_labels(trees):
return sorted(set(x.label for x in trees))
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 74992b19..1e645999 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -77,6 +77,18 @@ class Trainer:
logger.info("Model saved to %s", filename)
@staticmethod
+ def fix_shift_transitions(transitions):
+ """
+ Fix models with Shift transitions that have no labels
+ """
+ new_trans = []
+ for t in transitions:
+ if isinstance(t, parse_transitions.Shift) and not hasattr(t, "label"):
+ t.label = None
+ new_trans.append(t)
+ return new_trans
+
+ @staticmethod
def model_from_params(params, args, foundation_cache=None):
"""
Build a new model just from the saved params and some extra args
@@ -108,7 +120,8 @@ class Trainer:
backward_charlm=backward_charlm,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
- transitions=params['transitions'],
+ # TODO: remove this when models are saved with updated transitions
+ transitions=Trainer.fix_shift_transitions(params['transitions']),
constituents=params['constituents'],
tags=params['tags'],
words=params['words'],
@@ -343,6 +356,13 @@ def remove_optimizer(args, model_save_file, model_load_file):
trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False)
trainer.save(model_save_file)
+def convert_trees_to_sequences(trees, tree_type, transition_scheme, known_tags):
+ logger.info("Building {} transition sequences".format(tree_type))
+ if logger.getEffectiveLevel() <= logging.INFO:
+ trees = tqdm(trees)
+ sequences = transition_sequence.build_treebank(trees, transition_scheme, known_tags)
+ return sequences
+
def add_grad_clipping(trainer, grad_clipping):
"""
Adds a torch.clamp hook on each parameter if grad_clipping is not None
@@ -402,9 +422,17 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache,
max(t.count_unary_depth() for t in dev_trees)) + 1
if silver_trees:
unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))
- train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'])
- dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'])
- silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'])
+ train_sequences = convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], tags)
+ # the training transitions will all be labeled with the tags
+ # currently we are just checking correctness
+ # we add an unlabeled Shift so that the model can represent previously unseen tags
+ # at train time we will redo some tags as <UNK> to train the unlabeled Shift
+ # (this also will essentially be a form of dropout)
+ train_transitions = transition_sequence.all_transitions(train_sequences + [[parse_transitions.Shift()]])
+ dev_sequences = convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], tags)
+ dev_transitions = transition_sequence.all_transitions(dev_sequences)
+ silver_sequences = convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], tags)
+ silver_transitions = transition_sequence.all_transitions(silver_sequences)
logger.info("Total unique transitions in train set: %d", len(train_transitions))
logger.info("Unique transitions in training set: %s", train_transitions)
@@ -819,7 +847,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
# the state is build as a bulk operation
gold_trees = [x.tree.dropout_tags(args['tag_dropout']) for x in training_batch]
preterminals = [list(x.yield_preterminals()) for x in gold_trees]
- train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'])
+ train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.tags)
initial_states = model.initial_state_from_preterminals(preterminals, gold_trees)
current_batch = [state._replace(gold_sequence=sequence)
for sequence, state in zip(train_sequences, initial_states)]
diff --git a/stanza/models/constituency/transition_sequence.py b/stanza/models/constituency/transition_sequence.py
index 34209dda..135c5e37 100644
--- a/stanza/models/constituency/transition_sequence.py
+++ b/stanza/models/constituency/transition_sequence.py
@@ -14,7 +14,7 @@ tqdm = utils.get_tqdm()
logger = logging.getLogger('stanza.constituency.trainer')
-def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+def yield_top_down_sequence(tree, known_tags, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
"""
For tree (X A B C D), yield Open(X) A B C D Close
@@ -25,7 +25,11 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN
TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close
"""
if tree.is_preterminal():
- yield Shift()
+ tag = tree.label
+ if tag in known_tags:
+ yield Shift(tag)
+ else:
+ yield Shift()
return
if tree.is_leaf():
@@ -37,7 +41,7 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN
while not tree.is_preterminal() and len(tree.children) == 1:
labels.append(tree.label)
tree = tree.children[0]
- for transition in yield_top_down_sequence(tree, transition_scheme):
+ for transition in yield_top_down_sequence(tree, known_tags, transition_scheme):
yield transition
yield CompoundUnary(labels)
return
@@ -51,46 +55,54 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN
else:
yield OpenConstituent(tree.label)
for child in tree.children:
- for transition in yield_top_down_sequence(child, transition_scheme):
+ for transition in yield_top_down_sequence(child, known_tags, transition_scheme):
yield transition
yield CloseConstituent()
-def yield_in_order_sequence(tree):
+def yield_in_order_sequence(tree, known_tags):
"""
For tree (X A B C D), yield A Open(X) B C D Close
"""
if tree.is_preterminal():
- yield Shift()
+ tag = tree.label
+ if tag in known_tags:
+ yield Shift(tag)
+ else:
+ yield Shift()
return
if tree.is_leaf():
return
- for transition in yield_in_order_sequence(tree.children[0]):
+ for transition in yield_in_order_sequence(tree.children[0], known_tags):
yield transition
yield OpenConstituent(tree.label)
for child in tree.children[1:]:
- for transition in yield_in_order_sequence(child):
+ for transition in yield_in_order_sequence(child, known_tags):
yield transition
yield CloseConstituent()
-def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, known_tags=None):
"""
Turn a single tree into a list of transitions based on the TransitionScheme
"""
+ if known_tags is None:
+ known_tags = set()
if transition_scheme is TransitionScheme.IN_ORDER:
- return list(yield_in_order_sequence(tree))
+ return list(yield_in_order_sequence(tree, known_tags))
else:
- return list(yield_top_down_sequence(tree, transition_scheme))
+ return list(yield_top_down_sequence(tree, known_tags, transition_scheme))
-def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, known_tags=None):
"""
Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme
"""
- return [build_sequence(tree, transition_scheme) for tree in trees]
+ if known_tags is None:
+ known_tags = set()
+ return [build_sequence(tree, transition_scheme, known_tags) for tree in trees]
def all_transitions(transition_lists):
"""
@@ -101,19 +113,6 @@ def all_transitions(transition_lists):
transitions.update(trans_list)
return sorted(transitions)
-def convert_trees_to_sequences(trees, tree_type, transition_scheme):
- """
- Wrap both build_treebank and all_transitions, possibly with a tqdm
-
- Converts trees to a list of sequences, then returns the list of known transitions
- """
- logger.info("Building {} transition sequences".format(tree_type))
- if logger.getEffectiveLevel() <= logging.INFO:
- trees = tqdm(trees)
- sequences = build_treebank(trees, transition_scheme)
- transitions = all_transitions(sequences)
- return sequences, transitions
-
def main():
"""
Convert a sample tree and print its transitions
diff --git a/stanza/tests/constituency/test_transition_sequence.py b/stanza/tests/constituency/test_transition_sequence.py
index ce34c0e7..b92af064 100644
--- a/stanza/tests/constituency/test_transition_sequence.py
+++ b/stanza/tests/constituency/test_transition_sequence.py
@@ -117,3 +117,28 @@ def test_chinese_tree():
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
assert redone == trees[0]
+
+def test_labeled_shift():
+ """
+ Test that both inorder and preorder transition lists are produced with labeled shifts
+ """
+ trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
+
+ preterminals = list(trees[0].yield_preterminals())
+ common_tags = set([preterminals[0].label])
+
+ transitions = transition_sequence.build_treebank(trees, known_tags=common_tags, transition_scheme=TransitionScheme.TOP_DOWN)
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
+ assert redone == trees[0]
+ shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)])
+ assert len(shift_transitions) > 1
+
+ transitions = transition_sequence.build_treebank(trees, known_tags=None, transition_scheme=TransitionScheme.TOP_DOWN)
+ shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)])
+ assert len(shift_transitions) == 1
+
+ transitions = transition_sequence.build_treebank(trees, known_tags=common_tags, transition_scheme=TransitionScheme.IN_ORDER)
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
+ assert redone == trees[0]
+ shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)])
+ assert len(shift_transitions) > 1