Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'gmodule.lua')
-rw-r--r--gmodule.lua26
1 files changed, 25 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 2556b15..c4a67ec 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -8,6 +8,30 @@ local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
+ -- Check if we can bypass the allocation, for the special case where all
+ -- gradOutputs but one are zero tensors with an underlying one-element
+ -- storage. Note that for the case that we
+ -- cannot bypass it, this check will only be performed once
+ if not node.data.gradOutputBuffer then
+ local count = 0
+ local idx = 1
+ -- Count how many gradOutput are tensors of 1 element filled with zero
+ for i=1,#gradOutput do
+ local zero = torch.isTensor(gradOutput[i]) and
+ gradOutput[i]:storage() ~= nil and
+ gradOutput[i]:storage():size() == 1 and
+ gradOutput[i]:storage()[1] == 0
+ if not zero then
+ idx = i
+ count = count + 1
+ end
+ end
+ if count < 2 then
+ -- Return the only non-zero one, or the first one
+ -- if they are all zero
+ return gradOutput[idx]
+ end
+ end
node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
nesting.resizeNestedAs(gobuff, gradOutput[1])
@@ -117,7 +141,7 @@ function gModule:__init(inputs,outputs)
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
-
+
-- iteratare over all nodes: check, tag and add to container
for i,node in ipairs(self.forwardnodes) do
-- check for unused inputs or unused split() outputs