diff options
author | Sergio Gomez <sergomez@google.com> | 2015-06-03 12:50:55 +0300 |
---|---|---|
committer | Sergio Gomez <sergomez@google.com> | 2015-06-03 19:09:14 +0300 |
commit | 350de823fcd43dcdc5948ef869d9ea135e2467a1 (patch) | |
tree | af9ae1daf60f23d298d3b9aae55c17469b0eca56 /SplitTable.lua | |
parent | fdd6659d31ab160216701cd51b38cb1320b26fa7 (diff) |
Add support for negative indices in nn.SplitTable
This module can now be used when the total number of dimensions is unknown.
Diffstat (limited to 'SplitTable.lua')
-rw-r--r-- | SplitTable.lua | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/SplitTable.lua b/SplitTable.lua index bd46b71..a47c580 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -6,13 +6,21 @@ function SplitTable:__init(dimension, nInputDims) self.nInputDims = nInputDims end -function SplitTable:updateOutput(input) +function SplitTable:_getPositiveDimension(input) local dimension = self.dimension - if self.nInputDims and input:dim()==(self.nInputDims+1) then - dimension = dimension + 1 + if dimension < 0 then + dimension = input:dim() + dimension + 1 + elseif self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 end - local currentOutput= {} + return dimension +end + +function SplitTable:updateOutput(input) + local dimension = self:_getPositiveDimension(input) local slices = input:size(dimension) + + local currentOutput= {} for i=1,slices do currentOutput[#currentOutput+1] = input:select(dimension,i) end @@ -21,10 +29,7 @@ function SplitTable:updateOutput(input) end function SplitTable:updateGradInput(input, gradOutput) - local dimension = self.dimension - if self.nInputDims and input:dim()==(self.nInputDims+1) then - dimension = dimension + 1 - end + local dimension = self:_getPositiveDimension(input) local slices = input:size(dimension) self.gradInput:resizeAs(input) |