Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SelectTable.lua14
-rwxr-xr-xdoc/table.md6
-rw-r--r--test.lua14
3 files changed, 27 insertions, 7 deletions
diff --git a/SelectTable.lua b/SelectTable.lua
index 64cd105..61918f7 100644
--- a/SelectTable.lua
+++ b/SelectTable.lua
@@ -7,7 +7,13 @@ function SelectTable:__init(index)
end
function SelectTable:updateOutput(input)
- self.output = input[self.index]
+ assert(math.abs(self.index) <= #input, "arg 1 table idx out of range")
+ if self.index < 0 then
+ self.output = input[#input + self.index + 1]
+ else
+ self.output = input[self.index]
+ end
+
return self.output
end
@@ -31,7 +37,11 @@ local function zeroTableCopy(t1, t2)
end
function SelectTable:updateGradInput(input, gradOutput)
- self.gradInput[self.index] = gradOutput
+ if self.index < 0 then
+ self.gradInput[#input + self.index + 1] = gradOutput
+ else
+ self.gradInput[self.index] = gradOutput
+ end
zeroTableCopy(self.gradInput, input)
return self.gradInput
end
diff --git a/doc/table.md b/doc/table.md
index 57c222d..c2aeb83 100755
--- a/doc/table.md
+++ b/doc/table.md
@@ -636,7 +636,7 @@ Forwarding a batch of 2 examples gives us something like this:
`module` = `SelectTable(index)`
-Creates a module that takes a `table` as input and outputs the element at index `index`.
+Creates a module that takes a `table` as input and outputs the element at index `index` (positive or negative).
This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the
@@ -645,14 +645,12 @@ dept of the encapsulated `Tensor` as the function used internally to do so is re
Example 1:
```lua
> input = {torch.randn(2, 3), torch.randn(2, 1)}
- [0.0002s]
> =nn.SelectTable(1):forward(input)
-0.3060 0.1398 0.2707
0.0576 1.5455 0.0610
[torch.DoubleTensor of dimension 2x3]
- [0.0002s]
-> =nn.SelectTable(2):forward(input)
+> =nn.SelectTable(-1):forward(input)
2.3080
-0.2955
[torch.DoubleTensor of dimension 2x1]
diff --git a/test.lua b/test.lua
index 82aefc8..bfece0e 100644
--- a/test.lua
+++ b/test.lua
@@ -2964,8 +2964,20 @@ function nntest.SelectTable()
equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
end
- module:float()
+
+ -- test negative index
+ local idx = -2
+ module = nn.SelectTable(idx)
+ local output = module:forward(input)
+ equal(output, input[#input+idx+1], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[#input+idx+1])
+ equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx)
+
+ -- test typecast
local idx = #input
+ module = nn.SelectTable(idx)
+ module:float()
local output = module:forward(input)
equal(output, input[idx], "type output")
local gradInput = module:backward(input, gradOutputs[idx])