diff options
author | Albert Zhuang <iamalbert@users.noreply.github.com> | 2017-01-01 20:21:44 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-01-01 20:21:44 +0300 |
commit | 8251438690b9b9d90efe8ecef3c4a8cbe3f13653 (patch) | |
tree | 7fe218465c5bc85998344e17826993c9879304d6 /SelectTable.lua | |
parent | bf83c63728f38a2c336a5427ec7f32c79a423804 (diff) |
SelectTable accept string as key (#951)
* SelectTable accept any index types (not just integers)
Diffstat (limited to 'SelectTable.lua')
-rw-r--r-- | SelectTable.lua | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/SelectTable.lua b/SelectTable.lua index 8eba85e..f383a10 100644 --- a/SelectTable.lua +++ b/SelectTable.lua @@ -7,8 +7,12 @@ function SelectTable:__init(index) end function SelectTable:updateOutput(input) + -- handle negative indices - local index = self.index < 0 and #input + self.index + 1 or self.index + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end assert(input[index], "index does not exist in the input table") self.output = input[index] @@ -41,7 +45,10 @@ function SelectTable:updateGradInput(input, gradOutput) -- make gradInput a zeroed copy of input zeroTableCopy(self.gradInput, input) -- handle negative indices - local index = self.index < 0 and #input + self.index + 1 or self.index + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end -- copy into gradInput[index] (necessary for variable sized inputs) assert(self.gradInput[index]) nn.utils.recursiveCopy(self.gradInput[index], gradOutput) |