diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-07-09 10:48:11 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-07-09 10:48:11 +0300 |
commit | e0196922091ee55bf5e939b2863304641aec5ddc (patch) | |
tree | 64cbb67b03ee08b46c32877da598f25665a5d869 | |
parent | b7aa53d96fbb6c0f2eaa1976b28c5cf12edf1ced (diff) | |
parent | 350de823fcd43dcdc5948ef869d9ea135e2467a1 (diff) |
Merge pull request #285 from sergomezcol/master
Add support for negative indices in nn.SplitTable
-rw-r--r-- | SplitTable.lua | 21 | ||||
-rwxr-xr-x | doc/table.md | 47 | ||||
-rw-r--r-- | test.lua | 7 |
3 files changed, 67 insertions, 8 deletions
diff --git a/SplitTable.lua b/SplitTable.lua index bd46b71..a47c580 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -6,13 +6,21 @@ function SplitTable:__init(dimension, nInputDims) self.nInputDims = nInputDims end -function SplitTable:updateOutput(input) +function SplitTable:_getPositiveDimension(input) local dimension = self.dimension - if self.nInputDims and input:dim()==(self.nInputDims+1) then - dimension = dimension + 1 + if dimension < 0 then + dimension = input:dim() + dimension + 1 + elseif self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 end - local currentOutput= {} + return dimension +end + +function SplitTable:updateOutput(input) + local dimension = self:_getPositiveDimension(input) local slices = input:size(dimension) + + local currentOutput= {} for i=1,slices do currentOutput[#currentOutput+1] = input:select(dimension,i) end @@ -21,10 +29,7 @@ function SplitTable:updateOutput(input) end function SplitTable:updateGradInput(input, gradOutput) - local dimension = self.dimension - if self.nInputDims and input:dim()==(self.nInputDims+1) then - dimension = dimension + 1 - end + local dimension = self:_getPositiveDimension(input) local slices = input:size(dimension) self.gradInput:resizeAs(input) diff --git a/doc/table.md b/doc/table.md index b8cb3a9..57c222d 100755 --- a/doc/table.md +++ b/doc/table.md @@ -315,6 +315,53 @@ gives the output: [torch.DoubleTensor of dimension 3] ``` +The module also supports indexing from the end using negative dimensions. This allows to use this module when the number of dimensions of the input is unknown. + +### Example + +```lua +m = nn.SplitTable(-2) +out = m:forward(torch.randn(3, 2)) +for i, k in ipairs(out) do print(i, k) end +out = m:forward(torch.randn(1, 3, 2)) +for i, k in ipairs(out) do print(i, k) end +``` + +gives the output: + +``` +1 + 0.1420 +-0.5698 +[torch.DoubleTensor of size 2] + +2 + 0.1663 + 0.1197 +[torch.DoubleTensor of size 2] + +3 + 0.4198 +-1.1394 +[torch.DoubleTensor of size 2] + + +1 +-2.4941 +-1.4541 +[torch.DoubleTensor of size 1x2] + +2 + 0.4594 + 1.1946 +[torch.DoubleTensor of size 1x2] + +3 +-2.3322 +-0.7383 +[torch.DoubleTensor of size 1x2] +``` + ### A more complicated example ```lua @@ -2929,6 +2929,13 @@ function nntest.SplitTable() module = nn.SplitTable(d, 2) mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d) end + + -- Negative indices + local module = nn.SplitTable(-3) + local input = torch.randn(3,4,5) + mytester:asserteq(#module:forward(input), 3, "negative index") + local input = torch.randn(2,3,4,5) + mytester:asserteq(#module:forward(input), 3, "negative index (minibatch)") end function nntest.SelectTable() |