diff options
Diffstat (limited to 'gmodule.lua')
-rw-r--r-- | gmodule.lua | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua index 2556b15..c4a67ec 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -8,6 +8,30 @@ local function getTotalGradOutput(node) local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then + -- Check if we can bypass the allocation, for the special case where all + -- gradOutputs but one are zero tensors with an underlying one-element + -- storage. Note that for the case that we + -- cannot bypass it, this check will only be performed once + if not node.data.gradOutputBuffer then + local count = 0 + local idx = 1 + -- Count how many gradOutput are tensors of 1 element filled with zero + for i=1,#gradOutput do + local zero = torch.isTensor(gradOutput[i]) and + gradOutput[i]:storage() ~= nil and + gradOutput[i]:storage():size() == 1 and + gradOutput[i]:storage()[1] == 0 + if not zero then + idx = i + count = count + 1 + end + end + if count < 2 then + -- Return the only non-zero one, or the first one + -- if they are all zero + return gradOutput[idx] + end + end node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer nesting.resizeNestedAs(gobuff, gradOutput[1]) @@ -117,7 +141,7 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - + -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs |