diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2013-07-18 14:12:02 +0400 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2013-07-18 14:12:02 +0400 |
commit | 694d539dfcc1ab6d60cfb57d54f9526f608b6b8b (patch) | |
tree | 38f0e4ea7fbbf3738db4425ede94bf89dcaab1db | |
parent | 894e7976301aab34b125eb9244fc5d0a4040daa8 (diff) | |
parent | 879cdc80c3811a632ae0140a4b4e4105c8cfd899 (diff) |
Merge pull request #8 from fidlej/topic_data_oblivious
Using data.gradOutput to hold the to-be-summed gradients
-rw-r--r-- | gmodule.lua | 103 |
1 files changed, 46 insertions, 57 deletions
diff --git a/gmodule.lua b/gmodule.lua index 44e1f39..1633c12 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -4,6 +4,24 @@ local istensor = utils.istensor local istable = utils.istable local istorchclass = utils.istorchclass +local function getTotalGradOutput(node) + local module = node.data.module + local gradOutput = node.data.gradOutput + assert(istable(gradOutput), "expecting gradients to sum") + if #gradOutput > 1 then + node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new() + local gobuff = node.data.gradOutputBuffer + gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1]) + for i=2,#gradOutput do + gobuff:add(gradOutput[i]) + end + gradOutput = gobuff + else + gradOutput = gradOutput[1] + end + return gradOutput +end + local gModule, parent = torch.class('nn.gModule','nn.Module') function gModule:__init(inputs,outputs) @@ -61,10 +79,11 @@ function gModule:runForwardFunction(func,input) local func_name = func func = function(module,input) return module[func_name](module,input) end end - -- we will assume that the input is either a table of stuff - -- if not we will put it in a table of stuff - if torch.typename(input) or type(input) ~= 'table' then + -- We see the input as a list of inputs. + if #self.innode.data.mapindex <= 1 then input={input} + elseif type(input) ~= "table" then + error(string.format("expecting %s inputs", #self.innode.data.mapindex)) end local function neteval(node) local function propagate(node,x) @@ -150,27 +169,21 @@ function gModule:runForwardFunction(func,input) end self.output = self.outnode.data.input - if #self.outnode.children == 1 and self.output == self.outnode.data.input then + if #self.outnode.children == 1 then self.output = self.output[1] end return self.output end function gModule:updateGradInput(input,gradOutput) - -- we will assume that the input is either a table of stuff - -- if not we will put it in a table of stuff - if torch.typename(gradOutput) or type(gradOutput) ~= 'table' then + -- We see the gradOutput as a list of gradOutputs + if #self.outnode.children <= 1 then gradOutput={gradOutput} + elseif type(gradOutput) ~= "table" then + error(string.format("expecting %s gradOutputs", #self.outnode.children)) end local outputs = {} local function neteval(node) - local function propagate(node,x) - for i,child in ipairs(node.children) do - child.data.gradOutput = child.data.gradOutput or {} - local mapindex = node.data.mapindex[child.data] - table.insert(child.data.gradOutput,x[mapindex]) - end - end if node.data.data then -- then this is a data node, just propagate into -- its children @@ -191,39 +204,24 @@ function gModule:updateGradInput(input,gradOutput) -- its children for i,child in ipairs(node.children) do child.data.gradOutput = child.data.gradOutput or {} - local go = node.data.gradOutput - if istable(go) and #go == 1 then - go = go[1] - end + local go = getTotalGradOutput(node) if node.data.selectindex then - child.data.gradOutput[node.data.selectindex] = go + assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") + -- The data.gradOutput holds the to-be-summed gradients. + child.data.gradOutput[1] = child.data.gradOutput[1] or {} + child.data.gradOutput[1][node.data.selectindex] = go else table.insert(child.data.gradOutput,go) end end elseif node.data.module then local module = node.data.module - local gradOutput = node.data.gradOutput + local gradOutput = getTotalGradOutput(node) local input = node.data.input if #input == 1 then input = input[1] end -- updateGradInput through this node - if istable(gradOutput) and not istable(module.output) then - if #gradOutput > 1 then - node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new() - local gobuff = node.data.gradOutputBuffer - gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1]) - for i=2,#gradOutput do - gobuff:add(gradOutput[i]) - end - gradOutput = gobuff - else - gradOutput = gradOutput[1] - end - elseif istable(gradOutput) and istable(module.output) and #gradOutput ~= #module.output then - gradOutput = gradOutput[1] - end local gradInput = module:updateGradInput(input,gradOutput) -- propagate the output to children for i,child in ipairs(node.children) do @@ -265,23 +263,23 @@ function gModule:updateGradInput(input,gradOutput) end -- now fix the order of gradInput + -- The innode is used as the input by all the input nodes. + -- The data.gradOutput holds the gradients to-be-summed. + -- Here, instead of summing them, we reorder them based on the mapindex. self.gradInput = self.innode.data.gradOutput - if not istable(self.gradInput) then - return self.gradInput - end local gi = {} for i,child in ipairs(self.innode.children) do local mi = self.innode.data.mapindex[child.data] table.insert(gi,self.gradInput[mi]) end - while istable(self.gradInput) and #self.gradInput > 0 do + while #self.gradInput > 0 do table.remove(self.gradInput) end for i,v in ipairs(gi) do table.insert(self.gradInput,v) end - if #self.innode.children == 1 and self.gradInput == self.innode.data.gradOutput then + if #self.innode.children == 1 then self.gradInput = self.gradInput[1] end @@ -289,10 +287,11 @@ function gModule:updateGradInput(input,gradOutput) end function gModule:accGradParameters(input,gradOutput,lr) - -- we will assume that the input is either a table of stuff - -- if not we will put it in a table of stuff - if torch.typename(gradOutput) or type(gradOutput) ~= 'table' then + -- We see the gradOutput as a list of gradOutputs + if #self.outnode.children <= 1 then gradOutput={gradOutput} + elseif type(gradOutput) ~= "table" then + error(string.format("expecting %s gradOutputs", #self.outnode.children)) end local outputs = {} local function neteval(node) @@ -300,25 +299,15 @@ function gModule:accGradParameters(input,gradOutput,lr) elseif not node.data.module and node.data.gradOutput then elseif node.data.module then local module = node.data.module - local gradOutput = node.data.gradOutput + local gradOutput = node.data.gradOutput[1] + if #node.data.gradOutput > 1 then + gradOutput = node.data.gradOutputBuffer + end local input = node.data.input if #input == 1 then input = input[1] end -- accGradParameters through this node - if istable(gradOutput) and not istable(module.output) then - if #gradOutput > 1 then - node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new() - local gobuff = node.data.gradOutputBuffer - gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1]) - for i=2,#gradOutput do - gobuff:add(gradOutput[i]) - end - gradOutput = gobuff - else - gradOutput = gradOutput[1] - end - end module:accGradParameters(input,gradOutput,lr) else if self.verbose then |