diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2016-01-25 19:20:46 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2016-01-25 19:20:46 +0300 |
commit | b43cc8247fc789fed63f86e5aaa91582f1e15450 (patch) | |
tree | 31174e83afca20eedd09355ac70fff36f846b333 | |
parent | 6c71a7736bdff47bef80c632f7632abb5ac429be (diff) | |
parent | a25e293345db8b23acb1411e381b392bcc819586 (diff) |
Merge pull request #98 from malcolmreynolds/remove_zeroing_optimisation
Don't bother filling a Tensor with zero right before we copy into it
-rw-r--r-- | gmodule.lua | 4 | ||||
-rw-r--r-- | nesting.lua | 17 |
2 files changed, 19 insertions, 2 deletions
diff --git a/gmodule.lua b/gmodule.lua index 910e2c4..9b2c0d7 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -11,8 +11,8 @@ local function getTotalGradOutput(node) node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer nesting.resizeNestedAs(gobuff, gradOutput[1]) - nesting.fillNested(gobuff, 0) - for i=1,#gradOutput do + nesting.copyNested(gobuff, gradOutput[1]) + for i=2,#gradOutput do nesting.addNestedTo(gobuff, gradOutput[i]) end gradOutput = gobuff diff --git a/nesting.lua b/nesting.lua index 18899c1..e1ddd7b 100644 --- a/nesting.lua +++ b/nesting.lua @@ -51,6 +51,23 @@ function nesting.resizeNestedAs(output, input) end end +-- Copies all tensors in the output. +function nesting.copyNested(output, input) + if torch.isTensor(output) then + output:copy(input) + else + for key, child in pairs(input) do + nesting.copyNested(output[key], child) + end + -- Extra elements are removed from the output. + for key, child in pairs(output) do + if not input[key] then + output[key] = nil + end + end + end +end + -- Adds the input to the output. -- The input can contain nested tables. -- The output will contain the same nesting of tables. |