diff options
author | John Bauer <horatio@gmail.com> | 2022-03-01 00:30:32 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-03 19:54:23 +0300 |
commit | a847419660d90ee571fd7895ae2bd8ae4a4a2290 (patch) | |
tree | 611ece99e22544b790d20622abe0f1ca741470e8 | |
parent | 19c6210d9f7c3a9ba72225e0e3311b4cf96ac304 (diff) |
Add a tag label to a Shift
- when converting trees to transition sequences, add labels on the Shift based on the tags
- labeled Shifts are only legal if the tag is the expected tag
Fix shift transitions in the models that don't already have a label attribute
-rw-r--r-- | stanza/models/constituency/parse_transitions.py | 23 | ||||
-rw-r--r-- | stanza/models/constituency/parse_tree.py | 1 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 38 | ||||
-rw-r--r-- | stanza/models/constituency/transition_sequence.py | 51 | ||||
-rw-r--r-- | stanza/tests/constituency/test_transition_sequence.py | 25 |
5 files changed, 106 insertions, 32 deletions
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py index d62e7c42..1422ef16 100644 --- a/stanza/models/constituency/parse_transitions.py +++ b/stanza/models/constituency/parse_transitions.py @@ -105,6 +105,12 @@ class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'g def all_words(self, model): return [model.get_word(x) for x in self.word_queue] + def next_tagged_word(self, model): + """ + Returns a preterminal Tree, eg, has a tag node and word child node + """ + return model.get_word(self.get_word(self.word_position)) + def to_string(self, model): return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens) @@ -167,6 +173,12 @@ class Transition(ABC): # put the Shift at the front of a list, and otherwise sort alphabetically if self == other: return False + if isinstance(self, Shift) and isinstance(other, Shift): + if other.label is None: + return False + if self.label is None: + return True + return self.label < other.label if isinstance(self, Shift): return True if isinstance(other, Shift): @@ -174,6 +186,9 @@ class Transition(ABC): return str(self) < str(other) class Shift(Transition): + def __init__(self, label=None): + self.label = label + def update_state(self, state, model): """ This will handle all aspects of a shift transition @@ -190,6 +205,8 @@ class Shift(Transition): """ if state.empty_word_queue(): return False + if self.label is not None and self.label != state.next_tagged_word(model).label: + return False if model.is_top_down(): # top down transition sequences cannot shift if there are currently no # Open transitions on the stack. in such a case, the new constituent @@ -227,16 +244,20 @@ class Shift(Transition): return "Shift" def __repr__(self): + if self.label is not None: + return "Shift(%s)" % self.label return "Shift" def __eq__(self, other): if self is other: return True if isinstance(other, Shift): - return True + return self.label == other.label return False def __hash__(self): + if self.label: + return hash(self.label) + 37 return hash(37) class CompoundUnary(Transition): diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py index f145dafd..7db70caf 100644 --- a/stanza/models/constituency/parse_tree.py +++ b/stanza/models/constituency/parse_tree.py @@ -372,6 +372,7 @@ class Tree(StanzaObject): threshold = max(int(len(words) * threshold), 1) return sorted(x[0] for x in words.most_common()[:-threshold-1:-1]) + @staticmethod def get_root_labels(trees): return sorted(set(x.label for x in trees)) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 74992b19..1e645999 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -77,6 +77,18 @@ class Trainer: logger.info("Model saved to %s", filename) @staticmethod + def fix_shift_transitions(transitions): + """ + Fix models with Shift transitions that have no labels + """ + new_trans = [] + for t in transitions: + if isinstance(t, parse_transitions.Shift) and not hasattr(t, "label"): + t.label = None + new_trans.append(t) + return new_trans + + @staticmethod def model_from_params(params, args, foundation_cache=None): """ Build a new model just from the saved params and some extra args @@ -108,7 +120,8 @@ class Trainer: backward_charlm=backward_charlm, bert_model=bert_model, bert_tokenizer=bert_tokenizer, - transitions=params['transitions'], + # TODO: remove this when models are saved with updated transitions + transitions=Trainer.fix_shift_transitions(params['transitions']), constituents=params['constituents'], tags=params['tags'], words=params['words'], @@ -343,6 +356,13 @@ def remove_optimizer(args, model_save_file, model_load_file): trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False) trainer.save(model_save_file) +def convert_trees_to_sequences(trees, tree_type, transition_scheme, known_tags): + logger.info("Building {} transition sequences".format(tree_type)) + if logger.getEffectiveLevel() <= logging.INFO: + trees = tqdm(trees) + sequences = transition_sequence.build_treebank(trees, transition_scheme, known_tags) + return sequences + def add_grad_clipping(trainer, grad_clipping): """ Adds a torch.clamp hook on each parameter if grad_clipping is not None @@ -402,9 +422,17 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, max(t.count_unary_depth() for t in dev_trees)) + 1 if silver_trees: unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees)) - train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme']) - dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme']) - silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme']) + train_sequences = convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], tags) + # the training transitions will all be labeled with the tags + # currently we are just checking correctness + # we add an unlabeled Shift so that the model can represent previously unseen tags + # at train time we will redo some tags as <UNK> to train the unlabeled Shift + # (this also will essentially be a form of dropout) + train_transitions = transition_sequence.all_transitions(train_sequences + [[parse_transitions.Shift()]]) + dev_sequences = convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], tags) + dev_transitions = transition_sequence.all_transitions(dev_sequences) + silver_sequences = convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], tags) + silver_transitions = transition_sequence.all_transitions(silver_sequences) logger.info("Total unique transitions in train set: %d", len(train_transitions)) logger.info("Unique transitions in training set: %s", train_transitions) @@ -819,7 +847,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te # the state is build as a bulk operation gold_trees = [x.tree.dropout_tags(args['tag_dropout']) for x in training_batch] preterminals = [list(x.yield_preterminals()) for x in gold_trees] - train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme']) + train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.tags) initial_states = model.initial_state_from_preterminals(preterminals, gold_trees) current_batch = [state._replace(gold_sequence=sequence) for sequence, state in zip(train_sequences, initial_states)] diff --git a/stanza/models/constituency/transition_sequence.py b/stanza/models/constituency/transition_sequence.py index 34209dda..135c5e37 100644 --- a/stanza/models/constituency/transition_sequence.py +++ b/stanza/models/constituency/transition_sequence.py @@ -14,7 +14,7 @@ tqdm = utils.get_tqdm() logger = logging.getLogger('stanza.constituency.trainer') -def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): +def yield_top_down_sequence(tree, known_tags, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): """ For tree (X A B C D), yield Open(X) A B C D Close @@ -25,7 +25,11 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close """ if tree.is_preterminal(): - yield Shift() + tag = tree.label + if tag in known_tags: + yield Shift(tag) + else: + yield Shift() return if tree.is_leaf(): @@ -37,7 +41,7 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN while not tree.is_preterminal() and len(tree.children) == 1: labels.append(tree.label) tree = tree.children[0] - for transition in yield_top_down_sequence(tree, transition_scheme): + for transition in yield_top_down_sequence(tree, known_tags, transition_scheme): yield transition yield CompoundUnary(labels) return @@ -51,46 +55,54 @@ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UN else: yield OpenConstituent(tree.label) for child in tree.children: - for transition in yield_top_down_sequence(child, transition_scheme): + for transition in yield_top_down_sequence(child, known_tags, transition_scheme): yield transition yield CloseConstituent() -def yield_in_order_sequence(tree): +def yield_in_order_sequence(tree, known_tags): """ For tree (X A B C D), yield A Open(X) B C D Close """ if tree.is_preterminal(): - yield Shift() + tag = tree.label + if tag in known_tags: + yield Shift(tag) + else: + yield Shift() return if tree.is_leaf(): return - for transition in yield_in_order_sequence(tree.children[0]): + for transition in yield_in_order_sequence(tree.children[0], known_tags): yield transition yield OpenConstituent(tree.label) for child in tree.children[1:]: - for transition in yield_in_order_sequence(child): + for transition in yield_in_order_sequence(child, known_tags): yield transition yield CloseConstituent() -def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): +def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, known_tags=None): """ Turn a single tree into a list of transitions based on the TransitionScheme """ + if known_tags is None: + known_tags = set() if transition_scheme is TransitionScheme.IN_ORDER: - return list(yield_in_order_sequence(tree)) + return list(yield_in_order_sequence(tree, known_tags)) else: - return list(yield_top_down_sequence(tree, transition_scheme)) + return list(yield_top_down_sequence(tree, known_tags, transition_scheme)) -def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): +def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, known_tags=None): """ Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme """ - return [build_sequence(tree, transition_scheme) for tree in trees] + if known_tags is None: + known_tags = set() + return [build_sequence(tree, transition_scheme, known_tags) for tree in trees] def all_transitions(transition_lists): """ @@ -101,19 +113,6 @@ def all_transitions(transition_lists): transitions.update(trans_list) return sorted(transitions) -def convert_trees_to_sequences(trees, tree_type, transition_scheme): - """ - Wrap both build_treebank and all_transitions, possibly with a tqdm - - Converts trees to a list of sequences, then returns the list of known transitions - """ - logger.info("Building {} transition sequences".format(tree_type)) - if logger.getEffectiveLevel() <= logging.INFO: - trees = tqdm(trees) - sequences = build_treebank(trees, transition_scheme) - transitions = all_transitions(sequences) - return sequences, transitions - def main(): """ Convert a sample tree and print its transitions diff --git a/stanza/tests/constituency/test_transition_sequence.py b/stanza/tests/constituency/test_transition_sequence.py index ce34c0e7..b92af064 100644 --- a/stanza/tests/constituency/test_transition_sequence.py +++ b/stanza/tests/constituency/test_transition_sequence.py @@ -117,3 +117,28 @@ def test_chinese_tree(): redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) assert redone == trees[0] + +def test_labeled_shift(): + """ + Test that both inorder and preorder transition lists are produced with labeled shifts + """ + trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) + + preterminals = list(trees[0].yield_preterminals()) + common_tags = set([preterminals[0].label]) + + transitions = transition_sequence.build_treebank(trees, known_tags=common_tags, transition_scheme=TransitionScheme.TOP_DOWN) + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN) + assert redone == trees[0] + shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)]) + assert len(shift_transitions) > 1 + + transitions = transition_sequence.build_treebank(trees, known_tags=None, transition_scheme=TransitionScheme.TOP_DOWN) + shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)]) + assert len(shift_transitions) == 1 + + transitions = transition_sequence.build_treebank(trees, known_tags=common_tags, transition_scheme=TransitionScheme.IN_ORDER) + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) + assert redone == trees[0] + shift_transitions = set([x for x in transitions[0] if isinstance(x, Shift)]) + assert len(shift_transitions) > 1 |