Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-09-19 00:21:42 +0400
committerNicholas Leonard <nick@nikopia.org>2014-09-19 00:21:42 +0400
commite68d15b0414c059a8ec747ee825266404d167238 (patch)
tree2d3db182a8ffc02e594f52b6c69ca19279202e10 /test
parentf87985eb0b2596dbb82e32cb9eaf0a102708d93c (diff)
PushTable/PullTable unit tested
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua18
1 files changed, 16 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index c784073..5da80bf 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -680,7 +680,7 @@ function nnxtest.PushPullTable()
-- use for targets with SoftMaxTree
local input = torch.randn(5,50)
local target = torch.IntTensor{20,23,27,10,8}
- local grad = torch.randn(5)
+ local gradOutput = torch.randn(5)
local root_id = 29
local hierarchy={
[29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
@@ -701,7 +701,21 @@ function nnxtest.PushPullTable()
mlp:add(linear)
mlp:add(pull)
mlp:add(smt)
- print(mlp:forward{input, target})
+ -- compare to simpler alternative
+ local mlp2 = nn.Sequential()
+ local para = nn.ParallelTable()
+ para:add(linear:clone())
+ para:add(nn.Identity())
+ mlp2:add(para)
+ mlp2:add(smt:clone())
+ local inputTable = {input, target}
+ local output = mlp:forward(inputTable)
+ local output2 = mlp2:forward(inputTable)
+ local gradInput = mlp:backward(inputTable, gradOutput)
+ local gradInput2 = mlp2:backward(inputTable, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.00001, "push/pull forward error")
+ mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull backward error")
+ mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull backward error")
end