Welcome to mirror list, hosted at ThFree Co, Russian Federation.

preprocess.lua - github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 38b2602fe269265b7cb4fe8687aee1a3c0e0e72b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
require('onmt.init')

local cmd = onmt.utils.ExtendedCmdLine.new('preprocess.lua')

-- First argument define the dataType: bitext/monotext - default is bitext.
local dataType = cmd.getArgument(arg, '-data_type') or 'bitext'

-- Options declaration
local options = {
  {
    '-data_type', 'bitext',
    [[Type of data to preprocess. Use 'monotext' for monolingual data.
      This option impacts all options choices.]],
    {
      enum = {'bitext', 'monotext', 'feattext'},
      depends = function(opt) return opt.data_type ~= 'feattext' or opt.idx_files end
    }
  },
  {
    '-dry_run', false,
    [[If set, this will only prepare the preprocessor. Useful when using file sampling to
      test distribution rules.]]
  },
  {
    '-save_data', '',
    [[Output file for the prepared data.]],
    {
      depends = function(opt)
        return opt.dry_run or opt.save_data ~= '', "option `-save_data` is required"
      end
    }
  }
}

cmd:setCmdLineOptions(options, 'Preprocess')

onmt.data.Preprocessor.declareOpts(cmd, dataType)
-- insert on the fly the option depending if there is a hook selected
onmt.utils.HookManager.updateOpt(arg, cmd)

-- expand options depending on source or target (tokenization, mpreprocessing)
onmt.data.Preprocessor.expandOpts(cmd, dataType)

onmt.utils.HookManager.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)

local otherOptions = {
  {
    '-seed', 3425,
    [[Random seed.]],
    {
      valid = onmt.utils.ExtendedCmdLine.isUInt()
    }
  }
}
cmd:setCmdLineOptions(otherOptions, 'Other')

local opt = cmd:parse(arg)

local function main()

  torch.manualSeed(opt.seed)

  _G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)

  _G.hookManager = onmt.utils.HookManager.new(opt)

  local Preprocessor = onmt.data.Preprocessor.new(opt, dataType)

  if opt.dry_run then
    _G.logger:shutDown()
    return
  end

  local data = { dataType=dataType }

  -- keep processing options in the structure for further traceability
  data.opt = opt

  _G.logger:info('Preparing vocabulary...')
  data.dicts = Preprocessor:getVocabulary()

  _G.logger:info('Preparing training data...')
  data.train = Preprocessor:makeData('train', data.dicts)
  _G.logger:info('')

  _G.logger:info('Preparing validation data...')
  data.valid = Preprocessor:makeData('valid', data.dicts)
  _G.logger:info('')

  if dataType == 'monotext' then
    if opt.vocab:len() == 0 then
      onmt.data.Vocabulary.save('source', data.dicts.src.words, opt.save_data .. '.dict')
    end
    if opt.features_vocabs_prefix:len() == 0 then
      onmt.data.Vocabulary.saveFeatures('source', data.dicts.src.features, opt.save_data)
    end
  elseif dataType == 'feattext' then
    if opt.tgt_vocab:len() == 0 then
      onmt.data.Vocabulary.save('target', data.dicts.tgt.words, opt.save_data .. '.tgt.dict')
    end
    if opt.features_vocabs_prefix:len() == 0 then
      onmt.data.Vocabulary.saveFeatures('target', data.dicts.tgt.features, opt.save_data)
    end
  else
    if opt.src_vocab:len() == 0 then
      onmt.data.Vocabulary.save('source', data.dicts.src.words, opt.save_data .. '.src.dict')
    end

    if opt.tgt_vocab:len() == 0 then
      onmt.data.Vocabulary.save('target', data.dicts.tgt.words, opt.save_data .. '.tgt.dict')
    end
    if opt.features_vocabs_prefix:len() == 0 then
      onmt.data.Vocabulary.saveFeatures('source', data.dicts.src.features, opt.save_data..'.source')
      onmt.data.Vocabulary.saveFeatures('target', data.dicts.tgt.features, opt.save_data..'.target')
    end
  end

  _G.logger:info('Saving data to \'' .. opt.save_data .. '-train.t7\'...')
  torch.save(opt.save_data .. '-train.t7', data, 'binary', false)
  _G.logger:shutDown()
end

main()