Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Zhuang <iamalbert@users.noreply.github.com>2017-01-01 20:21:44 +0300
committerSoumith Chintala <soumith@gmail.com>2017-01-01 20:21:44 +0300
commit8251438690b9b9d90efe8ecef3c4a8cbe3f13653 (patch)
tree7fe218465c5bc85998344e17826993c9879304d6 /SelectTable.lua
parentbf83c63728f38a2c336a5427ec7f32c79a423804 (diff)
SelectTable accept string as key (#951)
* SelectTable accept any index types (not just integers)
Diffstat (limited to 'SelectTable.lua')
-rw-r--r--SelectTable.lua11
1 files changed, 9 insertions, 2 deletions
diff --git a/SelectTable.lua b/SelectTable.lua
index 8eba85e..f383a10 100644
--- a/SelectTable.lua
+++ b/SelectTable.lua
@@ -7,8 +7,12 @@ function SelectTable:__init(index)
end
function SelectTable:updateOutput(input)
+
-- handle negative indices
- local index = self.index < 0 and #input + self.index + 1 or self.index
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
assert(input[index], "index does not exist in the input table")
self.output = input[index]
@@ -41,7 +45,10 @@ function SelectTable:updateGradInput(input, gradOutput)
-- make gradInput a zeroed copy of input
zeroTableCopy(self.gradInput, input)
-- handle negative indices
- local index = self.index < 0 and #input + self.index + 1 or self.index
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
-- copy into gradInput[index] (necessary for variable sized inputs)
assert(self.gradInput[index])
nn.utils.recursiveCopy(self.gradInput[index], gradOutput)