Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-09-18 22:48:00 +0400
committerNicholas Leonard <nick@nikopia.org>2014-09-18 22:48:00 +0400
commit41452d951d0e66360299ef42a3f22dc39519b113 (patch)
treeba77e4820b6e536316f88ff9f4223501c41b3bcc /test
parent144c2f0177dbca983be522cca6f940bfd02d8a78 (diff)
initial commit for Push/PullTable
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua28
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 4236fd9..b3d033c 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -676,6 +676,34 @@ function nnxtest.MultiSoftMax()
mytester:assertTensorEq(gradInput, gradInput2, 0.000001)
end
+function nnxtest.PushPullTable()
+ -- use for targets with SoftMaxTree
+ local input = torch.randn(5,50)
+ local target = torch.IntTensor{20,23,27,10,8}
+ local grad = torch.randn(5)
+ local root_id = 29
+ local hierarchy={
+ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
+ [2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11},
+ [4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17},
+ [6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23},
+ [8]=torch.IntTensor{24,25,26,27,28}
+ }
+ local smt = nn.SoftMaxTree(100, hierarchy, root_id)
+ -- create a network where inputs are fed through softmaxtree
+ -- and targets are teleported (pushed then pulled) to softmaxtree
+ local mlp = nn.Sequential()
+ local linear = nn.Linear(50,100)
+ local push = nn.Push(2)
+ local pull = push:pull(2)
+ mlp:add(push)
+ mlp:add(nn.SelectTable(1))
+ mlp:add(linear)
+ mlp:add(pull)
+ mlp:add(smt)
+ print(mlp:forward{input, target})
+end
+
function nnx.test(tests)
xlua.require('image',true)