Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-03-08 23:13:48 +0300
committerSoumith Chintala <soumith@gmail.com>2016-03-08 23:13:48 +0300
commit2c66c0d9ceed8fdfd4f0230a4d0d38f6d9fbcee0 (patch)
treef265b217fa3c9325970fa99b125b3a85b06ce9c3
parentdb81121991c7b41d927f103425c500070fb9c708 (diff)
parent3c8bc6bbf7ce6c55ea96fbd7cc37af2ddc709dc9 (diff)
Merge pull request #103 from fbesse/gradoutput_zero_optim
Added optimisation to bypass the buffer allocation when all but one g…
-rw-r--r--gmodule.lua24
-rw-r--r--test/test_nngraph.lua46
2 files changed, 70 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 98bc196..3501f0c 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -8,6 +8,30 @@ local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
+ -- Check if we can bypass the allocation, for the special case where all
+ -- gradOutputs but one are zero tensors with an underlying one-element
+ -- storage. Note that for the case that we
+ -- cannot bypass it, this check will only be performed once
+ if not node.data.gradOutputBuffer then
+ local count = 0
+ local idx = 1
+ -- Count how many gradOutput are tensors of 1 element filled with zero
+ for i=1,#gradOutput do
+ local zero = torch.isTensor(gradOutput[i]) and
+ gradOutput[i]:storage() ~= nil and
+ gradOutput[i]:storage():size() == 1 and
+ gradOutput[i]:storage()[1] == 0
+ if not zero then
+ idx = i
+ count = count + 1
+ end
+ end
+ if count < 2 then
+ -- Return the only non-zero one, or the first one
+ -- if they are all zero
+ return gradOutput[idx]
+ end
+ end
node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
nesting.resizeNestedAs(gobuff, gradOutput[1])
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index e3d8982..b919c84 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -381,4 +381,50 @@ function test.test_gradInputType()
assert(not ok, "the missing input to split should be detected")
end
+ function test.test_gradOutputZeroOptim()
+ -- Make module that produces an expanded zero gradInput tensor
+ local dummyModule = nn.Module()
+ dummyModule.updateOutput = function(self, input)
+ self.output = torch.Tensor(1, 2, 3):uniform()
+ return self.output
+ end
+ dummyModule.updateGradInput = function(self, input, gradOutput)
+ local zeroTensor = torch.Tensor{0}
+ :view(unpack(torch.ones(input:dim()):totable()))
+ :expandAs(input)
+ self.gradInput = zeroTensor
+ return self.gradInput
+ end
+
+ -- First input and final gradOutput
+ local input = torch.Tensor(1, 2, 3):uniform()
+ local gradOutput = torch.Tensor(1, 2, 3):uniform()
+
+ -- First case: one intermediary gradOutput is going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = nn.Identity()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+
+ -- Second case: all intermediary gradOutputs are going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = dummyModule:clone()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+ end
+
tester:add(test):run()