diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 14 |
1 files changed, 13 insertions, 1 deletions
@@ -2964,8 +2964,20 @@ function nntest.SelectTable() equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) end - module:float() + + -- test negative index + local idx = -2 + module = nn.SelectTable(idx) + local output = module:forward(input) + equal(output, input[#input+idx+1], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[#input+idx+1]) + equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx) + + -- test typecast local idx = #input + module = nn.SelectTable(idx) + module:float() local output = module:forward(input) equal(output, input[idx], "type output") local gradInput = module:backward(input, gradOutputs[idx]) |