diff options
Diffstat (limited to 'test/test_nngraph.lua')
-rw-r--r-- | test/test_nngraph.lua | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index e3d8982..b919c84 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -381,4 +381,50 @@ function test.test_gradInputType() assert(not ok, "the missing input to split should be detected") end + function test.test_gradOutputZeroOptim() + -- Make module that produces an expanded zero gradInput tensor + local dummyModule = nn.Module() + dummyModule.updateOutput = function(self, input) + self.output = torch.Tensor(1, 2, 3):uniform() + return self.output + end + dummyModule.updateGradInput = function(self, input, gradOutput) + local zeroTensor = torch.Tensor{0} + :view(unpack(torch.ones(input:dim()):totable())) + :expandAs(input) + self.gradInput = zeroTensor + return self.gradInput + end + + -- First input and final gradOutput + local input = torch.Tensor(1, 2, 3):uniform() + local gradOutput = torch.Tensor(1, 2, 3):uniform() + + -- First case: one intermediary gradOutput is going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = nn.Identity()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + + -- Second case: all intermediary gradOutputs are going to be zero + local x = nn.Identity()() + local h1 = dummyModule:clone()(x) + local h2 = dummyModule:clone()(x) + local y = nn.CAddTable()({h1, h2}) + local mod = nn.gModule({x}, {y}) + + local ok, result = pcall(nn.Module.forward, mod, input) + assert(ok, "forward should succeed") + + ok, result = pcall(nn.Module.backward, mod, input, gradOutput) + assert(ok, "backward should succeed") + end + tester:add(test):run() |