Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/models/constituency/parse_transitions.py')
-rw-r--r--stanza/models/constituency/parse_transitions.py603
1 files changed, 603 insertions, 0 deletions
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py
new file mode 100644
index 00000000..ec815caa
--- /dev/null
+++ b/stanza/models/constituency/parse_transitions.py
@@ -0,0 +1,603 @@
+"""
+Defines a series of transitions (open a constituent, close a constituent, etc
+
+Also defines a State which holds the various data needed to build
+a parse tree out of tagged words.
+"""
+
+from abc import ABC, abstractmethod
+from collections import defaultdict, namedtuple
+from enum import Enum
+import functools
+import logging
+
+from stanza.models.constituency.parse_tree import Tree
+
+logger = logging.getLogger('stanza')
+
+class TransitionScheme(Enum):
+ TOP_DOWN = 1
+ TOP_DOWN_COMPOUND = 2
+ TOP_DOWN_UNARY = 3
+
+ IN_ORDER = 4
+
+UNARY_LIMIT = 4
+
+class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',
+ 'sentence_length', 'num_opens', 'word_position'])):
+ """
+ Represents a partially completed transition parse
+
+ Includes stack/buffers for unused words, already executed transitions, and partially build constituents
+ At training time, also keeps track of the gold data we are reparsing
+
+ num_opens is useful for tracking
+ 1) if the parser is in a stuck state where it is making infinite opens
+ 2) if a close transition is impossible because there are no previous opens
+
+ sentence_length tracks how long the sentence is so we abort if we go infinite
+
+ non-stack information such as sentence_length and num_opens
+ will be copied from the original_state if possible, with the
+ exact arguments overriding the values in the original_state
+
+ gold_tree: the original tree, if made from a gold tree. might be None
+ gold_sequence: the original transition sequence, if available
+ Note that at runtime, gold values will not be available
+
+ word_position tracks where in the word queue we are. cheaper than
+ manipulating the list itself. this can be handled differently
+ from transitions and constituents as it is processed once
+ at the start of parsing
+ """
+ def empty_word_queue(self):
+ # the first element of each stack is a sentinel with no value
+ # and no parent
+ return self.word_position == self.sentence_length
+
+ def empty_transitions(self):
+ # the first element of each stack is a sentinel with no value
+ # and no parent
+ return self.transitions.parent is None
+
+ def has_one_constituent(self):
+ # a length of 1 represents no constituents
+ return len(self.constituents) == 2
+
+ def num_constituents(self):
+ return len(self.constituents) - 1
+
+ def num_transitions(self):
+ # -1 for the sentinel value
+ return len(self.transitions) - 1
+
+ def finished(self, model):
+ return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.get_root_labels()
+
+ def get_tree(self, model):
+ return model.get_top_constituent(self.constituents)
+
+ def all_transitions(self, model):
+ # TODO: rewrite this to be nicer / faster? or just refactor?
+ all_transitions = []
+ transitions = self.transitions
+ while transitions.parent is not None:
+ all_transitions.append(model.get_top_transition(transitions))
+ transitions = transitions.parent
+ return list(reversed(all_transitions))
+
+ def all_constituents(self, model):
+ # TODO: rewrite this to be nicer / faster?
+ all_constituents = []
+ constituents = self.constituents
+ while constituents.parent is not None:
+ all_constituents.append(model.get_top_constituent(constituents))
+ constituents = constituents.parent
+ return list(reversed(all_constituents))
+
+ def all_words(self, model):
+ return [model.get_word(x) for x in self.word_queue]
+
+ def to_string(self, model):
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)))
+
+ def __str__(self):
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents))
+
+def initial_state_from_preterminals(preterminal_lists, model, gold_trees):
+ """
+ what is passed in should be a list of list of preterminals
+ """
+ word_queues = model.initial_word_queues(preterminal_lists)
+ # this is the bottom of the TreeStack and will be the same for each State
+ transitions=model.initial_transitions()
+ constituents=model.initial_constituents()
+ states = [State(sentence_length=len(wq)-1, # -1 because it ends with a sentinel
+ num_opens=0,
+ word_queue=wq,
+ gold_tree=None,
+ gold_sequence=None,
+ transitions=transitions,
+ constituents=constituents,
+ word_position=0)
+ for idx, wq in enumerate(word_queues)]
+ if gold_trees:
+ states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]
+ return states
+
+def initial_state_from_words(word_lists, model):
+ # TODO: stop reversing the words
+ preterminal_lists = []
+ for words in word_lists:
+ preterminals = []
+ for word, tag in reversed(words):
+ word_node = Tree(label=word)
+ tag_node = Tree(label=tag, children=[word_node])
+ preterminals.append(tag_node)
+ preterminal_lists.append(preterminals)
+ return initial_state_from_preterminals(preterminal_lists, model, gold_trees=None)
+
+def initial_state_from_gold_trees(trees, model):
+ # reversed so we put the words on the stack backwards
+ preterminal_lists = [[Tree(label=pt.label, children=Tree(label=pt.children[0].label))
+ for pt in tree.yield_reversed_preterminals()]
+ for tree in trees]
+ return initial_state_from_preterminals(preterminal_lists, model, gold_trees=trees)
+
+@functools.total_ordering
+class Transition(ABC):
+ """
+ model is passed in as a dependency injection
+ for example, an LSTM model can update hidden & output vectors when transitioning
+ """
+ @abstractmethod
+ def update_state(self, state, model):
+ """
+ update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent
+
+ the return value should be a tuple:
+ updated word_position
+ updated constituents
+ new constituent to put on the queue and None
+ - note that the constituent shouldn't be on the queue yet
+ that allows putting it on as a batch operation, which
+ saves a significant amount of time in an LSTM, for example
+ OR
+ data used to make a new constituent and the method used
+ - for example, CloseConstituent can return the children needed
+ and itself. this allows a batch operation to build
+ the constituent
+ """
+ pass
+
+ def delta_opens(self):
+ return 0
+
+ def apply(self, state, model):
+ """
+ return a new State transformed via this transition
+ """
+ word_position, constituents, new_constituent, callback = self.update_state(state, model)
+ if callback is not None:
+ new_constituent = callback.build_constituents(model, [new_constituent])[0]
+ constituents = model.push_constituents([constituents], [new_constituent])[0]
+
+ return state._replace(num_opens=state.num_opens + self.delta_opens(),
+ word_position=word_position,
+ transitions=model.push_transitions([state.transitions], [self])[0],
+ constituents=constituents)
+
+ @abstractmethod
+ def is_legal(self, state, model):
+ """
+ assess whether or not this transition is legal in this state
+
+ at parse time, the parser might choose a transition which cannot be made
+ """
+ pass
+
+ def __lt__(self, other):
+ # put the Shift at the front of a list, and otherwise sort alphabetically
+ if self == other:
+ return False
+ if isinstance(self, Shift):
+ return True
+ if isinstance(other, Shift):
+ return False
+ return str(self) < str(other)
+
+class Shift(Transition):
+ def update_state(self, state, model):
+ """
+ This will handle all aspects of a shift transition
+
+ - push the top element of the word queue onto constituents
+ - pop the top element of the word queue
+ """
+ new_constituent = model.transform_word_to_constituent(state)
+ return state.word_position+1, state.constituents, new_constituent, None
+
+ def is_legal(self, state, model):
+ """
+ Disallow shifting when the word queue is empty or there are no opens to eventually eat this word
+ """
+ if state.empty_word_queue():
+ return False
+ if model.is_top_down():
+ # top down transition sequences cannot shift if there are currently no
+ # Open transitions on the stack. in such a case, the new constituent
+ # will never be reduced
+ if state.num_opens == 0:
+ return False
+ if state.num_opens == 1:
+ # there must be at least one transition, since there is an open
+ assert state.transitions.parent is not None
+ if state.transitions.parent.parent is None:
+ # only one transition
+ trans = model.get_top_transition(state.transitions)
+ # must be an Open, since there is one open and one transitions
+ # note that an S, FRAG, etc could happen if we're using unary
+ # and ROOT-S is possible in the case of compound Open
+ # in both cases, Shift is legal
+ # Note that the corresponding problem of shifting after the ROOT-S
+ # has been closed to just ROOT is handled in CloseConstituent
+ if len(trans.label) == 1 and trans.top_label in model.get_root_labels():
+ # don't shift a word at the very start of a parse
+ # we want there to be an extra layer below ROOT
+ return False
+ else:
+ # in-order k==1 (the only other option currently)
+ # can shift ONCE, but note that there is no way to consume
+ # two items in a row if there is no Open on the stack.
+ # As long as there is one or more open transitions,
+ # everything can be eaten
+ if state.num_opens == 0:
+ if state.num_constituents() > 0:
+ return False
+ return True
+
+ def __repr__(self):
+ return "Shift"
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if isinstance(other, Shift):
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(37)
+
+class CompoundUnary(Transition):
+ # TODO: run experiments to see if this is actually useful
+ def __init__(self, labels):
+ # the FIRST label will be the top of the tree
+ # so CompoundUnary that results in root will have root as labels[0], for example
+ if isinstance(labels, str):
+ self.labels = (labels,)
+ else:
+ self.labels = tuple(labels)
+
+ def update_state(self, state, model):
+ # remove the top constituent
+ # apply the labels
+ # put the constituent back on the state
+ constituents = state.constituents
+ new_constituent = model.unary_transform(state.constituents, self.labels)
+ constituents = constituents.pop()
+ return state.word_position, constituents, new_constituent, None
+
+ def is_legal(self, state, model):
+ """
+ Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT
+ """
+ # can't unary transition nothing
+ if model.get_top_constituent(state.constituents) is None:
+ return False
+ # don't unary transition a dummy, dummy
+ # and don't stack CompoundUnary transitions
+ if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)):
+ return False
+ is_root = self.labels[0] in model.get_root_labels()
+ if not state.empty_word_queue() or not state.has_one_constituent():
+ return not is_root
+ else:
+ return is_root
+
+ def __repr__(self):
+ return "CompoundUnary(%s)" % ",".join(self.labels)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, CompoundUnary):
+ return False
+ if self.labels == other.labels:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.labels)
+
+class Dummy():
+ """
+ Takes a space on the constituent stack to represent where an Open transition occurred
+ """
+ def __init__(self, label):
+ self.label = label
+
+ def __str__(self):
+ return "Dummy({})".format(self.label)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, Dummy):
+ return False
+ if self.label == other.label:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.label)
+
+def too_many_unary_nodes(tree):
+ """
+ Return True iff there are UNARY_LIMIT unary nodes in a tree in a row
+
+ helps prevent infinite open/close patterns
+ otherwise, the model can get stuck in essentially an infinite loop
+ """
+ if tree is None:
+ return False
+ for _ in range(UNARY_LIMIT + 1):
+ if len(tree.children) != 1:
+ return False
+ tree = tree.children[0]
+ return True
+
+class OpenConstituent(Transition):
+ def __init__(self, *label):
+ self.label = tuple(label)
+ self.top_label = self.label[0]
+
+ def delta_opens(self):
+ return 1
+
+ def update_state(self, state, model):
+ # open a new constituent which can later be closed
+ # puts a DUMMY constituent on the stack to mark where the constituents end
+ return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None
+
+ def is_legal(self, state, model):
+ """
+ disallow based on the length of the sentence
+ """
+ if state.num_opens > state.sentence_length + 5:
+ # fudge a bit so we don't miss root nodes etc in very small trees
+ return False
+ if model.is_top_down():
+ # If the model is top down, you can't Open if there are
+ # no word to eventually eat
+ if state.empty_word_queue():
+ return False
+ # Also, you can only Open a ROOT iff it is at the root position
+ # The assumption in the unary scheme is there will be no
+ # root open transitions
+ if not model.has_unary_transitions():
+ # TODO: maybe cache this value if this is an expensive operation
+ is_root = self.top_label in model.get_root_labels()
+ if is_root:
+ return state.empty_transitions()
+ else:
+ return not state.empty_transitions()
+ else:
+ # in-order nodes can Open as long as there is at least one thing
+ # on the constituency stack
+ # since closing the in-order involves removing one more
+ # item before the open, and it can close at any time
+ # (a close immediately after the open represents a unary)
+ if state.num_constituents() == 0:
+ return False
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ # consecutive Opens don't make sense in the context of in-order
+ return False
+ # one other restriction - we assume all parse trees
+ # start with (ROOT (first_real_con ...))
+ # therefore ROOT can only occur via Open after everything
+ # else has been pushed and processed
+ # there are no further restrictions
+ is_root = self.top_label in model.get_root_labels()
+ if is_root:
+ # can't make a root node if it will be in the middle of the parse
+ # can't make a root node if there's still words to eat
+ # note that the second assumption wouldn't work,
+ # except we are assuming there will never be multiple
+ # nodes under one root
+ return state.num_opens == 0 and state.empty_word_queue()
+ else:
+ if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents)):
+ # looks like we've been in a loop of lots of unary transitions
+ # note that we check `num_opens > 0` because otherwise we might wind up stuck
+ # in a state where the only legal transition is open, such as if the
+ # constituent stack is otherwise empty, but the open is illegal because
+ # it causes too many unaries
+ # in such a case we can forbid the corresponding close instead...
+ # if empty_word_queue, that means it is trying to make infinitiely many
+ # non-ROOT Open transitions instead of just transitioning ROOT
+ return False
+ return True
+ return True
+
+ def __repr__(self):
+ return "OpenConstituent({})".format(self.label)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, OpenConstituent):
+ return False
+ if self.label == other.label:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.label)
+
+class CloseConstituent(Transition):
+ def delta_opens(self):
+ return -1
+
+ def update_state(self, state, model):
+ # pop constituents until we are done
+ children = []
+ constituents = state.constituents
+ while not isinstance(model.get_top_constituent(constituents), Dummy):
+ # keep the entire value from the stack - the model may need
+ # the whole thing to transform the children into a new node
+ children.append(constituents.value)
+ constituents = constituents.pop()
+ # the Dummy has the label on it
+ label = model.get_top_constituent(constituents).label
+ # pop past the Dummy as well
+ constituents = constituents.pop()
+ if not model.is_top_down():
+ # the alternative to TOP_DOWN_... is IN_ORDER
+ # in which case we want to pop one more constituent
+ children.append(constituents.value)
+ constituents = constituents.pop()
+ # the children are in the opposite order of what we expect
+ children.reverse()
+
+ return state.word_position, constituents, (label, children), CloseConstituent
+
+ @staticmethod
+ def build_constituents(model, data):
+ labels, children_lists = list(map(list, zip(*data)))
+ new_constituents = model.build_constituents(labels, children_lists)
+ return new_constituents
+
+
+ def is_legal(self, state, model):
+ """
+ Disallow if there is no Open on the stack yet
+ in TOP_DOWN, if the previous transition was the Open (nothing built yet)
+ in IN_ORDER, previous transition does not matter, except for one small corner case
+ """
+ if state.num_opens <= 0:
+ return False
+ if model.is_top_down():
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ return False
+ if state.num_opens <= 1 and not state.empty_word_queue():
+ # don't close the last open until all words have been used
+ return False
+ if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND:
+ # when doing TOP_DOWN_COMPOUND, we assume all transitions
+ # at the ROOT level have an S, SQ, FRAG, etc underneath
+ # this is checked when the model is first trained
+ if state.num_opens == 1 and not state.empty_word_queue():
+ return False
+ elif not model.has_unary_transitions():
+ # in fact, we have to leave the top level constituent
+ # under the ROOT open if unary transitions are not possible
+ if state.num_opens == 2 and not state.empty_word_queue():
+ return False
+ else:
+ if not isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ # we're not stuck in a loop of unaries
+ return True
+ if state.num_opens > 1 or state.empty_word_queue():
+ # in either of these cases, the corresponding Open should be eliminated
+ # if we're stuck in a loop of unaries
+ return True
+ node = model.get_top_constituent(state.constituents.pop())
+ if too_many_unary_nodes(node):
+ # at this point, we are in a situation where
+ # - multiple unaries have happened in a row
+ # - there is stuff on the word_queue, so a ROOT open isn't legal
+ # - there's only one constituent on the stack, so the only legal
+ # option once there are no opens left will be an open
+ # this means we'll be stuck having to open again if we do close
+ # this node, so instead we make the Close illegal
+ return False
+ return True
+
+ def __repr__(self):
+ return "CloseConstituent"
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if isinstance(other, CloseConstituent):
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(93)
+
+def bulk_apply(model, tree_batch, transitions, fail=False, max_transitions=1000):
+ remove = set()
+
+ word_positions = []
+ constituents = []
+ new_constituents = []
+ callbacks = defaultdict(list)
+
+ for idx, (tree, transition) in enumerate(zip(tree_batch, transitions)):
+ if not transition:
+ error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model))
+ if fail:
+ raise ValueError(error)
+ else:
+ logger.error(error)
+ remove.add(idx)
+ continue
+
+ if max_transitions and tree.num_transitions() >= max_transitions:
+ # too many transitions
+ if tree.gold_tree:
+ error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model))
+ else:
+ error = "Went infinite!:\nFinal state:\n{}".format(tree.to_string(model))
+ if fail:
+ raise ValueError(error)
+ else:
+ logger.error(error)
+ remove.add(idx)
+ continue
+
+ wq, c, nc, callback = transition.update_state(tree, model)
+
+ word_positions.append(wq)
+ constituents.append(c)
+ new_constituents.append(nc)
+ if callback:
+ # not `idx` in case something was removed
+ callbacks[callback].append(len(new_constituents)-1)
+
+ for key, idxs in callbacks.items():
+ data = [new_constituents[x] for x in idxs]
+ callback_constituents = key.build_constituents(model, data)
+ for idx, constituent in zip(idxs, callback_constituents):
+ new_constituents[idx] = constituent
+
+ tree_batch = [tree for idx, tree in enumerate(tree_batch) if idx not in remove]
+ transitions = [trans for idx, trans in enumerate(transitions) if idx not in remove]
+
+ if len(tree_batch) == 0:
+ return tree_batch
+
+ new_transitions = model.push_transitions([tree.transitions for tree in tree_batch], transitions)
+ new_constituents = model.push_constituents(constituents, new_constituents)
+
+ tree_batch = [state._replace(num_opens=state.num_opens + transition.delta_opens(),
+ word_position=word_position,
+ transitions=transition_stack,
+ constituents=constituents)
+ for (state, transition, word_position, transition_stack, constituents)
+ in zip(tree_batch, transitions, word_positions, new_transitions, new_constituents)]
+
+ return tree_batch