diff options
author | Albert Zhuang <iamalbert@users.noreply.github.com> | 2017-01-01 20:21:44 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-01-01 20:21:44 +0300 |
commit | 8251438690b9b9d90efe8ecef3c4a8cbe3f13653 (patch) | |
tree | 7fe218465c5bc85998344e17826993c9879304d6 | |
parent | bf83c63728f38a2c336a5427ec7f32c79a423804 (diff) |
SelectTable accept string as key (#951)
* SelectTable accept any index types (not just integers)
-rw-r--r-- | SelectTable.lua | 11 | ||||
-rw-r--r-- | doc/table.md | 30 |
2 files changed, 37 insertions, 4 deletions
diff --git a/SelectTable.lua b/SelectTable.lua index 8eba85e..f383a10 100644 --- a/SelectTable.lua +++ b/SelectTable.lua @@ -7,8 +7,12 @@ function SelectTable:__init(index) end function SelectTable:updateOutput(input) + -- handle negative indices - local index = self.index < 0 and #input + self.index + 1 or self.index + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end assert(input[index], "index does not exist in the input table") self.output = input[index] @@ -41,7 +45,10 @@ function SelectTable:updateGradInput(input, gradOutput) -- make gradInput a zeroed copy of input zeroTableCopy(self.gradInput, input) -- handle negative indices - local index = self.index < 0 and #input + self.index + 1 or self.index + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end -- copy into gradInput[index] (necessary for variable sized inputs) assert(self.gradInput[index]) nn.utils.recursiveCopy(self.gradInput[index], gradOutput) diff --git a/doc/table.md b/doc/table.md index ee61719..d5174a7 100644 --- a/doc/table.md +++ b/doc/table.md @@ -692,7 +692,7 @@ Forwarding a batch of 2 examples gives us something like this: `module` = `SelectTable(index)` -Creates a module that takes a `table` as input and outputs the element at index `index` (positive or negative). +Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative). This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the @@ -719,10 +719,36 @@ Example 1: 0 0 [torch.DoubleTensor of dimension 2x1] +``` + +Exmaple 2: +```lua +> input = { A=torch.randn(2, 3), B=torch.randn(2, 1) } +> =nn.SelectTable("A"):forward(input) +-0.3060 0.1398 0.2707 + 0.0576 1.5455 0.0610 +[torch.DoubleTensor of dimension 2x3] + +> gradInput = nn.SelectTable("A"):backward(input, torch.randn(2, 3)) + +> gradInput +{ + A : DoubleTensor - size: 2x3 + B : DoubleTensor - size: 2x1 +} + +> gradInput["A"] +-0.4891 -0.3495 -0.3182 +-2.0999 0.7381 -0.5312 +[torch.DoubleTensor of dimension 2x3] +> gradInput["B"] +0 +0 +[torch.DoubleTensor of dimension 2x1] ``` -Example 2: +Example 3: ```lua > input = {torch.randn(2, 3), {torch.randn(2, 1), {torch.randn(2, 2)}}} |