diff options
author | Frederic Besse <fbesse@google.com> | 2016-02-16 18:02:56 +0300 |
---|---|---|
committer | Frederic Besse <fbesse@google.com> | 2016-02-16 18:02:56 +0300 |
commit | 3c8bc6bbf7ce6c55ea96fbd7cc37af2ddc709dc9 (patch) | |
tree | c140f9b145cc2b3821e66ea86abc1b5c99d5cd4a | |
parent | ccc9627a95972eca32915100ceddddcfe6e87f43 (diff) |
Added optimisation to bypass the buffer allocation when all but one gradOutput are zero one-element tensors.
-rw-r--r-- | gmodule.lua | 26 | ||||
-rw-r--r-- | test/test_nngraph.lua | 46 |
2 files changed, 71 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua index 2556b15..c4a67ec 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -8,6 +8,30 @@ local function getTotalGradOutput(node) local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then + -- Check if we can bypass the allocation, for the special case where all + -- gradOutputs but one are zero tensors with an underlying one-element + -- storage. Note that for the case that we + -- cannot bypass it, this check will only be performed once + if not node.data.gradOutputBuffer then + local count = 0 + local idx = 1 + -- Count how many gradOutput are tensors of 1 element filled with zero + for i=1,#gradOutput do + local zero = torch.isTensor(gradOutput[i]) and + gradOutput[i]:storage() ~= nil and + gradOutput[i]:storage():size() == 1 and + gradOutput[i]:storage()[1] == 0 + if not zero then + idx = i + count = count + 1 + end + end + if count < 2 then + -- Return the only non-zero one, or the first one + -- if they are all zero + return gradOutput[idx] + end + end node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer nesting.resizeNestedAs(gobuff, gradOutput[1]) @@ -117,7 +141,7 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - + -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index e3d8982..b919c84 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -381,4 +381,50 @@ function test.test_gradInputType() assert(not ok, "the missing input to split should be detected") end + function test.test_gradOutputZeroOptim() + -- Make module that produces an expanded zero gradInput tensor + local dummyModule = nn.Module() + dummyModule.updateOutput = function(self, input) + self.output = torch.Tensor(1, 2, 3):uniform() + return self.output + end + dummyModule.updateGradInput = function(self, input, gradOutput) + local zeroTensor = torch.Tensor{0} + :view(unpack(torch.ones(input:dim()):totable())) + :expandAs(input) + self.gradInput = zeroTensor + return self.gradInput + end + + -- First input and final gradOutput + local input = torch.Tensor(1, 2, 3):uniform() + local gradOutput = torch.Tensor(1, 2, 3):uniform() + + -- First case: one intermediary gradOutput is going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = nn.Identity()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + + -- Second case: all intermediary gradOutputs are going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = dummyModule:clone()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + end + tester:add(test):run() |