Welcome to mirror list, hosted at ThFree Co, Russian Federation.

translate.lua - github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 34c7794d99e8fdefadcbcdae1314621ab0cc0ace (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
require('onmt.init')
local tokenizer = require 'tools.utils.tokenizer'
local BPE = require ('tools.utils.BPE')

local cmd = onmt.utils.ExtendedCmdLine.new('translate.lua')

local options = {
  {
    '-src', '',
    [[Source sequences to translate.]],
    {
      valid = onmt.utils.ExtendedCmdLine.fileExists
    }
  },
  {
    '-tgt', '',
    [[Optional true target sequences.]]
  },
  {
    '-output', 'pred.txt',
    [[Output file.]]
  },
  {
    '-save_attention', '',
    [[Optional attention output file.]]
  },
  {
    '-batch_size', 30,
    [[Batch size.]],
    {
      valid = onmt.utils.ExtendedCmdLine.isInt(1)
    }
  },
  {
    '-idx_files', false,
    [[If set, source and target files are 'key value' with key match between source and target.]]
  },
  {
    '-detokenize_output', false,
    [[Detokenize output.]]
  }
}

cmd:setCmdLineOptions(options, 'Data')

onmt.translate.Translator.declareOpts(cmd)
tokenizer.declareOpts(cmd)
-- insert on the fly the option depending if there is a hook selected
onmt.utils.HookManager.updateOpt(arg, cmd)

-- expand options depending on source or target (tokenization, mpreprocessing)
onmt.translate.Translator.expandOpts(cmd, "bitext")

onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.HookManager.declareOpts(cmd)

onmt.utils.Logger.declareOpts(cmd)

cmd:text('')
cmd:text('Other options')
cmd:text('')

cmd:option('-time', false, [[Measure average translation time.]])

local function reportScore(name, scoreTotal, wordsTotal)
  _G.logger:info(name .. " AVG SCORE: %.2f, " .. name .. " PPL: %.2f",
                 scoreTotal / wordsTotal,
                 math.exp(-scoreTotal/wordsTotal))
end

local function main()
  local opt = cmd:parse(arg)

  _G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level, opt.log_tag)
  _G.profiler = onmt.utils.Profiler.new()
  _G.hookManager = onmt.utils.HookManager.new(opt)

  onmt.utils.Cuda.init(opt)

  local translator = onmt.translate.Translator.new(opt)

  local srcReader = onmt.utils.FileReader.new(opt.src, opt.idx_files, translator:srcFeat())
  local srcBatch = {}
  local srcIdBatch = {}

    -- tokenization options
  -- tokenization and preprocessing options
  local optTok = { {}, {} }
  local optMPr = { {}, {} }
  local bpes = {}
  for k, v in pairs(opt) do
    if k:sub(1,4) == 'tok_' then
      local idx = 1
      if k:sub(5, 8) == 'tgt_' then
        idx = 2
        k = k:sub(9)
      elseif k:sub(5,8) == 'src_' then
        k = k:sub(9)
      else
        k = k:sub(5)
      end
      optTok[idx][k] = v
    end
    if k:sub(1,4) == 'mpr_' then
      local idx = 1
      if k:sub(5, 8) == 'tgt_' then
        idx = 2
        k = k:sub(9)
      elseif k:sub(5,8) == 'src_' then
        k = k:sub(9)
      else
        k = k:sub(5)
      end
      optMPr[idx][k] = v
    end
  end

  if opt.tok_src_bpe_model ~= '' then
     local myopt = {}
     myopt.bpe_model = opt.tok_src_bpe_model
     myopt.bpe_EOT_marker = opt.tok_src_bpe_EOT_marker
     myopt.bpe_BOT_marker = opt.tok_src_bpe_BOT_marker
     myopt.joiner_new = opt.tok_src_joiner_new
     myopt.joiner_annotate = opt.tok_src_joiner_annotate
     myopt.bpe_mode = opt.tok_src_bpe_mode
     myopt.bpe_case_insensitive = opt.tok_src_bpe_case_insensitive
     bpes[1] = BPE.new(myopt)
  end
  if opt.tok_tgt_bpe_model ~= '' then
     local myopt = {}
     myopt.bpe_model = opt.tok_tgt_bpe_model
     myopt.bpe_EOT_marker = opt.tok_tgt_bpe_EOT_marker
     myopt.bpe_BOT_marker = opt.tok_tgt_bpe_BOT_marker
     myopt.joiner_new = opt.tok_tgt_joiner_new
     myopt.joiner_annotate = opt.tok_sgt_joiner_annotate
     myopt.bpe_mode = opt.tok_tgt_bpe_mode
     myopt.bpe_case_insensitive = opt.tok_tgt_bpe_case_insensitive
     bpes[2] = BPE.new(myopt)
  end

  for i = 1, 2 do
    _G.logger:info("Using on-the-fly '"..optTok[i]["mode"].."' tokenization for input "..i)
  end

  -- if source features - no tokenization
  if translator:srcFeat() then
    optTok[1] = nil
  end

  local goldReader
  local goldBatch

  local withGoldScore = opt.tgt:len() > 0
  local withAttention = opt.save_attention:len() > 0

  if withGoldScore then
    goldReader = onmt.utils.FileReader.new(opt.tgt, opt.idx_files)
    goldBatch = {}
  end

  local outFile = onmt.utils.Error.assert(io.open(opt.output, 'w'))
  local attFile
  if withAttention then
    attFile = onmt.utils.Error.assert(io.open(opt.save_attention, 'w'))
  end

  local sentId = 1
  local batchId = 1

  local predScoreTotal = 0
  local predWordsTotal = 0
  local goldScoreTotal = 0
  local goldWordsTotal = 0

  local globalUnkCountSrc = 0
  local globalTotalCountSrc = 0
  local globalUnkCountTgt = 0
  local globalTotalCountTgt = 0

  local timer
  if opt.time then
    timer = torch.Timer()
    timer:stop()
    timer:reset()
  end

  while true do
    local srcSeq, srcSeqId = srcReader:next(false)

    local goldOutputSeq
    if withGoldScore then
      goldOutputSeq = goldReader:next(false)
      if goldOutputSeq then
        goldOutputSeq =  _G.hookManager:call("mpreprocess", optMPr[2], goldOutputSeq) or goldOutputSeq
        goldOutputSeq = tokenizer.tokenize(optTok[2], goldOutputSeq, bpes[2])
      end
    end

    if srcSeq then
      if srcSeq:len() > 0 then
        srcSeq = _G.hookManager:call("mpreprocess", optMPr[1], srcSeq) or srcSeq
        if optTok[1] then
          srcSeq = tokenizer.tokenize(optTok[1], srcSeq, bpes[1])
        end
      else
        srcSeq = {}
      end
      table.insert(srcBatch, translator:buildInput(srcSeq))
      table.insert(srcIdBatch, srcSeqId)

      if withGoldScore then
        table.insert(goldBatch, translator:buildInputGold(goldOutputSeq))
      end
    elseif #srcBatch == 0 then
      break
    end

    if srcSeq == nil or #srcBatch == opt.batch_size then
      if opt.time then
        timer:resume()
      end

      local results, unkCountSrc, totalCountSrc = translator:translate(srcBatch, goldBatch)

      globalUnkCountSrc = globalUnkCountSrc + unkCountSrc;
      globalTotalCountSrc = globalTotalCountSrc + totalCountSrc

      if opt.time then
        timer:stop()
      end

      for b = 1, #results do
        if (srcBatch[b].words and #srcBatch[b].words == 0
            or srcBatch[b].vectors and srcBatch[b].vectors:dim() == 0) then
          _G.logger:warning('Line ' .. sentId .. ' is empty.')
          outFile:write('\n')
        else
          if srcBatch[b].words then
            _G.logger:info('SENT %d: %s', sentId, translator:buildOutput(srcBatch[b]))
          else
            _G.logger:info('FEATS %d: IDX - %s - SIZE %d', sentId, srcIdBatch[b], srcBatch[b].vectors:size(1))
          end

          if withGoldScore then
            _G.logger:info('GOLD %d: %s', sentId, translator:buildOutput(goldBatch[b]), results[b].goldScore)
            _G.logger:info("GOLD SCORE: %.2f", results[b].goldScore)
            goldScoreTotal = goldScoreTotal + results[b].goldScore
            goldWordsTotal = goldWordsTotal + #goldBatch[b].words
          end

          if opt.dump_input_encoding then
            outFile:write(sentId, ' ', table.concat(torch.totable(results[b]), " "), '\n')
          else
            for n = 1, #results[b].preds do
              -- count target unknown words and words generated on 1-best
              if n == 1 then
                globalTotalCountTgt = globalTotalCountTgt + #results[b].preds[n].words
                for _, w in ipairs(results[b].preds[n].words) do
                  globalUnkCountTgt = globalUnkCountTgt + (w==onmt.Constants.UNK_WORD and 1 or 0)
                end
              end
              local sentence
              if opt.detokenize_output then
                sentence = tokenizer.detokenize(optTok[2],
                                                results[b].preds[n].words,
                                                results[b].preds[n].features)
              else
                sentence = translator:buildOutput(results[b].preds[n])
              end
              outFile:write(sentence .. '\n')

              if withAttention then
                local attentions = results[b].preds[n].attention
                local score = results[b].preds[n].score
                local targetLength = #attentions

                if translator:srcFeat() then
                  attFile:write(string.format('%d ||| %s ||| %f ||| %d\n',
                                              sentId, sentence, score, targetLength))
                else
                  local source = translator:buildOutput(srcBatch[b])
                  local sourceLength = #srcBatch[b].words
                  attFile:write(string.format('%d ||| %s ||| %f ||| %s ||| %d %d\n',
                                              sentId, sentence, score, source,
                                              sourceLength, targetLength))
                end

                for _, attention in ipairs(attentions) do
                  if attention ~= nil then
                    attFile:write(table.concat(torch.totable(attention), ' '))
                    attFile:write('\n')
                  end
                end

                attFile:write('\n')
              end

              if n == 1 then
                predScoreTotal = predScoreTotal + results[b].preds[n].score
                predWordsTotal = predWordsTotal + #results[b].preds[n].words

                if #results[b].preds > 1 then
                  _G.logger:info('')
                  _G.logger:info('BEST HYP:')
                end
              end

              if #results[b].preds > 1 then
                _G.logger:info("[%.2f] %s", results[b].preds[n].score, sentence)
              else
                _G.logger:info("PRED %d: %s", sentId, sentence)
                _G.logger:info("PRED SCORE: %.2f", results[b].preds[n].score)
              end
            end
          end
        end
        _G.logger:info('')
        sentId = sentId + 1
      end

      if srcSeq == nil then
        break
      end

      batchId = batchId + 1
      srcBatch = {}
      srcIdBatch = {}
      if withGoldScore then
        goldBatch = {}
      end
      collectgarbage()
    end
  end

  _G.logger:info("Translated "..globalTotalCountSrc.." words, src unk count: "..globalUnkCountSrc..", coverage: "..
                 ((math.floor(globalUnkCountSrc*1000/globalTotalCountSrc))/10).."%, "..
                 "tgt words: "..globalTotalCountTgt.." words, tgt unk count: "..globalUnkCountTgt..", coverage: "..
                 ((math.floor(globalUnkCountTgt*1000/globalTotalCountTgt))/10).."%, ")

  if opt.time then
    local time = timer:time()
    local sentenceCount = sentId-1
    _G.logger:info("Average sentence translation time (in seconds):\n")
    _G.logger:info("avg real\t" .. time.real / sentenceCount .. "\n")
    _G.logger:info("avg user\t" .. time.user / sentenceCount .. "\n")
    _G.logger:info("avg sys\t" .. time.sys / sentenceCount .. "\n")
  end

  if opt.dump_input_encoding == false then
    reportScore('PRED', predScoreTotal, predWordsTotal)

    if withGoldScore then
      reportScore('GOLD', goldScoreTotal, goldWordsTotal)
    end
  end

  if opt.save_beam_to:len() > 0 then
    translator:saveBeamHistories(opt.save_beam_to)
  end

  outFile:close()
  _G.logger:shutDown()
end

main()