diff options
author | Phil Williams <philip.williams@mac.com> | 2012-12-17 22:49:50 +0400 |
---|---|---|
committer | Phil Williams <philip.williams@mac.com> | 2012-12-17 22:49:50 +0400 |
commit | 06081f7ddb050696c759441fdb081655d233c749 (patch) | |
tree | 76613c29bce208ccbe6d60d79cd6e7521f654fac /scripts/analysis | |
parent | b275c94dbf690d492a763f03e70939bf86d24b84 (diff) |
extract-target-trees.py: minor fixes, code style
Diffstat (limited to 'scripts/analysis')
-rwxr-xr-x | scripts/analysis/extract-target-trees.py | 103 |
1 files changed, 52 insertions, 51 deletions
diff --git a/scripts/analysis/extract-target-trees.py b/scripts/analysis/extract-target-trees.py index 96e6dfe8f..c799a5c97 100755 --- a/scripts/analysis/extract-target-trees.py +++ b/scripts/analysis/extract-target-trees.py @@ -8,10 +8,12 @@ import re import sys + class Tree: def __init__(self, label, children): self.label = label self.children = children + def is_leaf(self): return len(self.children) == 0 @@ -29,82 +31,82 @@ class Derivation(list): return root def construct_target_tree(self): - map = {} + hypo_map = {} for hypothesis in self: - map[hypothesis.span] = hypothesis + hypo_map[hypothesis.span] = hypothesis root = self.find_root() - return self._buildTree(root, map) + return self._build_tree(root, hypo_map) - def _buildTree(self, root, map): - def escapeLabel(label): + def _build_tree(self, root, hypo_map): + def escape_label(label): s = label.replace("&", "&") s = s.replace("<", "<") s = s.replace(">", ">") return s # Build list of NT spans in source order... - nonTermSpans = [] - for item in root.sourceSymbolInfo: + non_term_spans = [] + for item in root.source_symbol_info: span = item[0] - if span in map: # In map iff symbol is NT - nonTermSpans.append(span) - nonTermSpans.sort() + if span != root.span and span in hypo_map: # In hypo_map iff symbol is NT + non_term_spans.append(span) + non_term_spans.sort() # ... then convert to target order. - alignmentPairs = root.ntAlignments[:] - alignmentPairs.sort() - targetOrderNonTermSpans = {} - for i, pair in enumerate(alignmentPairs): - targetOrderNonTermSpans[pair[1]] = nonTermSpans[i] + alignment_pairs = root.nt_alignments[:] + alignment_pairs.sort() + target_order_non_term_spans = {} + for i, pair in enumerate(alignment_pairs): + target_order_non_term_spans[pair[1]] = non_term_spans[i] children = [] - numNonTerms = 0 + num_non_terms = 0 - for i, symbol in enumerate(root.targetRHS): - if i in targetOrderNonTermSpans: - hyp = map[targetOrderNonTermSpans[i]] - children.append(self._buildTree(hyp, map)) - numNonTerms += 1 + for i, symbol in enumerate(root.target_rhs): + if i in target_order_non_term_spans: + hyp = hypo_map[target_order_non_term_spans[i]] + children.append(self._build_tree(hyp, hypo_map)) + num_non_terms += 1 else: - children.append(Tree(escapeLabel(symbol), [])) + children.append(Tree(escape_label(symbol), [])) - assert numNonTerms == len(root.ntAlignments) + assert num_non_terms == len(root.nt_alignments) - return Tree(root.targetLHS, children) + return Tree(root.target_lhs, children) class Hypothesis: def __init__(self): - self.sentenceNum = None + self.sentence_num = None self.span = None - self.sourceSymbolInfo = None - self.targetLHS = None - self.targetRHS = None - self.ntAlignments = None - - def __str__(self): - return str(self.id) + " " + str(self.component_scores) + self.source_symbol_info = None + self.target_lhs = None + self.target_rhs = None + self.nt_alignments = None -def readDerivations(input, lineNum): - prevSentenceNum = None +def read_derivations(input): + line_num = 0 + start_line_num = None + prev_sentence_num = None derivation = Derivation() for line in input: - lineNum += 1 - hypothesis = parseLine(line) - if hypothesis.sentenceNum != prevSentenceNum: + line_num += 1 + hypothesis = parse_line(line) + if hypothesis.sentence_num != prev_sentence_num: # We've started reading the next derivation... - prevSentenceNum = hypothesis.sentenceNum + prev_sentence_num = hypothesis.sentence_num if len(derivation): - yield derivation, lineNum + yield derivation, start_line_num derivation = Derivation() + start_line_num = line_num derivation.append(hypothesis) if len(derivation): - yield derivation, lineNum + yield derivation, start_line_num # Extract the hypothesis components and return a Hypothesis object. -def parseLine(s): +def parse_line(s): pattern = r"Trans Opt (\d+) " + \ r"\[(\d+)\.\.(\d+)\]:" + \ r"((?: \[\d+\.\.\d+\]=\S+ )+):" + \ @@ -118,9 +120,9 @@ def parseLine(s): assert match group = match.groups() hypothesis = Hypothesis() - hypothesis.sentenceNum = int(group[0]) + 1 + hypothesis.sentence_num = int(group[0]) + 1 hypothesis.span = (int(group[1]), int(group[2])) - hypothesis.sourceSymbolInfo = [] + hypothesis.source_symbol_info = [] for item in group[3].split(): pattern = "\[(\d+)\.\.(\d+)\]=(\S+)" regexp = re.compile(pattern) @@ -128,15 +130,15 @@ def parseLine(s): assert(match) start, end, symbol = match.groups() span = (int(start), int(end)) - hypothesis.sourceSymbolInfo.append((span, symbol)) - hypothesis.targetLHS = group[4] - hypothesis.targetRHS = group[5].split() - hypothesis.ntAlignments = [] + hypothesis.source_symbol_info.append((span, symbol)) + hypothesis.target_lhs = group[4] + hypothesis.target_rhs = group[5].split() + hypothesis.nt_alignments = [] for pair in group[6].split(): match = re.match(r'(\d+)-(\d+)', pair) assert match ai = (int(match.group(1)), int(match.group(2))) - hypothesis.ntAlignments.append(ai) + hypothesis.nt_alignments.append(ai) return hypothesis @@ -160,12 +162,11 @@ def main(): input = sys.stdin else: input = open(sys.argv[1]) - lineNum = 0 - for derivation, lineNum in readDerivations(input, lineNum): + for derivation, line_num in read_derivations(input): try: tree = derivation.construct_target_tree() except: - msg = "error processing derivation at line %d\n" % lineNum + msg = "error processing derivation starting at line %d\n" % line_num sys.stderr.write(msg) raise print tree_to_xml(tree) |