diff options
author | pjwilliams <pjwilliams@1f5c12ca-751b-0410-a591-d2e778427230> | 2011-09-19 13:08:24 +0400 |
---|---|---|
committer | pjwilliams <pjwilliams@1f5c12ca-751b-0410-a591-d2e778427230> | 2011-09-19 13:08:24 +0400 |
commit | a064f799e06e8ee6fdb5f8d579e297b645ea97d8 (patch) | |
tree | 28a964f88ecf8ad63cf3ee2478ecf5493ce25aa7 /scripts | |
parent | 2cf26266f7139c07feec9f2fd7a91ea066c3c07c (diff) |
Add scripts/analysis/extract-target-trees.py
Usage: extract-target-trees.py [FILE]
Reads moses-chart's -T output from FILE or standard input and writes trees to
standard output in Moses' XML tree format.
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4233 1f5c12ca-751b-0410-a591-d2e778427230
Diffstat (limited to 'scripts')
-rwxr-xr-x | scripts/analysis/extract-target-trees.py | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/scripts/analysis/extract-target-trees.py b/scripts/analysis/extract-target-trees.py new file mode 100755 index 000000000..96e6dfe8f --- /dev/null +++ b/scripts/analysis/extract-target-trees.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +# Usage: extract-target-trees.py [FILE] +# +# Reads moses-chart's -T output from FILE or standard input and writes trees to +# standard output in Moses' XML tree format. + +import re +import sys + +class Tree: + def __init__(self, label, children): + self.label = label + self.children = children + def is_leaf(self): + return len(self.children) == 0 + + +class Derivation(list): + def find_root(self): + assert len(self) > 0 + root = None + for hypothesis in self: + if hypothesis.span[0] != 0: + continue + if root == None or hypothesis.span[1] > root.span[1]: + root = hypothesis + assert root + return root + + def construct_target_tree(self): + map = {} + for hypothesis in self: + map[hypothesis.span] = hypothesis + root = self.find_root() + return self._buildTree(root, map) + + def _buildTree(self, root, map): + def escapeLabel(label): + s = label.replace("&", "&") + s = s.replace("<", "<") + s = s.replace(">", ">") + return s + + # Build list of NT spans in source order... + nonTermSpans = [] + for item in root.sourceSymbolInfo: + span = item[0] + if span in map: # In map iff symbol is NT + nonTermSpans.append(span) + nonTermSpans.sort() + + # ... then convert to target order. + alignmentPairs = root.ntAlignments[:] + alignmentPairs.sort() + targetOrderNonTermSpans = {} + for i, pair in enumerate(alignmentPairs): + targetOrderNonTermSpans[pair[1]] = nonTermSpans[i] + + children = [] + numNonTerms = 0 + + for i, symbol in enumerate(root.targetRHS): + if i in targetOrderNonTermSpans: + hyp = map[targetOrderNonTermSpans[i]] + children.append(self._buildTree(hyp, map)) + numNonTerms += 1 + else: + children.append(Tree(escapeLabel(symbol), [])) + + assert numNonTerms == len(root.ntAlignments) + + return Tree(root.targetLHS, children) + + +class Hypothesis: + def __init__(self): + self.sentenceNum = None + self.span = None + self.sourceSymbolInfo = None + self.targetLHS = None + self.targetRHS = None + self.ntAlignments = None + + def __str__(self): + return str(self.id) + " " + str(self.component_scores) + + +def readDerivations(input, lineNum): + prevSentenceNum = None + derivation = Derivation() + for line in input: + lineNum += 1 + hypothesis = parseLine(line) + if hypothesis.sentenceNum != prevSentenceNum: + # We've started reading the next derivation... + prevSentenceNum = hypothesis.sentenceNum + if len(derivation): + yield derivation, lineNum + derivation = Derivation() + derivation.append(hypothesis) + if len(derivation): + yield derivation, lineNum + + +# Extract the hypothesis components and return a Hypothesis object. +def parseLine(s): + pattern = r"Trans Opt (\d+) " + \ + r"\[(\d+)\.\.(\d+)\]:" + \ + r"((?: \[\d+\.\.\d+\]=\S+ )+):" + \ + r" (\S+) ->" + \ + r"((?:\S+ )+):" + \ + r"((?:\d+-\d+ )*): pC=" + regexp = re.compile(pattern) + match = regexp.match(s) + if not match: + sys.stderr.write("%s\n" % s) + assert match + group = match.groups() + hypothesis = Hypothesis() + hypothesis.sentenceNum = int(group[0]) + 1 + hypothesis.span = (int(group[1]), int(group[2])) + hypothesis.sourceSymbolInfo = [] + for item in group[3].split(): + pattern = "\[(\d+)\.\.(\d+)\]=(\S+)" + regexp = re.compile(pattern) + match = regexp.match(item) + assert(match) + start, end, symbol = match.groups() + span = (int(start), int(end)) + hypothesis.sourceSymbolInfo.append((span, symbol)) + hypothesis.targetLHS = group[4] + hypothesis.targetRHS = group[5].split() + hypothesis.ntAlignments = [] + for pair in group[6].split(): + match = re.match(r'(\d+)-(\d+)', pair) + assert match + ai = (int(match.group(1)), int(match.group(2))) + hypothesis.ntAlignments.append(ai) + return hypothesis + + +def tree_to_xml(tree): + if tree.is_leaf(): + return tree.label + else: + s = '<tree label="%s"> ' % tree.label + for child in tree.children: + s += tree_to_xml(child) + s += " " + s += '</tree>' + return s + + +def main(): + if len(sys.argv) > 2: + sys.stderr.write("usage: %s [FILE]\n" % sys.argv[0]) + sys.exit(1) + if len(sys.argv) == 1 or sys.argv[1] == "-": + input = sys.stdin + else: + input = open(sys.argv[1]) + lineNum = 0 + for derivation, lineNum in readDerivations(input, lineNum): + try: + tree = derivation.construct_target_tree() + except: + msg = "error processing derivation at line %d\n" % lineNum + sys.stderr.write(msg) + raise + print tree_to_xml(tree) + + +if __name__ == '__main__': + main() |