Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Zhuang <iamalbert@users.noreply.github.com>2017-01-01 20:21:44 +0300
committerSoumith Chintala <soumith@gmail.com>2017-01-01 20:21:44 +0300
commit8251438690b9b9d90efe8ecef3c4a8cbe3f13653 (patch)
tree7fe218465c5bc85998344e17826993c9879304d6
parentbf83c63728f38a2c336a5427ec7f32c79a423804 (diff)
SelectTable accept string as key (#951)
* SelectTable accept any index types (not just integers)
-rw-r--r--SelectTable.lua11
-rw-r--r--doc/table.md30
2 files changed, 37 insertions, 4 deletions
diff --git a/SelectTable.lua b/SelectTable.lua
index 8eba85e..f383a10 100644
--- a/SelectTable.lua
+++ b/SelectTable.lua
@@ -7,8 +7,12 @@ function SelectTable:__init(index)
end
function SelectTable:updateOutput(input)
+
-- handle negative indices
- local index = self.index < 0 and #input + self.index + 1 or self.index
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
assert(input[index], "index does not exist in the input table")
self.output = input[index]
@@ -41,7 +45,10 @@ function SelectTable:updateGradInput(input, gradOutput)
-- make gradInput a zeroed copy of input
zeroTableCopy(self.gradInput, input)
-- handle negative indices
- local index = self.index < 0 and #input + self.index + 1 or self.index
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
-- copy into gradInput[index] (necessary for variable sized inputs)
assert(self.gradInput[index])
nn.utils.recursiveCopy(self.gradInput[index], gradOutput)
diff --git a/doc/table.md b/doc/table.md
index ee61719..d5174a7 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -692,7 +692,7 @@ Forwarding a batch of 2 examples gives us something like this:
`module` = `SelectTable(index)`
-Creates a module that takes a `table` as input and outputs the element at index `index` (positive or negative).
+Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative).
This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the
@@ -719,10 +719,36 @@ Example 1:
0
0
[torch.DoubleTensor of dimension 2x1]
+```
+
+Exmaple 2:
+```lua
+> input = { A=torch.randn(2, 3), B=torch.randn(2, 1) }
+> =nn.SelectTable("A"):forward(input)
+-0.3060 0.1398 0.2707
+ 0.0576 1.5455 0.0610
+[torch.DoubleTensor of dimension 2x3]
+
+> gradInput = nn.SelectTable("A"):backward(input, torch.randn(2, 3))
+
+> gradInput
+{
+ A : DoubleTensor - size: 2x3
+ B : DoubleTensor - size: 2x1
+}
+
+> gradInput["A"]
+-0.4891 -0.3495 -0.3182
+-2.0999 0.7381 -0.5312
+[torch.DoubleTensor of dimension 2x3]
+> gradInput["B"]
+0
+0
+[torch.DoubleTensor of dimension 2x1]
```
-Example 2:
+Example 3:
```lua
> input = {torch.randn(2, 3), {torch.randn(2, 1), {torch.randn(2, 2)}}}