diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2013-07-18 17:28:27 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2013-07-18 21:45:06 +0400 |
commit | e66979a7a35b1b892d343fe157fc77b67f23aab2 (patch) | |
tree | 8cd275ed7a36844ab9ba2c00deb5d90d23258b50 | |
parent | b56cfc2836fb8a8547be98a0675ac07dee673ea7 (diff) |
Used split on innode.
-rw-r--r-- | gmodule.lua | 215 | ||||
-rw-r--r-- | node.lua | 10 |
2 files changed, 83 insertions, 142 deletions
diff --git a/gmodule.lua b/gmodule.lua index fe2d47b..92286d9 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -5,7 +5,6 @@ local istable = utils.istable local istorchclass = utils.istorchclass local function getTotalGradOutput(node) - local module = node.data.module local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then @@ -34,12 +33,20 @@ function gModule:__init(inputs,outputs) for i,n in ipairs(outputs) do outnode:add(n,true) end - local innode = nngraph.Node({data={},gradOutput={}}) - for i,n in ipairs(inputs) do - n:add(innode,true) - -- fix the mapindex for the input data node - table.insert(innode.data.mapindex,n.data) - innode.data.mapindex[n.data] = #innode.data.mapindex + -- We add also a dummy input node. + -- The input node will be split to feed the passed input nodes. + local innode = nngraph.Node({input={}}) + assert(#inputs > 0, "no inputs are not supported") + if #inputs == 1 then + inputs[1]:add(innode,true) + else + local splits = {innode:split(#inputs)} + for i = 1, #inputs do + assert(#inputs[i].children == 0, "an input should have no inputs") + end + for i = 1, #inputs do + inputs[i]:add(splits[i],true) + end end -- the backward graph (bg) is for gradients @@ -49,7 +56,9 @@ function gModule:__init(inputs,outputs) -- the complete graph is constructed -- now regenerate the graphs with the additional nodes + assert(#self.fg:roots() == 1, "expecting only one start") self.innode = self.fg:roots()[1] + assert(self.innode.data == innode.data, "expecting the forward innode") self.outnode = outnode self.verbose = false @@ -80,70 +89,50 @@ function gModule:runForwardFunction(func,input) func = function(module,input) return module[func_name](module,input) end end -- We see the input as a list of inputs. - if #self.innode.data.mapindex <= 1 then + if #self.innode.children <= 1 then input={input} elseif type(input) ~= "table" then - error(string.format("expecting %s inputs", #self.innode.data.mapindex)) + error(string.format("expecting %s inputs", #self.innode.children)) end local function neteval(node) local function propagate(node,x) for i,child in ipairs(node.children) do child.data.input = child.data.input or {} local mapindex = child.data.mapindex[node.data] + assert(not child.data.input[mapindex], "each input should have one source") child.data.input[mapindex] = x end end - if node.data.data then - -- then this is a data node, just propagate into - -- its children - -- this is different from a regular data node - -- the input is expected to be a table of things - -- where each thing goes into the input of - -- corresponding children. So this is like a - -- dispatcher - -- the mapindex in a data node indexes the child data - -- so that this node can distribute its data to corresponding inputs - for i,child in ipairs(node.children) do - local mapindex = node.data.mapindex[child.data] - if child.data.input then - table.insert(child.data.input,node.data.data[mapindex]) - else - child.data.input = {node.data.data[mapindex]} - end - end - elseif not node.data.module and node.data.input then - -- then this is a data node, just propagate into - -- its children - local input = #node.data.input == 1 and node.data.input[1] or node.data.input - if node.data.selectindex then - input = input[node.data.selectindex] - end + if node.data.selectindex then + assert(not node.data.module, "the selectindex-handling nodes should have no module") + local input = node.data.input + assert(#input == 1, "only the splitted node should be the input") + input = input[1][node.data.selectindex] propagate(node,input) - elseif node.data.module then - local module = node.data.module + else local input = node.data.input if #input == 1 then input = input[1] end -- forward through this node - local output = func(module,input) + -- If no module is present, the node behaves like nn.Identity. + local output + if not node.data.module then + output = input + else + output = func(node.data.module,input) + end -- propagate the output to children propagate(node,output) - else - error('weird node: ' .. node.data) end if self.verbose then print(' V : ' .. node:label()) end end - -- set the data field to current input local innode = self.innode - innode.data.data=input - if #input ~= #innode.data.mapindex then - print('#inputs =' .. #input) - print('#mapindices =' .. #innode.data.mapindex) - error('Number of inputs do not match my graph') + if #input ~= #innode.children then + error(string.format('Got %s inputs instead of %s', #input, #innode.children)) end -- first clear the input states innode:bfs(function(node) @@ -152,6 +141,12 @@ function gModule:runForwardFunction(func,input) table.remove(input) end end) + -- Set the starting input. + -- We do copy instead of modifying the passed input. + innode.data.input = innode.data.input or {} + for i, item in ipairs(input) do + innode.data.input[i] = item + end -- the run forward for i,node in ipairs(self.forwardnodes) do @@ -166,78 +161,53 @@ function gModule:runForwardFunction(func,input) end function gModule:updateGradInput(input,gradOutput) - -- We see the gradOutput as a list of gradOutputs - if #self.outnode.children <= 1 then - gradOutput={gradOutput} - elseif type(gradOutput) ~= "table" then - error(string.format("expecting %s gradOutputs", #self.outnode.children)) - end - local outputs = {} local function neteval(node) - if node.data.data then - -- then this is a data node, just propagate into - -- its children - -- this is different from a regular data node - -- the input is expected to be a table of things - -- where each thing goes into the input of - -- corresponding children. So this is like a - -- dispatcher - -- First we need to fix the order of stuff in our - -- gradOutput table. - for i,child in ipairs(node.children) do - child.data.gradOutput = child.data.gradOutput or {} - local mapindex = node.data.mapindex[child.data] - table.insert(child.data.gradOutput,node.data.data[mapindex]) - end - elseif not node.data.module and node.data.gradOutput then - -- then this is a data node, just propagate into - -- its children - for i,child in ipairs(node.children) do - child.data.gradOutput = child.data.gradOutput or {} - local go = getTotalGradOutput(node) - if node.data.selectindex then - assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") - -- The data.gradOutput holds the to-be-summed gradients. - child.data.gradOutput[1] = child.data.gradOutput[1] or {} - child.data.gradOutput[1][node.data.selectindex] = go - else - table.insert(child.data.gradOutput,go) - end - end - elseif node.data.module then - local module = node.data.module + if node.data.selectindex then + assert(not node.data.module, "the selectindex-handling nodes should have no module") + assert(#node.children == 1, "only the splitted node should be the input") + local child = node.children[1] + local go = getTotalGradOutput(node) + child.data.gradOutput = child.data.gradOutput or {} + assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") + -- The data.gradOutput holds the to-be-summed gradients. + child.data.gradOutput[1] = child.data.gradOutput[1] or {} + assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") + child.data.gradOutput[1][node.data.selectindex] = go + else local gradOutput = getTotalGradOutput(node) - local input = node.data.input - if #input == 1 then - input = input[1] - end -- updateGradInput through this node - local gradInput = module:updateGradInput(input,gradOutput) + -- If no module is present, the node behaves like nn.Identity. + local gradInput + if not node.data.module then + gradInput = gradOutput + else + local input = node.data.input + if #input == 1 then + input = input[1] + end + local module = node.data.module + gradInput = module:updateGradInput(input,gradOutput) + end -- propagate the output to children for i,child in ipairs(node.children) do child.data.gradOutput = child.data.gradOutput or {} local mapindex = node.data.mapindex[child.data] local gi - if #node.children ~= 1 then --istable(gradInput) and istable(input) then - gi = gradInput[mapindex] - else + if #node.children == 1 then gi = gradInput + else + gi = gradInput[mapindex] end table.insert(child.data.gradOutput,gi) end - else - error('weird node: ' .. node.data) end if self.verbose then print(' V : ' .. node:label()) end end local outnode = self.outnode - outnode.data.data=gradOutput - if #gradOutput ~= #outnode.children then - print('#outputs =' .. #outnode.children) - print('#gradients =' .. #gradOutput) - error('Number of gradients do not match my graph') + if #outnode.children > 1 and #gradOutput ~= #outnode.children then + error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) end outnode:bfs(function(node) local gradOutput = node.data.gradOutput @@ -245,46 +215,22 @@ function gModule:updateGradInput(input,gradOutput) table.remove(gradOutput) end end) + -- Set the starting gradOutput. + outnode.data.gradOutput = outnode.data.gradOutput or {} + outnode.data.gradOutput[1] = gradOutput + for i,node in ipairs(self.backwardnodes) do neteval(node) end - -- now fix the order of gradInput - -- The innode is used as the input by all the input nodes. - -- The data.gradOutput holds the gradients to-be-summed. - -- Here, instead of summing them, we reorder them based on the mapindex. - self.gradInput = self.innode.data.gradOutput - local gi = {} - for i,child in ipairs(self.innode.children) do - local mi = self.innode.data.mapindex[child.data] - table.insert(gi,self.gradInput[mi]) - end - while #self.gradInput > 0 do - table.remove(self.gradInput) - end - for i,v in ipairs(gi) do - table.insert(self.gradInput,v) - end - - if #self.innode.children == 1 then - self.gradInput = self.gradInput[1] - end - + assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") + self.gradInput = self.innode.data.gradOutput[1] return self.gradInput end function gModule:accGradParameters(input,gradOutput,lr) - -- We see the gradOutput as a list of gradOutputs - if #self.outnode.children <= 1 then - gradOutput={gradOutput} - elseif type(gradOutput) ~= "table" then - error(string.format("expecting %s gradOutputs", #self.outnode.children)) - end - local outputs = {} local function neteval(node) - if node.data.data then - elseif not node.data.module and node.data.gradOutput then - elseif node.data.module then + if node.data.module then local module = node.data.module local gradOutput = node.data.gradOutput[1] if #node.data.gradOutput > 1 then @@ -296,19 +242,14 @@ function gModule:accGradParameters(input,gradOutput,lr) end -- accGradParameters through this node module:accGradParameters(input,gradOutput,lr) - else - error('weird node: ' .. node.data) end if self.verbose then print(' V : ' .. node:label()) end end local outnode = self.outnode - outnode.data.data=gradOutput - if #gradOutput ~= #outnode.children then - print('#outputs =' .. #outnode.children) - print('#gradients =' .. #gradOutput) - error('Number of gradients do not match my graph') + if #outnode.children > 1 and #gradOutput ~= #outnode.children then + error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) end for i,node in ipairs(self.backwardnodes) do neteval(node) @@ -21,10 +21,9 @@ function nnNode:add(child,domap) if domap then local mapindex = self.data.mapindex local data = child.data - if not mapindex[data] then - table.insert(mapindex,data) - mapindex[data] = #mapindex - end + assert(not mapindex[data], "Don't pass the same input twice.") + table.insert(mapindex,data) + mapindex[data] = #mapindex end end @@ -32,6 +31,7 @@ end -- that each take a single component of the output of this -- node in the order they are returned. function nnNode:split(noutput) + assert(noutput >= 2, "splitting to one output is not supported") local mnode = self local selectnodes = {} for i=1,noutput do @@ -79,7 +79,7 @@ function nnNode:label() end for k,v in pairs(self.data) do - vstr = '' + local vstr = '' if k=='mapindex' then vstr = getmapindexstr(v) else |