Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederic Besse <fbesse@google.com>2016-02-16 18:02:56 +0300
committerFrederic Besse <fbesse@google.com>2016-02-16 18:02:56 +0300
commit3c8bc6bbf7ce6c55ea96fbd7cc37af2ddc709dc9 (patch)
treec140f9b145cc2b3821e66ea86abc1b5c99d5cd4a
parentccc9627a95972eca32915100ceddddcfe6e87f43 (diff)
Added optimisation to bypass the buffer allocation when all but one gradOutput are zero one-element tensors.
-rw-r--r--gmodule.lua26
-rw-r--r--test/test_nngraph.lua46
2 files changed, 71 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 2556b15..c4a67ec 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -8,6 +8,30 @@ local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
+ -- Check if we can bypass the allocation, for the special case where all
+ -- gradOutputs but one are zero tensors with an underlying one-element
+ -- storage. Note that for the case that we
+ -- cannot bypass it, this check will only be performed once
+ if not node.data.gradOutputBuffer then
+ local count = 0
+ local idx = 1
+ -- Count how many gradOutput are tensors of 1 element filled with zero
+ for i=1,#gradOutput do
+ local zero = torch.isTensor(gradOutput[i]) and
+ gradOutput[i]:storage() ~= nil and
+ gradOutput[i]:storage():size() == 1 and
+ gradOutput[i]:storage()[1] == 0
+ if not zero then
+ idx = i
+ count = count + 1
+ end
+ end
+ if count < 2 then
+ -- Return the only non-zero one, or the first one
+ -- if they are all zero
+ return gradOutput[idx]
+ end
+ end
node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
nesting.resizeNestedAs(gobuff, gradOutput[1])
@@ -117,7 +141,7 @@ function gModule:__init(inputs,outputs)
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
-
+
-- iteratare over all nodes: check, tag and add to container
for i,node in ipairs(self.forwardnodes) do
-- check for unused inputs or unused split() outputs
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index e3d8982..b919c84 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -381,4 +381,50 @@ function test.test_gradInputType()
assert(not ok, "the missing input to split should be detected")
end
+ function test.test_gradOutputZeroOptim()
+ -- Make module that produces an expanded zero gradInput tensor
+ local dummyModule = nn.Module()
+ dummyModule.updateOutput = function(self, input)
+ self.output = torch.Tensor(1, 2, 3):uniform()
+ return self.output
+ end
+ dummyModule.updateGradInput = function(self, input, gradOutput)
+ local zeroTensor = torch.Tensor{0}
+ :view(unpack(torch.ones(input:dim()):totable()))
+ :expandAs(input)
+ self.gradInput = zeroTensor
+ return self.gradInput
+ end
+
+ -- First input and final gradOutput
+ local input = torch.Tensor(1, 2, 3):uniform()
+ local gradOutput = torch.Tensor(1, 2, 3):uniform()
+
+ -- First case: one intermediary gradOutput is going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = nn.Identity()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+
+ -- Second case: all intermediary gradOutputs are going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = dummyModule:clone()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+ end
+
tester:add(test):run()