Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvo Danihelka <ivo@danihelka.net>2013-07-18 17:28:27 +0400
committerIvo Danihelka <ivo@danihelka.net>2013-07-18 21:45:06 +0400
commite66979a7a35b1b892d343fe157fc77b67f23aab2 (patch)
tree8cd275ed7a36844ab9ba2c00deb5d90d23258b50
parentb56cfc2836fb8a8547be98a0675ac07dee673ea7 (diff)
Used split on innode.
-rw-r--r--gmodule.lua215
-rw-r--r--node.lua10
2 files changed, 83 insertions, 142 deletions
diff --git a/gmodule.lua b/gmodule.lua
index fe2d47b..92286d9 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -5,7 +5,6 @@ local istable = utils.istable
local istorchclass = utils.istorchclass
local function getTotalGradOutput(node)
- local module = node.data.module
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
@@ -34,12 +33,20 @@ function gModule:__init(inputs,outputs)
for i,n in ipairs(outputs) do
outnode:add(n,true)
end
- local innode = nngraph.Node({data={},gradOutput={}})
- for i,n in ipairs(inputs) do
- n:add(innode,true)
- -- fix the mapindex for the input data node
- table.insert(innode.data.mapindex,n.data)
- innode.data.mapindex[n.data] = #innode.data.mapindex
+ -- We add also a dummy input node.
+ -- The input node will be split to feed the passed input nodes.
+ local innode = nngraph.Node({input={}})
+ assert(#inputs > 0, "no inputs are not supported")
+ if #inputs == 1 then
+ inputs[1]:add(innode,true)
+ else
+ local splits = {innode:split(#inputs)}
+ for i = 1, #inputs do
+ assert(#inputs[i].children == 0, "an input should have no inputs")
+ end
+ for i = 1, #inputs do
+ inputs[i]:add(splits[i],true)
+ end
end
-- the backward graph (bg) is for gradients
@@ -49,7 +56,9 @@ function gModule:__init(inputs,outputs)
-- the complete graph is constructed
-- now regenerate the graphs with the additional nodes
+ assert(#self.fg:roots() == 1, "expecting only one start")
self.innode = self.fg:roots()[1]
+ assert(self.innode.data == innode.data, "expecting the forward innode")
self.outnode = outnode
self.verbose = false
@@ -80,70 +89,50 @@ function gModule:runForwardFunction(func,input)
func = function(module,input) return module[func_name](module,input) end
end
-- We see the input as a list of inputs.
- if #self.innode.data.mapindex <= 1 then
+ if #self.innode.children <= 1 then
input={input}
elseif type(input) ~= "table" then
- error(string.format("expecting %s inputs", #self.innode.data.mapindex))
+ error(string.format("expecting %s inputs", #self.innode.children))
end
local function neteval(node)
local function propagate(node,x)
for i,child in ipairs(node.children) do
child.data.input = child.data.input or {}
local mapindex = child.data.mapindex[node.data]
+ assert(not child.data.input[mapindex], "each input should have one source")
child.data.input[mapindex] = x
end
end
- if node.data.data then
- -- then this is a data node, just propagate into
- -- its children
- -- this is different from a regular data node
- -- the input is expected to be a table of things
- -- where each thing goes into the input of
- -- corresponding children. So this is like a
- -- dispatcher
- -- the mapindex in a data node indexes the child data
- -- so that this node can distribute its data to corresponding inputs
- for i,child in ipairs(node.children) do
- local mapindex = node.data.mapindex[child.data]
- if child.data.input then
- table.insert(child.data.input,node.data.data[mapindex])
- else
- child.data.input = {node.data.data[mapindex]}
- end
- end
- elseif not node.data.module and node.data.input then
- -- then this is a data node, just propagate into
- -- its children
- local input = #node.data.input == 1 and node.data.input[1] or node.data.input
- if node.data.selectindex then
- input = input[node.data.selectindex]
- end
+ if node.data.selectindex then
+ assert(not node.data.module, "the selectindex-handling nodes should have no module")
+ local input = node.data.input
+ assert(#input == 1, "only the splitted node should be the input")
+ input = input[1][node.data.selectindex]
propagate(node,input)
- elseif node.data.module then
- local module = node.data.module
+ else
local input = node.data.input
if #input == 1 then
input = input[1]
end
-- forward through this node
- local output = func(module,input)
+ -- If no module is present, the node behaves like nn.Identity.
+ local output
+ if not node.data.module then
+ output = input
+ else
+ output = func(node.data.module,input)
+ end
-- propagate the output to children
propagate(node,output)
- else
- error('weird node: ' .. node.data)
end
if self.verbose then
print(' V : ' .. node:label())
end
end
- -- set the data field to current input
local innode = self.innode
- innode.data.data=input
- if #input ~= #innode.data.mapindex then
- print('#inputs =' .. #input)
- print('#mapindices =' .. #innode.data.mapindex)
- error('Number of inputs do not match my graph')
+ if #input ~= #innode.children then
+ error(string.format('Got %s inputs instead of %s', #input, #innode.children))
end
-- first clear the input states
innode:bfs(function(node)
@@ -152,6 +141,12 @@ function gModule:runForwardFunction(func,input)
table.remove(input)
end
end)
+ -- Set the starting input.
+ -- We do copy instead of modifying the passed input.
+ innode.data.input = innode.data.input or {}
+ for i, item in ipairs(input) do
+ innode.data.input[i] = item
+ end
-- the run forward
for i,node in ipairs(self.forwardnodes) do
@@ -166,78 +161,53 @@ function gModule:runForwardFunction(func,input)
end
function gModule:updateGradInput(input,gradOutput)
- -- We see the gradOutput as a list of gradOutputs
- if #self.outnode.children <= 1 then
- gradOutput={gradOutput}
- elseif type(gradOutput) ~= "table" then
- error(string.format("expecting %s gradOutputs", #self.outnode.children))
- end
- local outputs = {}
local function neteval(node)
- if node.data.data then
- -- then this is a data node, just propagate into
- -- its children
- -- this is different from a regular data node
- -- the input is expected to be a table of things
- -- where each thing goes into the input of
- -- corresponding children. So this is like a
- -- dispatcher
- -- First we need to fix the order of stuff in our
- -- gradOutput table.
- for i,child in ipairs(node.children) do
- child.data.gradOutput = child.data.gradOutput or {}
- local mapindex = node.data.mapindex[child.data]
- table.insert(child.data.gradOutput,node.data.data[mapindex])
- end
- elseif not node.data.module and node.data.gradOutput then
- -- then this is a data node, just propagate into
- -- its children
- for i,child in ipairs(node.children) do
- child.data.gradOutput = child.data.gradOutput or {}
- local go = getTotalGradOutput(node)
- if node.data.selectindex then
- assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
- -- The data.gradOutput holds the to-be-summed gradients.
- child.data.gradOutput[1] = child.data.gradOutput[1] or {}
- child.data.gradOutput[1][node.data.selectindex] = go
- else
- table.insert(child.data.gradOutput,go)
- end
- end
- elseif node.data.module then
- local module = node.data.module
+ if node.data.selectindex then
+ assert(not node.data.module, "the selectindex-handling nodes should have no module")
+ assert(#node.children == 1, "only the splitted node should be the input")
+ local child = node.children[1]
+ local go = getTotalGradOutput(node)
+ child.data.gradOutput = child.data.gradOutput or {}
+ assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
+ -- The data.gradOutput holds the to-be-summed gradients.
+ child.data.gradOutput[1] = child.data.gradOutput[1] or {}
+ assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
+ child.data.gradOutput[1][node.data.selectindex] = go
+ else
local gradOutput = getTotalGradOutput(node)
- local input = node.data.input
- if #input == 1 then
- input = input[1]
- end
-- updateGradInput through this node
- local gradInput = module:updateGradInput(input,gradOutput)
+ -- If no module is present, the node behaves like nn.Identity.
+ local gradInput
+ if not node.data.module then
+ gradInput = gradOutput
+ else
+ local input = node.data.input
+ if #input == 1 then
+ input = input[1]
+ end
+ local module = node.data.module
+ gradInput = module:updateGradInput(input,gradOutput)
+ end
-- propagate the output to children
for i,child in ipairs(node.children) do
child.data.gradOutput = child.data.gradOutput or {}
local mapindex = node.data.mapindex[child.data]
local gi
- if #node.children ~= 1 then --istable(gradInput) and istable(input) then
- gi = gradInput[mapindex]
- else
+ if #node.children == 1 then
gi = gradInput
+ else
+ gi = gradInput[mapindex]
end
table.insert(child.data.gradOutput,gi)
end
- else
- error('weird node: ' .. node.data)
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
- outnode.data.data=gradOutput
- if #gradOutput ~= #outnode.children then
- print('#outputs =' .. #outnode.children)
- print('#gradients =' .. #gradOutput)
- error('Number of gradients do not match my graph')
+ if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+ error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
outnode:bfs(function(node)
local gradOutput = node.data.gradOutput
@@ -245,46 +215,22 @@ function gModule:updateGradInput(input,gradOutput)
table.remove(gradOutput)
end
end)
+ -- Set the starting gradOutput.
+ outnode.data.gradOutput = outnode.data.gradOutput or {}
+ outnode.data.gradOutput[1] = gradOutput
+
for i,node in ipairs(self.backwardnodes) do
neteval(node)
end
- -- now fix the order of gradInput
- -- The innode is used as the input by all the input nodes.
- -- The data.gradOutput holds the gradients to-be-summed.
- -- Here, instead of summing them, we reorder them based on the mapindex.
- self.gradInput = self.innode.data.gradOutput
- local gi = {}
- for i,child in ipairs(self.innode.children) do
- local mi = self.innode.data.mapindex[child.data]
- table.insert(gi,self.gradInput[mi])
- end
- while #self.gradInput > 0 do
- table.remove(self.gradInput)
- end
- for i,v in ipairs(gi) do
- table.insert(self.gradInput,v)
- end
-
- if #self.innode.children == 1 then
- self.gradInput = self.gradInput[1]
- end
-
+ assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
+ self.gradInput = self.innode.data.gradOutput[1]
return self.gradInput
end
function gModule:accGradParameters(input,gradOutput,lr)
- -- We see the gradOutput as a list of gradOutputs
- if #self.outnode.children <= 1 then
- gradOutput={gradOutput}
- elseif type(gradOutput) ~= "table" then
- error(string.format("expecting %s gradOutputs", #self.outnode.children))
- end
- local outputs = {}
local function neteval(node)
- if node.data.data then
- elseif not node.data.module and node.data.gradOutput then
- elseif node.data.module then
+ if node.data.module then
local module = node.data.module
local gradOutput = node.data.gradOutput[1]
if #node.data.gradOutput > 1 then
@@ -296,19 +242,14 @@ function gModule:accGradParameters(input,gradOutput,lr)
end
-- accGradParameters through this node
module:accGradParameters(input,gradOutput,lr)
- else
- error('weird node: ' .. node.data)
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
- outnode.data.data=gradOutput
- if #gradOutput ~= #outnode.children then
- print('#outputs =' .. #outnode.children)
- print('#gradients =' .. #gradOutput)
- error('Number of gradients do not match my graph')
+ if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+ error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
for i,node in ipairs(self.backwardnodes) do
neteval(node)
diff --git a/node.lua b/node.lua
index f130b12..73dbc23 100644
--- a/node.lua
+++ b/node.lua
@@ -21,10 +21,9 @@ function nnNode:add(child,domap)
if domap then
local mapindex = self.data.mapindex
local data = child.data
- if not mapindex[data] then
- table.insert(mapindex,data)
- mapindex[data] = #mapindex
- end
+ assert(not mapindex[data], "Don't pass the same input twice.")
+ table.insert(mapindex,data)
+ mapindex[data] = #mapindex
end
end
@@ -32,6 +31,7 @@ end
-- that each take a single component of the output of this
-- node in the order they are returned.
function nnNode:split(noutput)
+ assert(noutput >= 2, "splitting to one output is not supported")
local mnode = self
local selectnodes = {}
for i=1,noutput do
@@ -79,7 +79,7 @@ function nnNode:label()
end
for k,v in pairs(self.data) do
- vstr = ''
+ local vstr = ''
if k=='mapindex' then
vstr = getmapindexstr(v)
else