Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-07-27 08:45:10 +0300
committerGitHub <noreply@github.com>2016-07-27 08:45:10 +0300
commit7ff1f86ecffe6816e92049ef219e8922db0c2b67 (patch)
treea6eab4e234e6591ee33ada29bfbe7cd61980481e
parent08d0b5db1b56bbb52dfb81a233d288d3654c07d8 (diff)
parent4972614b3bb851a2234e832942d29fd7e89d3c77 (diff)
Merge pull request #126 from torch/typefix
fix for :type() to typecast the children as well
-rw-r--r--gmodule.lua6
-rw-r--r--test/test_nngraph.lua18
2 files changed, 19 insertions, 5 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 0a81da9..6e118d8 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -263,12 +263,18 @@ function gModule:type(type, tensorCache)
node.data.gradOutputBuffer =
recursiveType(node.data.gradOutputBuffer, type)
end
+ for k, child in ipairs(node.children) do
+ applyTypeToTable(child.data)
+ end
end
for i,node in ipairs(self.forwardnodes) do
if node.data.input ~= nil then
node.data.input = recursiveType(node.data.input, type)
end
+ for k, child in ipairs(node.children) do
+ applyTypeToTable(child.data)
+ end
end
self._type = type
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 8e5aadc..340ab24 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -214,6 +214,9 @@ function test.test_gradInputType()
local module = nn.gModule({in1}, {out1})
local input = torch.rand(20)
local output = module:forward(input)
+ local gradOutput = output:clone():normal()
+ local gradInput = module:backward(input, gradOutput)
+
module:backward(input, output)
tester:eq(torch.typename(output), "torch.DoubleTensor")
tester:eq(torch.typename(module.output), "torch.DoubleTensor")
@@ -222,17 +225,22 @@ function test.test_gradInputType()
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor")
module:float()
- local output = module:forward(input:float())
- tester:eq(torch.typename(output), "torch.FloatTensor")
tester:eq(torch.typename(module.output), "torch.FloatTensor")
tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor")
- end
+ tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor")
+ local output = module:forward(input:float())
+ tester:eq(torch.typename(output), "torch.FloatTensor")
+ local gradInput = module:backward(input:float(), gradOutput:float())
+ tester:eq(torch.typename(gradInput), "torch.FloatTensor")
+
+ end
function test.test_nestedGradInput()
local x = nn.Identity()()
@@ -384,8 +392,8 @@ function test.test_gradInputType()
end
function test.test_gradOutputZeroOptim()
- local unpack = function(...)
- if _G[unpack] then return _G[unpack](...)
+ local unpack = function(...)
+ if _G[unpack] then return _G[unpack](...)
else return table.unpack(...) end
end
-- Make module that produces an expanded zero gradInput tensor