Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvo Danihelka <ivo@danihelka.net>2013-09-01 15:22:32 +0400
committerIvo Danihelka <ivo@danihelka.net>2013-09-01 15:22:32 +0400
commit7d75756a12c4d93bc8dbdcde2690d2969edf1778 (patch)
treed8482449acce5552fada24b183e779b82ee30c97 /gmodule.lua
parenta1b777c6cb1798eafaf9b4dc37808154ed509e11 (diff)
Allowed the gradInputs to be tables with tensors.
Diffstat (limited to 'gmodule.lua')
-rw-r--r--gmodule.lua9
1 files changed, 5 insertions, 4 deletions
diff --git a/gmodule.lua b/gmodule.lua
index c3ca88c..06737f1 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -1,4 +1,5 @@
+local nesting = paths.dofile('nesting.lua')
local utils = paths.dofile('utils.lua')
local istensor = utils.istensor
local istable = utils.istable
@@ -8,11 +9,11 @@ local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
- node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new()
+ node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
- gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1])
- for i=2,#gradOutput do
- gobuff:add(gradOutput[i])
+ nesting.fillNested(gobuff, 0)
+ for i=1,#gradOutput do
+ nesting.addNestedTo(gobuff, gradOutput[i])
end
gradOutput = gobuff
else