diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2014-04-03 17:33:48 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2014-04-03 17:33:48 +0400 |
commit | 9ccefd70503a476e3227f4a6745e51c201cc93f0 (patch) | |
tree | 2d4d13093f05421e814be4e449fc86cbba03e413 | |
parent | 05a9ca42259cb5657209fe2d6e4fa27542193844 (diff) |
Added totem-based tests.
-rw-r--r-- | test/test_nngraph.lua | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua new file mode 100644 index 0000000..1193dc9 --- /dev/null +++ b/test/test_nngraph.lua @@ -0,0 +1,306 @@ + +require 'totem' +require 'nngraph' +local test = {} +local tester = totem.Tester() + +local function checkGradients(...) + totem.nn.checkGradients(tester, ...) +end + +function test.test_oneOutput() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.Tensor({1}) + module:forward(input) + tester:eq(module.output, torch.Tensor{1}, "output") + local gradInput = module:backward(input, torch.Tensor({-123})) + tester:eq(gradInput, torch.Tensor{-123}, "gradInput") + + local input2 = torch.Tensor({2}) + module:forward(input2) + tester:eq(module.output, torch.Tensor{2}, "output for input2") + gradInput = module:backward(input2, torch.Tensor({-2})) + tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") +end + + +function test.test_twoOutputs() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.Tensor({1}) + module:forward(input) + local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) + tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") + checkGradients(module, input) +end + +function test.test_twoGradOutputs() + local in1 = nn.Sigmoid()() + local splitTable = nn.SplitTable(1)({in1}) + local out1, out2 = splitTable:split(2) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.randn(2, 3) + local output = module:forward(input) + assert(#output == 2, "wrong number of outputs") + module:backward(input, {torch.randn(3), torch.randn(3)}) + checkGradients(module, input) +end + +function test.test_twoInputs() + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local prevH, prevCell = in2:split(2) + + local out1 = nn.CMulTable()({in1, prevH, prevCell}) + local module = nn.gModule({in1, in2}, {out1}) + + local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} + module:forward(input) + local gradInput = module:backward(input, torch.randn(3)) + assert(#gradInput == 2, "wrong number of gradInputs") + assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") + checkGradients(module, input) +end + +function test.test_twoInputs2() + local in1 = nn.Sigmoid()() + local in2 = nn.Sigmoid()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) +end + +function test.test_identity() + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) +end + +function test.test_gradInputType() + local xInput = torch.randn(3) + local h = torch.randn(3) + + local x = nn.Identity()() + local prevRnnState = nn.Identity()() + local prevH1, prevCell = prevRnnState:split(2) + local prevH = prevH1 + + local cellOut = nn.CAddTable()({ + nn.CMulTable()({x, prevH}), + nn.CMulTable()({prevH, prevCell})}) + local module = nn.gModule({x, prevRnnState}, {cellOut}) + + local c = torch.randn(h:size()) + local prevRnnState = {h, c} + local input = {xInput, prevRnnState} + local output = module:forward(input) + + local gradOutput = torch.randn(h:size()) + local gradInput = module:backward(input, gradOutput) + + local gradX, gradPrevState = unpack(gradInput) + local gradPrevH, gradPrevCell = unpack(gradPrevState) + assert(type(gradPrevH) == type(h), "wrong gradPrevH type") + + tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") + tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") + checkGradients(module, input) +end + +function test.test_tabularInput() + local in1 = nn.SplitTable(1)() + local out1 = nn.CAddTable()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.randn(2, 3) + checkGradients(module, input) +end + +function test.test_extraTable() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.Tensor({123}) + tester:eq(module:forward(input), input, "simple output") + tester:eq(module:forward({input}), {input}, "tabular output") +end + +function test.test_accGradParameters() + local input = torch.randn(10) + + local in1 = nn.CMul(input:nElement())() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + checkGradients(module, input) +end + +function test.test_example1() + local x1 = nn.Linear(20,10)() + local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) + local mlp = nn.gModule({x1},{mout}) + + local x = torch.rand(20) + checkGradients(mlp, x) +end + +function test.test_example2() + local x1=nn.Linear(20,20)() + local x2=nn.Linear(10,10)() + local m0=nn.Linear(20,1)(nn.Tanh()(x1)) + local m1=nn.Linear(10,1)(nn.Tanh()(x2)) + local madd=nn.CAddTable()({m0,m1}) + local m2=nn.Sigmoid()(madd) + local m3=nn.Tanh()(madd) + local gmod = nn.gModule({x1,x2},{m2,m3}) + + local x = torch.rand(20) + local y = torch.rand(10) + checkGradients(gmod, {x, y}) +end + +function test.test_example3() + local m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) + local input = nn.Identity()() + local input1,input2 = m(input):split(2) + local m3 = nn.JoinTable(1)({input1,input2}) + local g = nn.gModule({input},{m3}) + + local indata = torch.rand(2,10) + checkGradients(g, indata) +end + +function test.test_example4() + local input = nn.Identity()() + local L1 = nn.Tanh()(nn.Linear(1,2)(input)) + local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) + local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) + local g = nn.gModule({input},{L3}) + + local indata = torch.rand(1) + checkGradients(g, indata) +end + +function test.test_type() + local in1 = nn.Linear(20,10)() + local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) + local module = nn.gModule({in1}, {out1}) + + local input = torch.rand(20) + local output = module:forward(input) + tester:eq(torch.typename(output), "torch.DoubleTensor") + + module:float() + local output = module:forward(input:float()) + tester:eq(torch.typename(output), "torch.FloatTensor") +end + +function test.test_nestedGradInput() + local x = nn.Identity()() + local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) + local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) + local out = nn.CAddTable()({h1(x), h2(x)}) + + local model = nn.gModule({x}, {out}) + + local input = {} + input[1] = torch.randn(3, 3) + input[2] = torch.randn(3, 3) + input[3] = torch.randn(3, 3) + + checkGradients(model, input) + + local input = {} + input[1] = torch.randn(2, 3) + input[2] = torch.randn(2, 3) + input[3] = torch.randn(2, 3) + + checkGradients(model, input) +end + +function test.test_unusedInput() + local x = nn.Identity()() + local h = nn.Identity()() + local h2 = nn.Identity()() + + local ok, result = pcall(nn.gModule, {x, h}, {x}) + assert(not ok, "the unused input should be detected") +end + +function test.test_unusedChild() + local prevState = nn.Identity()() + local h, cell = prevState:split(2) + + local ok, result = pcall(nn.gModule, {prevState}, {h}) + assert(not ok, "the unused cell should be detected") +end + +function test.test_nilInput() + local ok, result = pcall(function() nn.Sigmoid()(nil) end) + assert(not ok, "the nil input should be detected") +end + +function test.test_unusedNode() + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local middleResult = nn.Sigmoid()(in2) + local out1 = nn.Sigmoid()(in1) + + local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) + assert(not ok, "the unused middleResult should be detected") +end + +function test.test_usageAfterSplit() + local prevState = nn.Identity()() + local h, cell = prevState:split(2) + local nextState = nn.Identity()(prevState) + local transformed = nn.Sigmoid()(cell) + + local model = nn.gModule({prevState}, {h, nextState, transformed}) + local nHidden = 10 + local input = {torch.randn(nHidden), torch.randn(nHidden)} + checkGradients(model, input) +end + +function test.test_resizeNestedAs() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + + local net = nn.gModule({in1}, {out1, out2}) + local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + local gradInput = net:backward(input, net.output) + tester:eq(#(gradInput[2]), 2, "gradInput[2] size") + checkGradients(net, input) +end + +tester:add(test):run() |