diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2013-09-01 15:22:32 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2013-09-01 15:22:32 +0400 |
commit | 7d75756a12c4d93bc8dbdcde2690d2969edf1778 (patch) | |
tree | d8482449acce5552fada24b183e779b82ee30c97 /gmodule.lua | |
parent | a1b777c6cb1798eafaf9b4dc37808154ed509e11 (diff) |
Allowed the gradInputs to be tables with tensors.
Diffstat (limited to 'gmodule.lua')
-rw-r--r-- | gmodule.lua | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/gmodule.lua b/gmodule.lua index c3ca88c..06737f1 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -1,4 +1,5 @@ +local nesting = paths.dofile('nesting.lua') local utils = paths.dofile('utils.lua') local istensor = utils.istensor local istable = utils.istable @@ -8,11 +9,11 @@ local function getTotalGradOutput(node) local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then - node.data.gradOutputBuffer = node.data.gradOutputBuffer or gradOutput[1].new() + node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer - gobuff:resizeAs(gradOutput[1]):copy(gradOutput[1]) - for i=2,#gradOutput do - gobuff:add(gradOutput[i]) + nesting.fillNested(gobuff, 0) + for i=1,#gradOutput do + nesting.addNestedTo(gobuff, gradOutput[i]) end gradOutput = gobuff else |