diff options
-rw-r--r-- | test/test_nngraph.lua | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index a993283..8e5aadc 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -221,6 +221,7 @@ function test.test_gradInputType() tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor") module:float() local output = module:forward(input:float()) @@ -230,7 +231,8 @@ function test.test_gradInputType() tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") - end + tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor") + end function test.test_nestedGradInput() local x = nn.Identity()() |