diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-03-08 23:13:48 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-03-08 23:13:48 +0300 |
commit | 2c66c0d9ceed8fdfd4f0230a4d0d38f6d9fbcee0 (patch) | |
tree | f265b217fa3c9325970fa99b125b3a85b06ce9c3 | |
parent | db81121991c7b41d927f103425c500070fb9c708 (diff) | |
parent | 3c8bc6bbf7ce6c55ea96fbd7cc37af2ddc709dc9 (diff) |
Merge pull request #103 from fbesse/gradoutput_zero_optim
Added optimisation to bypass the buffer allocation when all but one g…
-rw-r--r-- | gmodule.lua | 24 | ||||
-rw-r--r-- | test/test_nngraph.lua | 46 |
2 files changed, 70 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua index 98bc196..3501f0c 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -8,6 +8,30 @@ local function getTotalGradOutput(node) local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then + -- Check if we can bypass the allocation, for the special case where all + -- gradOutputs but one are zero tensors with an underlying one-element + -- storage. Note that for the case that we + -- cannot bypass it, this check will only be performed once + if not node.data.gradOutputBuffer then + local count = 0 + local idx = 1 + -- Count how many gradOutput are tensors of 1 element filled with zero + for i=1,#gradOutput do + local zero = torch.isTensor(gradOutput[i]) and + gradOutput[i]:storage() ~= nil and + gradOutput[i]:storage():size() == 1 and + gradOutput[i]:storage()[1] == 0 + if not zero then + idx = i + count = count + 1 + end + end + if count < 2 then + -- Return the only non-zero one, or the first one + -- if they are all zero + return gradOutput[idx] + end + end node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer nesting.resizeNestedAs(gobuff, gradOutput[1]) diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index e3d8982..b919c84 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -381,4 +381,50 @@ function test.test_gradInputType() assert(not ok, "the missing input to split should be detected") end + function test.test_gradOutputZeroOptim() + -- Make module that produces an expanded zero gradInput tensor + local dummyModule = nn.Module() + dummyModule.updateOutput = function(self, input) + self.output = torch.Tensor(1, 2, 3):uniform() + return self.output + end + dummyModule.updateGradInput = function(self, input, gradOutput) + local zeroTensor = torch.Tensor{0} + :view(unpack(torch.ones(input:dim()):totable())) + :expandAs(input) + self.gradInput = zeroTensor + return self.gradInput + end + + -- First input and final gradOutput + local input = torch.Tensor(1, 2, 3):uniform() + local gradOutput = torch.Tensor(1, 2, 3):uniform() + + -- First case: one intermediary gradOutput is going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = nn.Identity()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + + -- Second case: all intermediary gradOutputs are going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = dummyModule:clone()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + end + tester:add(test):run() |