Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Gomez <sergomez@google.com>2015-06-03 12:50:55 +0300
committerSergio Gomez <sergomez@google.com>2015-06-03 19:09:14 +0300
commit350de823fcd43dcdc5948ef869d9ea135e2467a1 (patch)
treeaf9ae1daf60f23d298d3b9aae55c17469b0eca56 /SplitTable.lua
parentfdd6659d31ab160216701cd51b38cb1320b26fa7 (diff)
Add support for negative indices in nn.SplitTable
This module can now be used when the total number of dimensions is unknown.
Diffstat (limited to 'SplitTable.lua')
-rw-r--r--SplitTable.lua21
1 files changed, 13 insertions, 8 deletions
diff --git a/SplitTable.lua b/SplitTable.lua
index bd46b71..a47c580 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -6,13 +6,21 @@ function SplitTable:__init(dimension, nInputDims)
self.nInputDims = nInputDims
end
-function SplitTable:updateOutput(input)
+function SplitTable:_getPositiveDimension(input)
local dimension = self.dimension
- if self.nInputDims and input:dim()==(self.nInputDims+1) then
- dimension = dimension + 1
+ if dimension < 0 then
+ dimension = input:dim() + dimension + 1
+ elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
end
- local currentOutput= {}
+ return dimension
+end
+
+function SplitTable:updateOutput(input)
+ local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)
+
+ local currentOutput= {}
for i=1,slices do
currentOutput[#currentOutput+1] = input:select(dimension,i)
end
@@ -21,10 +29,7 @@ function SplitTable:updateOutput(input)
end
function SplitTable:updateGradInput(input, gradOutput)
- local dimension = self.dimension
- if self.nInputDims and input:dim()==(self.nInputDims+1) then
- dimension = dimension + 1
- end
+ local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)
self.gradInput:resizeAs(input)