Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2013-07-18 14:12:02 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2013-07-18 14:12:02 +0400
commit694d539dfcc1ab6d60cfb57d54f9526f608b6b8b (patch)
tree38f0e4ea7fbbf3738db4425ede94bf89dcaab1db
parent894e7976301aab34b125eb9244fc5d0a4040daa8 (diff)
parent879cdc80c3811a632ae0140a4b4e4105c8cfd899 (diff)
Merge pull request #8 from fidlej/topic_data_oblivious
Using data.gradOutput to hold the to-be-summed gradients
-rw-r--r--gmodule.lua103
1 files changed, 46 insertions, 57 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 44e1f39..1633c12 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -4,6 +4,24 @@ local istensor = utils.istensor
local istable = utils.istable
local istorchclass = utils.istorchclass
+local function getTotalGradOutput(node)
+ local module = node.data.module
+ local gradOutput = node.data.gradOutput
+ assert(istable(gradOutput), "expecting gradients to sum")
+ if #gradOutput > 1 then
+ node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new()
+ local gobuff = node.data.gradOutputBuffer
+ gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1])
+ for i=2,#gradOutput do
+ gobuff:add(gradOutput[i])
+ end
+ gradOutput = gobuff
+ else
+ gradOutput = gradOutput[1]
+ end
+ return gradOutput
+end
+
local gModule, parent = torch.class('nn.gModule','nn.Module')
function gModule:__init(inputs,outputs)
@@ -61,10 +79,11 @@ function gModule:runForwardFunction(func,input)
local func_name = func
func = function(module,input) return module[func_name](module,input) end
end
- -- we will assume that the input is either a table of stuff
- -- if not we will put it in a table of stuff
- if torch.typename(input) or type(input) ~= 'table' then
+ -- We see the input as a list of inputs.
+ if #self.innode.data.mapindex <= 1 then
input={input}
+ elseif type(input) ~= "table" then
+ error(string.format("expecting %s inputs", #self.innode.data.mapindex))
end
local function neteval(node)
local function propagate(node,x)
@@ -150,27 +169,21 @@ function gModule:runForwardFunction(func,input)
end
self.output = self.outnode.data.input
- if #self.outnode.children == 1 and self.output == self.outnode.data.input then
+ if #self.outnode.children == 1 then
self.output = self.output[1]
end
return self.output
end
function gModule:updateGradInput(input,gradOutput)
- -- we will assume that the input is either a table of stuff
- -- if not we will put it in a table of stuff
- if torch.typename(gradOutput) or type(gradOutput) ~= 'table' then
+ -- We see the gradOutput as a list of gradOutputs
+ if #self.outnode.children <= 1 then
gradOutput={gradOutput}
+ elseif type(gradOutput) ~= "table" then
+ error(string.format("expecting %s gradOutputs", #self.outnode.children))
end
local outputs = {}
local function neteval(node)
- local function propagate(node,x)
- for i,child in ipairs(node.children) do
- child.data.gradOutput = child.data.gradOutput or {}
- local mapindex = node.data.mapindex[child.data]
- table.insert(child.data.gradOutput,x[mapindex])
- end
- end
if node.data.data then
-- then this is a data node, just propagate into
-- its children
@@ -191,39 +204,24 @@ function gModule:updateGradInput(input,gradOutput)
-- its children
for i,child in ipairs(node.children) do
child.data.gradOutput = child.data.gradOutput or {}
- local go = node.data.gradOutput
- if istable(go) and #go == 1 then
- go = go[1]
- end
+ local go = getTotalGradOutput(node)
if node.data.selectindex then
- child.data.gradOutput[node.data.selectindex] = go
+ assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
+ -- The data.gradOutput holds the to-be-summed gradients.
+ child.data.gradOutput[1] = child.data.gradOutput[1] or {}
+ child.data.gradOutput[1][node.data.selectindex] = go
else
table.insert(child.data.gradOutput,go)
end
end
elseif node.data.module then
local module = node.data.module
- local gradOutput = node.data.gradOutput
+ local gradOutput = getTotalGradOutput(node)
local input = node.data.input
if #input == 1 then
input = input[1]
end
-- updateGradInput through this node
- if istable(gradOutput) and not istable(module.output) then
- if #gradOutput > 1 then
- node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new()
- local gobuff = node.data.gradOutputBuffer
- gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1])
- for i=2,#gradOutput do
- gobuff:add(gradOutput[i])
- end
- gradOutput = gobuff
- else
- gradOutput = gradOutput[1]
- end
- elseif istable(gradOutput) and istable(module.output) and #gradOutput ~= #module.output then
- gradOutput = gradOutput[1]
- end
local gradInput = module:updateGradInput(input,gradOutput)
-- propagate the output to children
for i,child in ipairs(node.children) do
@@ -265,23 +263,23 @@ function gModule:updateGradInput(input,gradOutput)
end
-- now fix the order of gradInput
+ -- The innode is used as the input by all the input nodes.
+ -- The data.gradOutput holds the gradients to-be-summed.
+ -- Here, instead of summing them, we reorder them based on the mapindex.
self.gradInput = self.innode.data.gradOutput
- if not istable(self.gradInput) then
- return self.gradInput
- end
local gi = {}
for i,child in ipairs(self.innode.children) do
local mi = self.innode.data.mapindex[child.data]
table.insert(gi,self.gradInput[mi])
end
- while istable(self.gradInput) and #self.gradInput > 0 do
+ while #self.gradInput > 0 do
table.remove(self.gradInput)
end
for i,v in ipairs(gi) do
table.insert(self.gradInput,v)
end
- if #self.innode.children == 1 and self.gradInput == self.innode.data.gradOutput then
+ if #self.innode.children == 1 then
self.gradInput = self.gradInput[1]
end
@@ -289,10 +287,11 @@ function gModule:updateGradInput(input,gradOutput)
end
function gModule:accGradParameters(input,gradOutput,lr)
- -- we will assume that the input is either a table of stuff
- -- if not we will put it in a table of stuff
- if torch.typename(gradOutput) or type(gradOutput) ~= 'table' then
+ -- We see the gradOutput as a list of gradOutputs
+ if #self.outnode.children <= 1 then
gradOutput={gradOutput}
+ elseif type(gradOutput) ~= "table" then
+ error(string.format("expecting %s gradOutputs", #self.outnode.children))
end
local outputs = {}
local function neteval(node)
@@ -300,25 +299,15 @@ function gModule:accGradParameters(input,gradOutput,lr)
elseif not node.data.module and node.data.gradOutput then
elseif node.data.module then
local module = node.data.module
- local gradOutput = node.data.gradOutput
+ local gradOutput = node.data.gradOutput[1]
+ if #node.data.gradOutput > 1 then
+ gradOutput = node.data.gradOutputBuffer
+ end
local input = node.data.input
if #input == 1 then
input = input[1]
end
-- accGradParameters through this node
- if istable(gradOutput) and not istable(module.output) then
- if #gradOutput > 1 then
- node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new()
- local gobuff = node.data.gradOutputBuffer
- gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1])
- for i=2,#gradOutput do
- gobuff:add(gradOutput[i])
- end
- gradOutput = gobuff
- else
- gradOutput = gradOutput[1]
- end
- end
module:accGradParameters(input,gradOutput,lr)
else
if self.verbose then