diff options
author | Nicholas Leonard <nick@nikopia.org> | 2015-07-10 22:54:36 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2015-07-10 22:54:36 +0300 |
commit | 818a17d2b1065dd0f3c733b3eb8f944b7ab16cde (patch) | |
tree | 743e790e9b0dcfd30590c36997889755c90c4e03 /test.lua | |
parent | e0196922091ee55bf5e939b2863304641aec5ddc (diff) |
nn.SelectTable accepts negative indices
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 14 |
1 files changed, 13 insertions, 1 deletions
@@ -2964,8 +2964,20 @@ function nntest.SelectTable() equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) end - module:float() + + -- test negative index + local idx = -2 + module = nn.SelectTable(idx) + local output = module:forward(input) + equal(output, input[#input+idx+1], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[#input+idx+1]) + equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx) + + -- test typecast local idx = #input + module = nn.SelectTable(idx) + module:float() local output = module:forward(input) equal(output, input[idx], "type output") local gradInput = module:backward(input, gradOutputs[idx]) |