diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-07-27 08:45:10 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-07-27 08:45:10 +0300 |
commit | 7ff1f86ecffe6816e92049ef219e8922db0c2b67 (patch) | |
tree | a6eab4e234e6591ee33ada29bfbe7cd61980481e | |
parent | 08d0b5db1b56bbb52dfb81a233d288d3654c07d8 (diff) | |
parent | 4972614b3bb851a2234e832942d29fd7e89d3c77 (diff) |
Merge pull request #126 from torch/typefix
fix for :type() to typecast the children as well
-rw-r--r-- | gmodule.lua | 6 | ||||
-rw-r--r-- | test/test_nngraph.lua | 18 |
2 files changed, 19 insertions, 5 deletions
diff --git a/gmodule.lua b/gmodule.lua index 0a81da9..6e118d8 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -263,12 +263,18 @@ function gModule:type(type, tensorCache) node.data.gradOutputBuffer = recursiveType(node.data.gradOutputBuffer, type) end + for k, child in ipairs(node.children) do + applyTypeToTable(child.data) + end end for i,node in ipairs(self.forwardnodes) do if node.data.input ~= nil then node.data.input = recursiveType(node.data.input, type) end + for k, child in ipairs(node.children) do + applyTypeToTable(child.data) + end end self._type = type diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 8e5aadc..340ab24 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -214,6 +214,9 @@ function test.test_gradInputType() local module = nn.gModule({in1}, {out1}) local input = torch.rand(20) local output = module:forward(input) + local gradOutput = output:clone():normal() + local gradInput = module:backward(input, gradOutput) + module:backward(input, output) tester:eq(torch.typename(output), "torch.DoubleTensor") tester:eq(torch.typename(module.output), "torch.DoubleTensor") @@ -222,17 +225,22 @@ function test.test_gradInputType() tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor") module:float() - local output = module:forward(input:float()) - tester:eq(torch.typename(output), "torch.FloatTensor") tester:eq(torch.typename(module.output), "torch.FloatTensor") tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor") - end + tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor") + local output = module:forward(input:float()) + tester:eq(torch.typename(output), "torch.FloatTensor") + local gradInput = module:backward(input:float(), gradOutput:float()) + tester:eq(torch.typename(gradInput), "torch.FloatTensor") + + end function test.test_nestedGradInput() local x = nn.Identity()() @@ -384,8 +392,8 @@ function test.test_gradInputType() end function test.test_gradOutputZeroOptim() - local unpack = function(...) - if _G[unpack] then return _G[unpack](...) + local unpack = function(...) + if _G[unpack] then return _G[unpack](...) else return table.unpack(...) end end -- Make module that produces an expanded zero gradInput tensor |