From 818a17d2b1065dd0f3c733b3eb8f944b7ab16cde Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Fri, 10 Jul 2015 15:54:36 -0400 Subject: nn.SelectTable accepts negative indices --- test.lua | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 82aefc8..bfece0e 100644 --- a/test.lua +++ b/test.lua @@ -2964,8 +2964,20 @@ function nntest.SelectTable() equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) end - module:float() + + -- test negative index + local idx = -2 + module = nn.SelectTable(idx) + local output = module:forward(input) + equal(output, input[#input+idx+1], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[#input+idx+1]) + equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx) + + -- test typecast local idx = #input + module = nn.SelectTable(idx) + module:float() local output = module:forward(input) equal(output, input[idx], "type output") local gradInput = module:backward(input, gradOutputs[idx]) -- cgit v1.2.3