Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-10 22:54:36 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-10 22:54:36 +0300
commit818a17d2b1065dd0f3c733b3eb8f944b7ab16cde (patch)
tree743e790e9b0dcfd30590c36997889755c90c4e03 /test.lua
parente0196922091ee55bf5e939b2863304641aec5ddc (diff)
nn.SelectTable accepts negative indices
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/test.lua b/test.lua
index 82aefc8..bfece0e 100644
--- a/test.lua
+++ b/test.lua
@@ -2964,8 +2964,20 @@ function nntest.SelectTable()
equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
end
- module:float()
+
+ -- test negative index
+ local idx = -2
+ module = nn.SelectTable(idx)
+ local output = module:forward(input)
+ equal(output, input[#input+idx+1], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[#input+idx+1])
+ equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx)
+
+ -- test typecast
local idx = #input
+ module = nn.SelectTable(idx)
+ module:float()
local output = module:forward(input)
equal(output, input[idx], "type output")
local gradInput = module:backward(input, gradOutputs[idx])