Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gmodule.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6e118d8d4cbe0c826c8ec44c5253e6fff02ad74b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
local nesting = require('nngraph.nesting')
local utils = require('nngraph.utils')
local istensor = torch.isTensor
local istable = utils.istable
local istorchclass = utils.istorchclass

local function getTotalGradOutput(node)
   local gradOutput = node.data.gradOutput
   assert(istable(gradOutput), "expecting gradients to sum")
   if #gradOutput > 1 then
      -- Check if we can bypass the allocation, for the special case where all
      -- gradOutputs but one are zero tensors with an underlying one-element
      -- storage. Note that for the case that we
      -- cannot bypass it, this check will only be performed once
      if not node.data.gradOutputBuffer then
         local count = 0
         local idx = 1
         -- Count how many gradOutput are tensors of 1 element filled with zero
         for i=1,#gradOutput do
            local zero = torch.isTensor(gradOutput[i]) and
                         gradOutput[i]:storage() ~= nil and
                         gradOutput[i]:storage():size() == 1 and
                         gradOutput[i]:storage()[1] == 0
            if not zero then
               idx = i
               count = count + 1
            end
         end
         if count < 2 then
            -- Return the only non-zero one, or the first one
            -- if they are all zero
            return gradOutput[idx]
         end
      end
      node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
      local gobuff = node.data.gradOutputBuffer
      nesting.resizeNestedAs(gobuff, gradOutput[1])
      nesting.copyNested(gobuff, gradOutput[1])
      for i=2,#gradOutput do
         nesting.addNestedTo(gobuff, gradOutput[i])
      end
      gradOutput = gobuff
   else
      gradOutput = gradOutput[1]
   end
   return gradOutput
end

-- The gModule allows to have a general non-cyclic graph of of modules.
--
-- Each node of the graph can have multiple inputs.
-- The order of inputs is remembered in node.data.mapindex.
--
-- Each node have only one output.
-- The output can be also a table.
-- To route parts of the outputted table to different modules,
-- use the node:split(nOutputs) function.
-- The split will create subnodes with narrowed output.
--
-- Implementation details:
-- The node.data.input holds a list of inputs.
-- If a module expects only one input, the node.data.input[1] is used.
--
-- The node.data.gradOutput holds the to-be-summed gradOutputs.
-- Each node has only one output. So we need only one gradOutput.
local gModule, parent = torch.class('nn.gModule','nn.Container')

function gModule:__init(inputs,outputs)
   parent.__init(self)
   -- the graph is defined backwards, we have the output modules as input here
   -- we will define a dummy output node that connects all output modules
   -- into itself. This will be the output for the forward graph and
   -- input point for the backward graph
   local node
   local outnode = nngraph.Node({input={}})
   for i = 1, utils.tableMaxN(outputs) do
      node = outputs[i]
      if torch.typename(node) ~= 'nngraph.Node' then
         error(utils.expectingNodeErrorMessage(node, 'outputs', i))
      end
      outnode:add(node, true)
   end
   for i = 1, utils.tableMaxN(inputs) do
      node = inputs[i]
      if torch.typename(node) ~= 'nngraph.Node' then
         error(utils.expectingNodeErrorMessage(node, 'inputs', i))
      end
   end
   -- We add also a dummy input node.
   -- The input node will be split to feed the passed input nodes.
   local innode = nngraph.Node({input={}})
   assert(#inputs > 0, "no inputs are not supported")
   if #inputs == 1 then
      inputs[1]:add(innode,true)
   else
      local splits = {innode:split(#inputs)}
      for i = 1, #inputs do
         assert(#inputs[i].children == 0, "an input should have no inputs")
      end
      for i = 1, #inputs do
         inputs[i]:add(splits[i],true)
      end
   end

   -- the backward graph (bg) is for gradients
   -- the forward graph (fg) is for function evaluation
   self.bg = outnode:graph()
   self.fg = self.bg:reverse()

   -- the complete graph is constructed
   -- now regenerate the graphs with the additional nodes

   local roots = self.fg:roots()
   -- if there are more than one root in the forward graph, then make sure that
   -- extra roots are parameter nodes
   if #roots > 1 then
      local innodeRoot = nil
      -- first find our innode
      for _, root in ipairs(roots) do
         if root.data == innode.data then
            assert(innodeRoot == nil, 'more than one matching input node found in leaves')
            innodeRoot = root
         else
            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
            assert(torch.typename(root.data.module) == 'nnop.Parameters',
                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
         end
      end
      assert(innodeRoot ~= nil, 'input node not found among roots')
      self.innode = innodeRoot
   else
      assert(#self.fg:roots() == 1, "expecting only one start")
      self.innode = self.fg:roots()[1]
   end

   assert(self.innode.data == innode.data, "expecting the forward innode")
   self.outnode = outnode
   self.verbose = false
   self.nInputs = #inputs

   -- computation on the graph is done through topsort of forward and backward graphs
   self.forwardnodes = self.fg:topsort()
   self.backwardnodes = self.bg:topsort()

   -- iteratare over all nodes: check, tag and add to container
   for i,node in ipairs(self.forwardnodes) do
      -- check for unused inputs or unused split() outputs
      if node.data.nSplitOutputs and node.data.nSplitOutputs ~=  #node.children then
         local nUnused = node.data.nSplitOutputs - #node.children
         local debugLabel = node.data.annotations._debugLabel
         local errStr =
            "%s of split(%s) outputs from the node declared at %s are unused"
         error(string.format(errStr, nUnused, node.data.nSplitOutputs,
                             debugLabel))
      end

      -- Check whether any nodes were defined as taking this node as an input,
      -- but then left dangling and don't connect to the output. If this is
      -- the case, then they won't be present in forwardnodes, so error out.
      for successor, _ in pairs(node.data.reverseMap) do
         local successorIsInGraph = false

         -- Only need to the part of forwardnodes from i onwards, topological
         -- sort guarantees it cannot be in the first part.
         for j = i+1, #self.forwardnodes do
            -- Compare equality of data tables, as new Node objects have been
            -- created by processes such as topoological sort, but the
            -- underlying .data table is shared.
            if self.forwardnodes[j].data == successor.data then
               successorIsInGraph = true
               break
            end
         end
         local errStr =
            "node declared on %s does not connect to gmodule output"
         assert(successorIsInGraph,
                string.format(errStr, successor.data.annotations._debugLabel))
      end

      -- set data.forwardNodeId for node:label() output
      node.data.forwardNodeId = node.id

      -- add module to container
      if node.data.module then
         self:add(node.data.module)
      end
   end

   self.output = nil
   self.gradInput = nil
   if #self.outnode.children > 1 then
      self.output = self.outnode.data.input
   end
end

function gModule:replace(callback)
    local out = callback(self)
    local revmodules = {}
    for i,m in ipairs(self.modules) do
        revmodules[m] = i
    end
    for i,node in ipairs(self.forwardnodes) do
        if node.data.module then
            local m = node.data.module
            node.data.module = m:replace(callback)
            self.modules[revmodules[m]] = node.data.module
        end
    end
    return out
end

function gModule:map(gm, func)
   for i,node in ipairs(self.forwardnodes) do
      local gmnode = gm.forwardnodes[i]
      assert(gmnode, 'trying to map another gModule with a different structure')
      if node.data.module then
         assert(gmnode.data.module, 'trying to map another gModule with a different structure')
         func(node.data.module, gmnode.data.module)
      end
   end
end

--[[ Recursively applies type(type_str) to any tensors in the argument. If the
argument is a tensor, type(type_str) is applied; if the argument is an array,
this function recurses into it. ]]
local function recursiveType(param, type_str)
   if torch.type(param) == 'table' then
      for i = 1, #param do
         param[i] = recursiveType(param[i], type_str)
      end
   elseif torch.typename(param) and
      torch.typename(param):find('torch%..+Tensor') then
      param = param:type(type_str)
   end
   return param
end

function gModule:type(type, tensorCache)
   if not type then
      return self._type
   end

   tensorCache = tensorCache or {}

   local function applyTypeToTable(table)
      for key, value in pairs(table) do
         table[key] = recursiveType(table[key], type)
      end
   end

   -- Convert any stored data in self, and in the in and out nodes
   applyTypeToTable(self)
   if self.innode then applyTypeToTable(self.innode.data) end
   if self.outnode then applyTypeToTable(self.outnode.data) end

   -- Loop through modules and convert data
   for _, m in ipairs(self.modules) do
      m:type(type, tensorCache)
   end

   for i,node in ipairs(self.backwardnodes) do
      if node.data.gradOutputBuffer ~= nil then
         node.data.gradOutputBuffer =
            recursiveType(node.data.gradOutputBuffer, type)
      end
      for k, child in ipairs(node.children) do
         applyTypeToTable(child.data)
      end
   end

   for i,node in ipairs(self.forwardnodes) do
      if node.data.input ~= nil then
         node.data.input = recursiveType(node.data.input, type)
      end
      for k, child in ipairs(node.children) do
         applyTypeToTable(child.data)
      end
   end

   self._type = type
   return self
end

function gModule:updateOutput(input)
   return self:runForwardFunction('updateOutput',input)
end

function gModule:clearState()
   local ret = parent.clearState(self)
   for _,node in ipairs(self.backwardnodes) do
      node.data.gradOutput = nil
      node.data.gradOutputBuffer = nil
   end
   for _,node in ipairs(self.forwardnodes) do
      node.data.input = nil
   end
   return ret
end

function gModule:runForwardFunction(func,input)
   if type(func) == "string" then
      local func_name = func
      func = function(module,input) return module[func_name](module,input) end
   end
   -- For backward compatibility, we allow self.nInputs to be missing.
   local nInputs = self.nInputs or #self.innode.children
   -- We see the input as a list of inputs.
   if nInputs <= 1 then
      input={input}
   elseif type(input) ~= "table" then
      error(string.format("expecting table of %s inputs", nInputs))
   end
   local function neteval(node)
      local function propagate(node,x)
         for i,child in ipairs(node.children) do
            child.data.input = child.data.input or {}
            local mapindex = child.data.mapindex[node.data]
            assert(not child.data.input[mapindex], "each input should have one source")
            child.data.input[mapindex] = x
         end
      end
      if node.data.selectindex then
         assert(not node.data.module, "the selectindex-handling nodes should have no module")
         local input = node.data.input
         assert(#input == 1, "only the splitted node should be the input")
         assert(istable(input[1]), "the input for a split should be a table")
         input = input[1][node.data.selectindex]
         propagate(node,input)
      else
         local input = node.data.input

         -- a parameter node is captured
         if input == nil and node.data.module ~= nil then
            input = {}
         end
         if #input == 1 then
            input = input[1]
         end
         -- forward through this node
         -- If no module is present, the node behaves like nn.Identity.
         local output
         if not node.data.module then
            output = input
         else
            output = func(node.data.module,input)
         end
         if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
            error(string.format("split(%s) cannot split %s outputs",
            node.data.nSplitOutputs,
            #output))
         end
         -- propagate the output to children
         propagate(node,output)
      end
      if self.verbose then
         print(' V : ' .. node:label())
      end
   end

   local innode = self.innode
   if #input ~= nInputs then
      error(string.format('Got %s inputs instead of %s', #input, nInputs))
   end
   -- first clear the input states
   for _,node in ipairs(self.forwardnodes) do
      local input = node.data.input
      while input and #input>0 do
         table.remove(input)
      end
   end
   -- Set the starting input.
   -- We do copy instead of modifying the passed input.
   innode.data.input = innode.data.input or {}
   for i, item in ipairs(input) do
      innode.data.input[i] = item
   end

   -- the run forward
   for i,node in ipairs(self.forwardnodes) do
      neteval(node)
   end

   self.output = self.outnode.data.input
   if #self.outnode.children == 1 then
      self.output = self.output[1]
   end
   return self.output
end

function gModule:updateGradInput(input,gradOutput)
   local function neteval(node)
      if node.data.selectindex then
         assert(not node.data.module, "the selectindex-handling nodes should have no module")
         assert(#node.children == 1, "only the splitted node should be the input")
         local child = node.children[1]
         local go = getTotalGradOutput(node)
         child.data.gradOutput = child.data.gradOutput or {}
         assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
         -- The data.gradOutput holds the to-be-summed gradients.
         child.data.gradOutput[1] = child.data.gradOutput[1] or {}
         assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
         child.data.gradOutput[1][node.data.selectindex] = go
      else
         local gradOutput = getTotalGradOutput(node)
         -- updateGradInput through this node
         -- If no module is present, the node behaves like nn.Identity.
         local gradInput
         if not node.data.module then
            gradInput = gradOutput
         else
            local input = node.data.input
            -- a parameter node is captured
            if input == nil and node.data.module ~= nil then
               input = {}
            end
            if #input == 1 then
               input = input[1]
            end
            local module = node.data.module
            gradInput = module:updateGradInput(input,gradOutput)
         end
         -- propagate the output to children
         for i,child in ipairs(node.children) do
            child.data.gradOutput = child.data.gradOutput or {}
            local mapindex = node.data.mapindex[child.data]
            local gi
            if #node.children == 1 then
               gi = gradInput
            else
               gi = gradInput[mapindex]
            end
            table.insert(child.data.gradOutput,gi)
         end
      end
      if self.verbose then
         print(' V : ' .. node:label())
      end
   end
   local outnode = self.outnode
   if #outnode.children > 1 and #gradOutput ~= #outnode.children then
      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
   end
   for _,node in ipairs(self.backwardnodes) do
      local gradOutput = node.data.gradOutput
      while gradOutput and #gradOutput >0 do
         table.remove(gradOutput)
      end
   end
   -- Set the starting gradOutput.
   outnode.data.gradOutput = outnode.data.gradOutput or {}
   outnode.data.gradOutput[1] = gradOutput

   for i,node in ipairs(self.backwardnodes) do
      neteval(node)
   end

   assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
   self.gradInput = self.innode.data.gradOutput[1]
   return self.gradInput
end

function gModule:accGradParameters(input,gradOutput,lr)
   local function neteval(node)
      if node.data.module then
         local module = node.data.module
         local gradOutput = node.data.gradOutput[1]
         if #node.data.gradOutput > 1 then
            gradOutput = node.data.gradOutputBuffer
         end
         local input = node.data.input
         -- a parameter node is captured
         if input == nil and node.data.module ~= nil then
            input = {}
         end
         if #input == 1 then
            input = input[1]
         end
         -- accGradParameters through this node
         module:accGradParameters(input,gradOutput,lr)
      end
      if self.verbose then
         print(' V : ' .. node:label())
      end
   end
   local outnode = self.outnode
   if #outnode.children > 1 and #gradOutput ~= #outnode.children then
      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
   end
   for i,node in ipairs(self.backwardnodes) do
      neteval(node)
   end
end

function gModule:read(file)
   local data = file:readObject()
   for k, v in pairs(data) do
      self[k] = v
   end

   -- Initialize the modules table if necessary.
   if not self.modules then
      self.modules = {}
      for _, node in ipairs(self.forwardnodes) do
         if node.data.module then
            table.insert(self.modules, node.data.module)
         end
      end
   end
end


function gModule:__tostring__()
   return self.name or torch.type(self)
end