diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-09-18 22:48:00 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-09-18 22:48:00 +0400 |
commit | 41452d951d0e66360299ef42a3f22dc39519b113 (patch) | |
tree | ba77e4820b6e536316f88ff9f4223501c41b3bcc /test | |
parent | 144c2f0177dbca983be522cca6f940bfd02d8a78 (diff) |
initial commit for Push/PullTable
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 4236fd9..b3d033c 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -676,6 +676,34 @@ function nnxtest.MultiSoftMax() mytester:assertTensorEq(gradInput, gradInput2, 0.000001) end +function nnxtest.PushPullTable() + -- use for targets with SoftMaxTree + local input = torch.randn(5,50) + local target = torch.IntTensor{20,23,27,10,8} + local grad = torch.randn(5) + local root_id = 29 + local hierarchy={ + [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, + [2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11}, + [4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17}, + [6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23}, + [8]=torch.IntTensor{24,25,26,27,28} + } + local smt = nn.SoftMaxTree(100, hierarchy, root_id) + -- create a network where inputs are fed through softmaxtree + -- and targets are teleported (pushed then pulled) to softmaxtree + local mlp = nn.Sequential() + local linear = nn.Linear(50,100) + local push = nn.Push(2) + local pull = push:pull(2) + mlp:add(push) + mlp:add(nn.SelectTable(1)) + mlp:add(linear) + mlp:add(pull) + mlp:add(smt) + print(mlp:forward{input, target}) +end + function nnx.test(tests) xlua.require('image',true) |