Welcome to mirror list, hosted at ThFree Co, Russian Federation.

flexibility_score.py « training « scripts - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 56d4f94255a1a609493736d41402a724b41c00c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# author: Rico Sennrich
#
# This file is part of moses.  Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.

"""Add flexibility scores to a phrase table half.

You usually don't have to call this script directly; to add flexibility
scores to your model, run train-model.perl with the option
"--flexibility-score" (will only affect steps 5 and 6).

Usage:
    python flexibility_score.py extract.context(.inv).sorted \
        [--Inverse] [--Hierarchical] < phrasetable > output_file
"""

from __future__ import division
from __future__ import unicode_literals

import sys
import gzip
from collections import defaultdict


class FlexScore:

    def __init__(self, inverted, hierarchical):
        self.inverted = inverted
        self.hierarchical = hierarchical

    def store_pt(self, obj):
        """Store line in dictionary.

        If we work with inverted phrase table, swap the two phrases.
        """
        src, target = obj[0], obj[1]

        if self.inverted:
            src, target = target, src

        self.phrase_pairs[src][target] = obj

    def update_contextcounts(self, obj):
        """count the number of contexts a phrase pair occurs in"""
        src, target = obj[0], obj[1]
        self.context_counts[src][target] += 1
        if obj[-1].startswith(b'<'):
            self.context_counts_l[src][target] += 1
        elif obj[-1].startswith(b'>'):
            self.context_counts_r[src][target] += 1
        elif obj[-1].startswith(b'v'):
            self.context_counts_d[src][target] += 1
        else:
            sys.stderr.write(
                b"\nERROR in line: {0}\n".format(b' ||| '.join(obj)))
            sys.stderr.write(
                b"ERROR: expecting one of '<, >, v' as context marker "
                "in context extract file.\n")
            raise ValueError

    def traverse_incrementally(self, phrasetable, flexfile):
        """Traverse phrase table and phrase extract file (with context
            information) incrementally without storing all in memory.
        """

        increment = b''
        old_increment = 1
        stack = [''] * 2

        # which phrase to use for sorting
        sort_pt = 0
        if self.inverted:
            sort_pt = 1

        while old_increment != increment:

            old_increment = increment

            self.phrase_pairs = defaultdict(dict)
            self.context_counts = defaultdict(lambda: defaultdict(int))
            self.context_counts_l = defaultdict(lambda: defaultdict(int))
            self.context_counts_r = defaultdict(lambda: defaultdict(int))
            self.context_counts_d = defaultdict(lambda: defaultdict(int))

            if stack[0]:
                self.store_pt(stack[0])
                stack[0] = b''

            if stack[1]:
                self.update_contextcounts(stack[1])
                stack[1] = b''

            for line in phrasetable:
                line = line.rstrip().split(b' ||| ')
                if line[sort_pt] != increment:
                    increment = line[sort_pt]
                    stack[0] = line
                    break
                else:
                    self.store_pt(line)

            for line in flexfile:
                line = line.rstrip().split(b' ||| ')
                if line[0] + b' |' <= old_increment + b' |':
                    self.update_contextcounts(line)

                else:
                    stack[1] = line
                    break

            yield 1

    def main(self, phrasetable, flexfile, output_object):

        i = 0
        sys.stderr.write(
            "Incrementally loading phrase table "
            "and adding flexibility score...")
        for block in self.traverse_incrementally(phrasetable, flexfile):

            self.flexprob_l = normalize(self.context_counts_l)
            self.flexprob_r = normalize(self.context_counts_r)
            self.flexprob_d = normalize(self.context_counts_d)

            # TODO: Why this lambda?  It doesn't affect sorting, does it?
            sortkey = lambda x: x + b' |'
            for src in sorted(self.phrase_pairs, key=sortkey):
                for target in sorted(self.phrase_pairs[src], key=sortkey):

                    if i % 1000000 == 0:
                        sys.stderr.write('.')
                    i += 1

                    outline = self.write_phrase_table(src, target)
                    output_object.write(outline)
        sys.stderr.write('done\n')

    def write_phrase_table(self, src, target):

        line = self.phrase_pairs[src][target]
        flexscore_l = b"{0:.6g}".format(self.flexprob_l[src][target])
        flexscore_r = b"{0:.6g}".format(self.flexprob_r[src][target])
        line[3] += b' ' + flexscore_l + b' ' + flexscore_r

        if self.hierarchical:
            try:
                flexscore_d = b"{0:.6g}".format(self.flexprob_d[src][target])
            except KeyError:
                flexscore_d = b"1"
            line[3] += b' ' + flexscore_d

        return b' ||| '.join(line) + b'\n'


def normalize(d):

    out_dict = defaultdict(dict)

    for src in d:
        total = sum(d[src].values())

        for target in d[src]:
            out_dict[src][target] = d[src][target] / total

    return out_dict


if __name__ == '__main__':

    if len(sys.argv) < 1:
        sys.stderr.write(
            "Usage: "
            "python flexibility_score.py extract.context(.inv).sorted "
            "[--Inverse] [--Hierarchical] < phrasetable > output_file\n")
        exit()

    flexfile = sys.argv[1]
    if '--Inverse' in sys.argv:
        inverted = True
    else:
        inverted = False

    if '--Hierarchical' in sys.argv:
        hierarchical = True
    else:
        hierarchical = False

    FS = FlexScore(inverted, hierarchical)
    FS.main(sys.stdin, gzip.open(flexfile, 'r'), sys.stdout)