Welcome to mirror list, hosted at ThFree Co, Russian Federation.

apply_embeddings.lua « tools - github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7a73f5e2270d1de1ab8b8996ce2e4f5c508a9764 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
require('torch')
require('onmt.init')

local cmd = onmt.utils.ExtendedCmdLine.new('apply_embeddings.lua')

cmd:setCmdLineOptions(
  {
    {
      '-txt_src', '',
      [[Tokenized text file to apply embeddings on.]],
      {
        valid = onmt.utils.ExtendedCmdLine.fileExists
      }
    },
    {
      '-txt_tgt', '',
      [[Aligned target file.]],
      {
        valid = onmt.utils.ExtendedCmdLine.fileNullOrExists
      }
    },
    {
      '-dict', '',
      [[Dictionary]],
      {
        valid = onmt.utils.ExtendedCmdLine.fileExists
      }
    },
    {
      '-embed_data', '',
      [[Embedding model corresponding to dictionary generated with embeddings.lua.]],
      {
        valid = onmt.utils.ExtendedCmdLine.fileExists
      }
    },
    {
      '-save_prefix', '',
      [[Output file prefix (.src,.tgt) will be saved.]],
      {
        valid = onmt.utils.ExtendedCmdLine.nonEmpty
      }
    }
  }, 'Data')

onmt.utils.Logger.declareOpts(cmd)

local opt = cmd:parse(arg)

local function main()
  _G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)

  local embeddingWeights = torch.load(opt.embed_data)

  local Vocabulary = onmt.data.Vocabulary
  local dict = Vocabulary.init('source',
                               opt.txt_src,
                               opt.dict,
                               { 50000 },
                               {},
                               '',
                               function() return true end,
                               false,
                               false)

  assert(dict.words:size(1) == embeddingWeights:size(1))
  local fsrc = io.open(opt.save_prefix .. ".src", "w")
  local ftgt
  if opt.txt_tgt ~= '' then
    ftgt = io.open(opt.save_prefix .. ".tgt", "w")
  end

  local wordEmbedding = onmt.WordEmbedding.new(dict.words:size(1),
                                               embeddingWeights:size(2),
                                               embeddingWeights)

  local readerSrc = onmt.utils.FileReader.new(opt.txt_src)

  local readerTgt
  if opt.txt_tgt ~= '' then
    readerTgt = onmt.utils.FileReader.new(opt.txt_tgt)
  end

  local count = 1

  while true do
    local tokensSrc = readerSrc:next()

    if tokensSrc == nil then
      break
    end
    local IDX = 'IDX' .. count
    local words, feats = onmt.utils.Features.extract(tokensSrc)
    local vec = dict.words:convertToIdx(words, onmt.Constants.UNK_WORD)
    assert(#feats == 0)
    fsrc:write(IDX .. ' [\n')
    for i = 1, vec:size(1) do
      local we = wordEmbedding:forward(torch.LongTensor(1):fill(vec[i]))[1]
      for j = 1, embeddingWeights:size(2) do
        if j > 1 then
          fsrc:write(" ")
        end
        fsrc:write(string.format("%.4f", we[j]))
      end
      if i == vec:size(1) then
        fsrc:write(" ]")
      end
      fsrc:write("\n")
    end
    if ftgt then
      local tokensTgt = readerTgt:next()
      ftgt:write(IDX .. ' ' .. table.concat(tokensTgt, ' ') .. '\n')
    end
    count = count + 1
  end
end

main()