Welcome to mirror list, hosted at ThFree Co, Russian Federation.

release_model.lua « tools - github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: fc8e1cc6ac3bba488903a1f38702caba7d4fd457 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
require('onmt.init')

local path = require('pl.path')

local cmd = onmt.utils.ExtendedCmdLine.new('release_model.lua')

local options = {
  {
    '-model', '',
    [[Path to the trained model to release.]],
    {
      valid = onmt.utils.ExtendedCmdLine.fileExists
    }
  },
  {
    '-output_model', '',
    [[Path the released model. If not set, the `release` suffix will be automatically
      added to the model filename.]]
  },
  {
    '-force', false,
    [[Force output model creation even if the target file exists.]]
  }
}

cmd:setCmdLineOptions(options, 'Model')

onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)

local opt = cmd:parse(arg)

local function isModel(object)
  return torch.type(object) == 'table' and object.modules
end

local function releaseModule(object, tensorCache)
  tensorCache = tensorCache or {}
  if object.release then
    object:release()
  end
  object:float(tensorCache)
  object:clearState()
  object:apply(function (m)
    nn.utils.clear(m, 'gradWeight', 'gradBias')
    for k, v in pairs(m) do
      if type(v) == 'function' then
        m[k] = nil
      end
    end
  end)
end

local function releaseModel(model, tensorCache)
  tensorCache = tensorCache or {}
  for _, object in pairs(model.modules) do
    if isModel(object) then
      releaseModel(object, tensorCache)
    else
      releaseModule(object, tensorCache)
    end
  end
end

local function main()
  assert(path.exists(opt.model), 'model \'' .. opt.model .. '\' does not exist.')

  _G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)

  if opt.output_model:len() == 0 then
    if opt.model:sub(-3) == '.t7' then
      opt.output_model = opt.model:sub(1, -4) -- copy input model without '.t7' extension
    else
      opt.output_model = opt.model
    end
    opt.output_model = opt.output_model .. '_release.t7'
  end

  if not opt.force then
    assert(not path.exists(opt.output_model),
           'output model already exists; use -force to overwrite.')
  end

  onmt.utils.Cuda.init(opt)

  _G.logger:info('Loading model \'' .. opt.model .. '\'...')

  local checkpoint
  local _, err = pcall(function ()
    checkpoint = torch.load(opt.model)
  end)
  if err then
    error('unable to load the model (' .. err .. '). If you are releasing a GPU model, it needs to be loaded on the GPU first (set -gpuid > 0)')
  end

  _G.logger:info('... done.')

  _G.logger:info('Converting model...')
  checkpoint.info = nil
  for _, object in pairs(checkpoint.models) do
    if isModel(object) then
      releaseModel(object)
    else
      releaseModule(object)
    end
  end
  _G.logger:info('... done.')

  _G.logger:info('Releasing model \'' .. opt.output_model .. '\'...')
  torch.save(opt.output_model, checkpoint)
  _G.logger:info('... done.')

  _G.logger:shutDown()
end

main()