diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-09-19 00:21:42 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-09-19 00:21:42 +0400 |
commit | e68d15b0414c059a8ec747ee825266404d167238 (patch) | |
tree | 2d3db182a8ffc02e594f52b6c69ca19279202e10 /test | |
parent | f87985eb0b2596dbb82e32cb9eaf0a102708d93c (diff) |
PushTable/PullTable unit tested
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index c784073..5da80bf 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -680,7 +680,7 @@ function nnxtest.PushPullTable() -- use for targets with SoftMaxTree local input = torch.randn(5,50) local target = torch.IntTensor{20,23,27,10,8} - local grad = torch.randn(5) + local gradOutput = torch.randn(5) local root_id = 29 local hierarchy={ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, @@ -701,7 +701,21 @@ function nnxtest.PushPullTable() mlp:add(linear) mlp:add(pull) mlp:add(smt) - print(mlp:forward{input, target}) + -- compare to simpler alternative + local mlp2 = nn.Sequential() + local para = nn.ParallelTable() + para:add(linear:clone()) + para:add(nn.Identity()) + mlp2:add(para) + mlp2:add(smt:clone()) + local inputTable = {input, target} + local output = mlp:forward(inputTable) + local output2 = mlp2:forward(inputTable) + local gradInput = mlp:backward(inputTable, gradOutput) + local gradInput2 = mlp2:backward(inputTable, gradOutput) + mytester:assertTensorEq(output, output2, 0.00001, "push/pull forward error") + mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull backward error") + mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull backward error") end |