Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_nngraph.lua')
-rw-r--r--test/test_nngraph.lua46
1 files changed, 46 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index e3d8982..b919c84 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -381,4 +381,50 @@ function test.test_gradInputType()
assert(not ok, "the missing input to split should be detected")
end
+ function test.test_gradOutputZeroOptim()
+ -- Make module that produces an expanded zero gradInput tensor
+ local dummyModule = nn.Module()
+ dummyModule.updateOutput = function(self, input)
+ self.output = torch.Tensor(1, 2, 3):uniform()
+ return self.output
+ end
+ dummyModule.updateGradInput = function(self, input, gradOutput)
+ local zeroTensor = torch.Tensor{0}
+ :view(unpack(torch.ones(input:dim()):totable()))
+ :expandAs(input)
+ self.gradInput = zeroTensor
+ return self.gradInput
+ end
+
+ -- First input and final gradOutput
+ local input = torch.Tensor(1, 2, 3):uniform()
+ local gradOutput = torch.Tensor(1, 2, 3):uniform()
+
+ -- First case: one intermediary gradOutput is going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = nn.Identity()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+
+ -- Second case: all intermediary gradOutputs are going to be zero
+ local x = nn.Identity()()
+ local h1 = dummyModule:clone()(x)
+ local h2 = dummyModule:clone()(x)
+ local y = nn.CAddTable()({h1, h2})
+ local mod = nn.gModule({x}, {y})
+
+ local ok, result = pcall(nn.Module.forward, mod, input)
+ assert(ok, "forward should succeed")
+
+ ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+ assert(ok, "backward should succeed")
+ end
+
tester:add(test):run()