Welcome to mirror list, hosted at ThFree Co, Russian Federation.

extract.py « bilingual-lm « training « scripts - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 876fba9ee7fa3ffb0bc077598e53d469f57ad2c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python
#
# This file is part of moses.  Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.

from collections import Counter
import logging
import sys

LOG = logging.getLogger(__name__)

BOS = "<s>"
EOS = "</s>"
UNK = "<unk>"


def replace_tags(tokens, tags, vocab):
    for i, t in enumerate(tokens):
        if t not in vocab:
            if i < len(tags):
                tokens[i] = tags[i]
            else:
                print "Error: missing tags for index i:", i
                print ' '.join(tokens)
                print ' '.join(tags)
                tokens[i] = UNK


def replace_unks(tokens, vocab):
    for i, t in enumerate(tokens):
        if t not in vocab:
            tokens[i] = UNK


def numberize(line, m, n, svocab, tvocab):
    line = line.split()
    source_words = line[:2 * m + 1]
    target_words = line[-n:]

    line = ' '.join([str(svocab[item]) for item in source_words]) + ' '
    line += ' '.join([str(tvocab[item]) for item in target_words]) + '\n'

    return line


def get_ngrams(corpus_stem, align_file, tagged_stem, svocab, tvocab, slang,
               tlang, m, n, ofh):
    """
      m - source context
      n - target context

      returns set of tags used
    """
    tags = Counter()
    sfh = open(corpus_stem + "." + slang)
    tfh = open(corpus_stem + "." + tlang)
    afh = open(align_file)
    fhs = [sfh, tfh, afh]
    if tagged_stem:
        fhs.append(open(tagged_stem + "." + slang))
        fhs.append(open(tagged_stem + "." + tlang))

    count = 0
    ngrams = 0
    LOG.info("Extracting ngrams")
    for lines in zip(*fhs):
        stokens = lines[0][:-1].split()
        ttokens = lines[1][:-1].split()
        stokens.append(EOS)
        ttokens.append(EOS)
        if tagged_stem:
            stags = lines[3][:-1].split()
            ttags = lines[4][:-1].split()
            stags.append(EOS)
            ttags.append(EOS)
            tags.update(stags)
            tags.update(ttags)
            replace_tags(stokens, stags, svocab)
            replace_tags(ttokens, ttags, tvocab)
        else:
            replace_unks(stokens, svocab)
            replace_unks(ttokens, tvocab)
        # List aligns for each target.
        # Note: align specifies source -> target
        target_aligns = [[] for t in range(len(ttokens))]
        for atoken in lines[2][:-1].split():
            spos, tpos = atoken.split("-")
            spos, tpos = int(spos), int(tpos)
            target_aligns[tpos].append(spos)

        # EOS alignment.
        target_aligns[-1] = [len(stokens) - 1]

        for tpos, spos_list in enumerate(target_aligns):
            # Affiliation heuristics - see Devlin t al. p1371
            if not spos_list:
                # tpos has no alignment, look right, then left, then
                # right-right, then left-left etc.
                rpos = tpos + 1
                lpos = tpos - 1
                while rpos < len(ttokens) or lpos >= 0:
                    if rpos < len(ttokens) and target_aligns[rpos]:
                        spos_list = target_aligns[rpos]
                        break
                    if lpos >= 0 and target_aligns[lpos]:
                        spos_list = target_aligns[lpos]
                        break
                    rpos += 1
                    lpos -= 1

            if not spos_list:
                raise Exception(
                    "No alignments in sentence \nSRC: " +
                    lines[0][:-1] + "\nTGT: " + lines[1][:-1])
            midpos = (len(spos_list) - 1) / 2
            spos = sorted(spos_list)[midpos]

            # source-context, target-context, predicted word
            for i in range(max(0, m - spos)):
                print>>ofh, BOS,
                # print [spos-m/2,spos+m/2+1], stokens[spos-m/2:spos+m/2+1]
            print>>ofh, " ".join(
                [s for s in stokens[max(0, spos - m):spos + m + 1]]),
            for i in range(max(0, spos + m + 1 - len(stokens))):
                print>>ofh, EOS,
            for i in range(max(0, n - (tpos + 1))):
                print>>ofh, BOS,
            print>>ofh, " ".join(
                [t for t in ttokens[max(0, tpos + 1 - n):tpos + 1]]),
            print>>ofh
            ngrams += 1

        count += 1
        if count % 1000 == 0:
            sys.stderr.write(".")
        if count % 50000 == 0:
            sys.stderr.write(" [%d]\n" % count)
    ofh.close()
    sys.stderr.write("\n")
    LOG.info("Extracted %d ngrams" % ngrams)
    return tags