diff options
-rw-r--r-- | SelectTable.lua | 14 | ||||
-rwxr-xr-x | doc/table.md | 6 | ||||
-rw-r--r-- | test.lua | 14 |
3 files changed, 27 insertions, 7 deletions
diff --git a/SelectTable.lua b/SelectTable.lua index 64cd105..61918f7 100644 --- a/SelectTable.lua +++ b/SelectTable.lua @@ -7,7 +7,13 @@ function SelectTable:__init(index) end function SelectTable:updateOutput(input) - self.output = input[self.index] + assert(math.abs(self.index) <= #input, "arg 1 table idx out of range") + if self.index < 0 then + self.output = input[#input + self.index + 1] + else + self.output = input[self.index] + end + return self.output end @@ -31,7 +37,11 @@ local function zeroTableCopy(t1, t2) end function SelectTable:updateGradInput(input, gradOutput) - self.gradInput[self.index] = gradOutput + if self.index < 0 then + self.gradInput[#input + self.index + 1] = gradOutput + else + self.gradInput[self.index] = gradOutput + end zeroTableCopy(self.gradInput, input) return self.gradInput end diff --git a/doc/table.md b/doc/table.md index 57c222d..c2aeb83 100755 --- a/doc/table.md +++ b/doc/table.md @@ -636,7 +636,7 @@ Forwarding a batch of 2 examples gives us something like this: `module` = `SelectTable(index)` -Creates a module that takes a `table` as input and outputs the element at index `index`. +Creates a module that takes a `table` as input and outputs the element at index `index` (positive or negative). This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the @@ -645,14 +645,12 @@ dept of the encapsulated `Tensor` as the function used internally to do so is re Example 1: ```lua > input = {torch.randn(2, 3), torch.randn(2, 1)} - [0.0002s] > =nn.SelectTable(1):forward(input) -0.3060 0.1398 0.2707 0.0576 1.5455 0.0610 [torch.DoubleTensor of dimension 2x3] - [0.0002s] -> =nn.SelectTable(2):forward(input) +> =nn.SelectTable(-1):forward(input) 2.3080 -0.2955 [torch.DoubleTensor of dimension 2x1] @@ -2964,8 +2964,20 @@ function nntest.SelectTable() equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) end - module:float() + + -- test negative index + local idx = -2 + module = nn.SelectTable(idx) + local output = module:forward(input) + equal(output, input[#input+idx+1], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[#input+idx+1]) + equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx) + + -- test typecast local idx = #input + module = nn.SelectTable(idx) + module:float() local output = module:forward(input) equal(output, input[idx], "type output") local gradInput = module:backward(input, gradOutputs[idx]) |