Welcome to mirror list, hosted at ThFree Co, Russian Federation.

extract-target-trees.py « analysis « scripts - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 5dd097ff05b788d9dc14e34b70e1065c0872faab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python
#
# This file is part of moses.  Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.

"""Usage: extract-target-trees.py [FILE]

Reads moses-chart's -T output from FILE or standard input and writes trees to
standard output in Moses' XML tree format.
"""

import re
import sys


class Tree:
    def __init__(self, label, children):
        self.label = label
        self.children = children

    def is_leaf(self):
        return len(self.children) == 0


class Derivation(list):
    def find_root(self):
        assert len(self) > 0
        root = None
        for hypothesis in self:
            if hypothesis.span[0] != 0:
                continue
            if root is None or hypothesis.span[1] > root.span[1]:
                root = hypothesis
        assert root
        return root

    def construct_target_tree(self):
        hypo_map = {}
        for hypothesis in self:
            hypo_map[hypothesis.span] = hypothesis
        root = self.find_root()
        return self._build_tree(root, hypo_map)

    def _build_tree(self, root, hypo_map):
        def escape_label(label):
            s = label.replace("&", "&")
            s = s.replace("<", "&lt;")
            s = s.replace(">", "&gt;")
            return s

        # Build list of NT spans in source order...
        non_term_spans = []
        for item in root.source_symbol_info:
            span = item[0]
            # In hypo_map iff symbol is NT:
            if span != root.span and span in hypo_map:
                non_term_spans.append(span)
        non_term_spans.sort()

        # ... then convert to target order.
        alignment_pairs = root.nt_alignments[:]
        alignment_pairs.sort()
        target_order_non_term_spans = {}
        for i, pair in enumerate(alignment_pairs):
            target_order_non_term_spans[pair[1]] = non_term_spans[i]

        children = []
        num_non_terms = 0

        for i, symbol in enumerate(root.target_rhs):
            if i in target_order_non_term_spans:
                hyp = hypo_map[target_order_non_term_spans[i]]
                children.append(self._build_tree(hyp, hypo_map))
                num_non_terms += 1
            else:
                children.append(Tree(escape_label(symbol), []))

        assert num_non_terms == len(root.nt_alignments)

        return Tree(root.target_lhs, children)


class Hypothesis:
    def __init__(self):
        self.sentence_num = None
        self.span = None
        self.source_symbol_info = None
        self.target_lhs = None
        self.target_rhs = None
        self.nt_alignments = None


def read_derivations(input):
    line_num = 0
    start_line_num = None
    prev_sentence_num = None
    derivation = Derivation()
    for line in input:
        line_num += 1
        hypothesis = parse_line(line)
        if hypothesis.sentence_num != prev_sentence_num:
            # We've started reading the next derivation...
            prev_sentence_num = hypothesis.sentence_num
            if len(derivation):
                yield derivation, start_line_num
                derivation = Derivation()
            start_line_num = line_num
        derivation.append(hypothesis)
    if len(derivation):
        yield derivation, start_line_num


def parse_line(s):
    if s.startswith("Trans Opt"):
        return parse_line_old_format(s)
    else:
        return parse_line_new_format(s)


# Extract the hypothesis components and return a Hypothesis object.
def parse_line_old_format(s):
    pattern = r"Trans Opt (\d+) " + \
              r"\[(\d+)\.\.(\d+)\]:" + \
              r"((?: \[\d+\.\.\d+\]=\S+  )+):" + \
              r" (\S+) ->\S+  -> " + \
              r"((?:\S+ )+):" + \
              r"((?:\d+-\d+ )*): c="
    regexp = re.compile(pattern)
    match = regexp.match(s)
    if not match:
        sys.stderr.write("%s\n" % s)
    assert match
    group = match.groups()
    hypothesis = Hypothesis()
    hypothesis.sentence_num = int(group[0]) + 1
    hypothesis.span = (int(group[1]), int(group[2]))
    hypothesis.source_symbol_info = []
    for item in group[3].split():
        pattern = "\[(\d+)\.\.(\d+)\]=(\S+)"
        regexp = re.compile(pattern)
        match = regexp.match(item)
        assert(match)
        start, end, symbol = match.groups()
        span = (int(start), int(end))
        hypothesis.source_symbol_info.append((span, symbol))
    hypothesis.target_lhs = group[4]
    hypothesis.target_rhs = group[5].split()
    hypothesis.nt_alignments = []
    for pair in group[6].split():
        match = re.match(r'(\d+)-(\d+)', pair)
        assert match
        ai = (int(match.group(1)), int(match.group(2)))
        hypothesis.nt_alignments.append(ai)
    return hypothesis


# Extract the hypothesis components and return a Hypothesis object.
def parse_line_new_format(s):
    pattern = r"(\d+) \|\|\|" + \
              r" (\[\S+\]) -> ((?:\S+ )+)\|\|\|" + \
              r" (\[\S+\]) -> ((?:\S+ )+)\|\|\|" + \
              r" ((?:\d+-\d+ )*)\|\|\|" + \
              r"((?: \d+\.\.\d+)*)"
    regexp = re.compile(pattern)
    match = regexp.match(s)
    if not match:
        sys.stderr.write("%s\n" % s)
    assert match
    group = match.groups()
    hypothesis = Hypothesis()
    hypothesis.sentence_num = int(group[0]) + 1
    spans = []
    for pair in group[6].split():
        match = re.match(r'(\d+)\.\.(\d+)', pair)
        assert match
        span = (int(match.group(1)), int(match.group(2)))
        spans.append(span)
    hypothesis.span = (spans[0][0], spans[-1][1])
    hypothesis.source_symbol_info = []
    for i, symbol in enumerate(group[2].split()):
        hypothesis.source_symbol_info.append((spans[i], strip_brackets(symbol)))
    hypothesis.target_lhs = strip_brackets(group[3])
    hypothesis.target_rhs = group[4].split()
    hypothesis.nt_alignments = []
    for pair in group[5].split():
        match = re.match(r'(\d+)-(\d+)', pair)
        assert match
        ai = (int(match.group(1)), int(match.group(2)))
        hypothesis.nt_alignments.append(ai)
    return hypothesis


def strip_brackets(symbol):
    if symbol[0] == '[' and symbol[-1] == ']':
        return symbol[1:-1]
    return symbol


def tree_to_xml(tree):
    if tree.is_leaf():
        return tree.label
    else:
        s = '<tree label="%s"> ' % tree.label
        for child in tree.children:
            s += tree_to_xml(child)
            s += " "
        s += '</tree>'
        return s


def main():
    if len(sys.argv) > 2:
        sys.stderr.write("usage: %s [FILE]\n" % sys.argv[0])
        sys.exit(1)
    if len(sys.argv) == 1 or sys.argv[1] == "-":
        input = sys.stdin
    else:
        input = open(sys.argv[1])
    for derivation, line_num in read_derivations(input):
        try:
            tree = derivation.construct_target_tree()
        except:
            msg = (
                "error processing derivation starting at line %d\n"
                % line_num)
            sys.stderr.write(msg)
            raise
        print tree_to_xml(tree)


if __name__ == '__main__':
    main()