Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_nngraph.lua4
1 files changed, 3 insertions, 1 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index a993283..8e5aadc 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -221,6 +221,7 @@ function test.test_gradInputType()
tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor")
module:float()
local output = module:forward(input:float())
@@ -230,7 +231,8 @@ function test.test_gradInputType()
tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
- end
+ tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor")
+ end
function test.test_nestedGradInput()
local x = nn.Identity()()