Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvo Danihelka <ivo@danihelka.net>2014-04-03 17:33:48 +0400
committerIvo Danihelka <ivo@danihelka.net>2014-04-03 17:33:48 +0400
commit9ccefd70503a476e3227f4a6745e51c201cc93f0 (patch)
tree2d4d13093f05421e814be4e449fc86cbba03e413
parent05a9ca42259cb5657209fe2d6e4fa27542193844 (diff)
Added totem-based tests.
-rw-r--r--test/test_nngraph.lua306
1 files changed, 306 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
new file mode 100644
index 0000000..1193dc9
--- /dev/null
+++ b/test/test_nngraph.lua
@@ -0,0 +1,306 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+local function checkGradients(...)
+ totem.nn.checkGradients(tester, ...)
+end
+
+function test.test_oneOutput()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ tester:eq(module.output, torch.Tensor{1}, "output")
+ local gradInput = module:backward(input, torch.Tensor({-123}))
+ tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
+
+ local input2 = torch.Tensor({2})
+ module:forward(input2)
+ tester:eq(module.output, torch.Tensor{2}, "output for input2")
+ gradInput = module:backward(input2, torch.Tensor({-2}))
+ tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
+end
+
+
+function test.test_twoOutputs()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
+ tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
+ checkGradients(module, input)
+end
+
+function test.test_twoGradOutputs()
+ local in1 = nn.Sigmoid()()
+ local splitTable = nn.SplitTable(1)({in1})
+ local out1, out2 = splitTable:split(2)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.randn(2, 3)
+ local output = module:forward(input)
+ assert(#output == 2, "wrong number of outputs")
+ module:backward(input, {torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
+end
+
+function test.test_twoInputs()
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local prevH, prevCell = in2:split(2)
+
+ local out1 = nn.CMulTable()({in1, prevH, prevCell})
+ local module = nn.gModule({in1, in2}, {out1})
+
+ local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
+ module:forward(input)
+ local gradInput = module:backward(input, torch.randn(3))
+ assert(#gradInput == 2, "wrong number of gradInputs")
+ assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
+ checkGradients(module, input)
+end
+
+function test.test_twoInputs2()
+ local in1 = nn.Sigmoid()()
+ local in2 = nn.Sigmoid()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
+end
+
+function test.test_identity()
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
+end
+
+function test.test_gradInputType()
+ local xInput = torch.randn(3)
+ local h = torch.randn(3)
+
+ local x = nn.Identity()()
+ local prevRnnState = nn.Identity()()
+ local prevH1, prevCell = prevRnnState:split(2)
+ local prevH = prevH1
+
+ local cellOut = nn.CAddTable()({
+ nn.CMulTable()({x, prevH}),
+ nn.CMulTable()({prevH, prevCell})})
+ local module = nn.gModule({x, prevRnnState}, {cellOut})
+
+ local c = torch.randn(h:size())
+ local prevRnnState = {h, c}
+ local input = {xInput, prevRnnState}
+ local output = module:forward(input)
+
+ local gradOutput = torch.randn(h:size())
+ local gradInput = module:backward(input, gradOutput)
+
+ local gradX, gradPrevState = unpack(gradInput)
+ local gradPrevH, gradPrevCell = unpack(gradPrevState)
+ assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
+
+ tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
+ tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
+ checkGradients(module, input)
+end
+
+function test.test_tabularInput()
+ local in1 = nn.SplitTable(1)()
+ local out1 = nn.CAddTable()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.randn(2, 3)
+ checkGradients(module, input)
+end
+
+function test.test_extraTable()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.Tensor({123})
+ tester:eq(module:forward(input), input, "simple output")
+ tester:eq(module:forward({input}), {input}, "tabular output")
+end
+
+function test.test_accGradParameters()
+ local input = torch.randn(10)
+
+ local in1 = nn.CMul(input:nElement())()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+ checkGradients(module, input)
+end
+
+function test.test_example1()
+ local x1 = nn.Linear(20,10)()
+ local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
+ local mlp = nn.gModule({x1},{mout})
+
+ local x = torch.rand(20)
+ checkGradients(mlp, x)
+end
+
+function test.test_example2()
+ local x1=nn.Linear(20,20)()
+ local x2=nn.Linear(10,10)()
+ local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+ local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+ local madd=nn.CAddTable()({m0,m1})
+ local m2=nn.Sigmoid()(madd)
+ local m3=nn.Tanh()(madd)
+ local gmod = nn.gModule({x1,x2},{m2,m3})
+
+ local x = torch.rand(20)
+ local y = torch.rand(10)
+ checkGradients(gmod, {x, y})
+end
+
+function test.test_example3()
+ local m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+ local input = nn.Identity()()
+ local input1,input2 = m(input):split(2)
+ local m3 = nn.JoinTable(1)({input1,input2})
+ local g = nn.gModule({input},{m3})
+
+ local indata = torch.rand(2,10)
+ checkGradients(g, indata)
+end
+
+function test.test_example4()
+ local input = nn.Identity()()
+ local L1 = nn.Tanh()(nn.Linear(1,2)(input))
+ local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
+ local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
+ local g = nn.gModule({input},{L3})
+
+ local indata = torch.rand(1)
+ checkGradients(g, indata)
+end
+
+function test.test_type()
+ local in1 = nn.Linear(20,10)()
+ local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.rand(20)
+ local output = module:forward(input)
+ tester:eq(torch.typename(output), "torch.DoubleTensor")
+
+ module:float()
+ local output = module:forward(input:float())
+ tester:eq(torch.typename(output), "torch.FloatTensor")
+end
+
+function test.test_nestedGradInput()
+ local x = nn.Identity()()
+ local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
+ local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
+ local out = nn.CAddTable()({h1(x), h2(x)})
+
+ local model = nn.gModule({x}, {out})
+
+ local input = {}
+ input[1] = torch.randn(3, 3)
+ input[2] = torch.randn(3, 3)
+ input[3] = torch.randn(3, 3)
+
+ checkGradients(model, input)
+
+ local input = {}
+ input[1] = torch.randn(2, 3)
+ input[2] = torch.randn(2, 3)
+ input[3] = torch.randn(2, 3)
+
+ checkGradients(model, input)
+end
+
+function test.test_unusedInput()
+ local x = nn.Identity()()
+ local h = nn.Identity()()
+ local h2 = nn.Identity()()
+
+ local ok, result = pcall(nn.gModule, {x, h}, {x})
+ assert(not ok, "the unused input should be detected")
+end
+
+function test.test_unusedChild()
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
+
+ local ok, result = pcall(nn.gModule, {prevState}, {h})
+ assert(not ok, "the unused cell should be detected")
+end
+
+function test.test_nilInput()
+ local ok, result = pcall(function() nn.Sigmoid()(nil) end)
+ assert(not ok, "the nil input should be detected")
+end
+
+function test.test_unusedNode()
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local middleResult = nn.Sigmoid()(in2)
+ local out1 = nn.Sigmoid()(in1)
+
+ local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
+ assert(not ok, "the unused middleResult should be detected")
+end
+
+function test.test_usageAfterSplit()
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
+ local nextState = nn.Identity()(prevState)
+ local transformed = nn.Sigmoid()(cell)
+
+ local model = nn.gModule({prevState}, {h, nextState, transformed})
+ local nHidden = 10
+ local input = {torch.randn(nHidden), torch.randn(nHidden)}
+ checkGradients(model, input)
+end
+
+function test.test_resizeNestedAs()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+
+ local net = nn.gModule({in1}, {out1, out2})
+ local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ local gradInput = net:backward(input, net.output)
+ tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
+ checkGradients(net, input)
+end
+
+tester:add(test):run()