Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoost van Doorn <joost.van.doorn@gmail.com>2016-04-30 18:43:25 +0300
committerJoost van Doorn <joost.van.doorn@gmail.com>2016-04-30 18:43:25 +0300
commitf3b88eff27f523a6b01f3d0a159aaa1da6b49e75 (patch)
tree8f1e383d49b77a264f05066dd7e66c063db7d623 /Select.lua
parent5ed07c93985e7a0923892e9448abf1c385677f9f (diff)
nn.Select accepts negative indices
Diffstat (limited to 'Select.lua')
-rw-r--r--Select.lua6
1 files changed, 4 insertions, 2 deletions
diff --git a/Select.lua b/Select.lua
index acf8e06..fccdf32 100644
--- a/Select.lua
+++ b/Select.lua
@@ -7,14 +7,16 @@ function Select:__init(dimension,index)
end
function Select:updateOutput(input)
- local output = input:select(self.dimension,self.index);
+ local index = self.index < 0 and input:size(self.dimension) + self.index + 1 or self.index
+ local output = input:select(self.dimension, index);
self.output:resizeAs(output)
return self.output:copy(output)
end
function Select:updateGradInput(input, gradOutput)
+ local index = self.index < 0 and input:size(self.dimension) + self.index + 1 or self.index
self.gradInput:resizeAs(input)
self.gradInput:zero()
- self.gradInput:select(self.dimension,self.index):copy(gradOutput)
+ self.gradInput:select(self.dimension,index):copy(gradOutput)
return self.gradInput
end