diff options
Diffstat (limited to 'test/test_nngraph.lua')
-rw-r--r-- | test/test_nngraph.lua | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 8e5aadc..340ab24 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -214,6 +214,9 @@ function test.test_gradInputType() local module = nn.gModule({in1}, {out1}) local input = torch.rand(20) local output = module:forward(input) + local gradOutput = output:clone():normal() + local gradInput = module:backward(input, gradOutput) + module:backward(input, output) tester:eq(torch.typename(output), "torch.DoubleTensor") tester:eq(torch.typename(module.output), "torch.DoubleTensor") @@ -222,17 +225,22 @@ function test.test_gradInputType() tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor") module:float() - local output = module:forward(input:float()) - tester:eq(torch.typename(output), "torch.FloatTensor") tester:eq(torch.typename(module.output), "torch.FloatTensor") tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor") - end + tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor") + local output = module:forward(input:float()) + tester:eq(torch.typename(output), "torch.FloatTensor") + local gradInput = module:backward(input:float(), gradOutput:float()) + tester:eq(torch.typename(gradInput), "torch.FloatTensor") + + end function test.test_nestedGradInput() local x = nn.Identity()() @@ -384,8 +392,8 @@ function test.test_gradInputType() end function test.test_gradOutputZeroOptim() - local unpack = function(...) - if _G[unpack] then return _G[unpack](...) + local unpack = function(...) + if _G[unpack] then return _G[unpack](...) else return table.unpack(...) end end -- Make module that produces an expanded zero gradInput tensor |