Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-09 10:48:11 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-09 10:48:11 +0300
commite0196922091ee55bf5e939b2863304641aec5ddc (patch)
tree64cbb67b03ee08b46c32877da598f25665a5d869
parentb7aa53d96fbb6c0f2eaa1976b28c5cf12edf1ced (diff)
parent350de823fcd43dcdc5948ef869d9ea135e2467a1 (diff)
Merge pull request #285 from sergomezcol/master
Add support for negative indices in nn.SplitTable
-rw-r--r--SplitTable.lua21
-rwxr-xr-xdoc/table.md47
-rw-r--r--test.lua7
3 files changed, 67 insertions, 8 deletions
diff --git a/SplitTable.lua b/SplitTable.lua
index bd46b71..a47c580 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -6,13 +6,21 @@ function SplitTable:__init(dimension, nInputDims)
self.nInputDims = nInputDims
end
-function SplitTable:updateOutput(input)
+function SplitTable:_getPositiveDimension(input)
local dimension = self.dimension
- if self.nInputDims and input:dim()==(self.nInputDims+1) then
- dimension = dimension + 1
+ if dimension < 0 then
+ dimension = input:dim() + dimension + 1
+ elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
end
- local currentOutput= {}
+ return dimension
+end
+
+function SplitTable:updateOutput(input)
+ local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)
+
+ local currentOutput= {}
for i=1,slices do
currentOutput[#currentOutput+1] = input:select(dimension,i)
end
@@ -21,10 +29,7 @@ function SplitTable:updateOutput(input)
end
function SplitTable:updateGradInput(input, gradOutput)
- local dimension = self.dimension
- if self.nInputDims and input:dim()==(self.nInputDims+1) then
- dimension = dimension + 1
- end
+ local dimension = self:_getPositiveDimension(input)
local slices = input:size(dimension)
self.gradInput:resizeAs(input)
diff --git a/doc/table.md b/doc/table.md
index b8cb3a9..57c222d 100755
--- a/doc/table.md
+++ b/doc/table.md
@@ -315,6 +315,53 @@ gives the output:
[torch.DoubleTensor of dimension 3]
```
+The module also supports indexing from the end using negative dimensions. This allows to use this module when the number of dimensions of the input is unknown.
+
+### Example
+
+```lua
+m = nn.SplitTable(-2)
+out = m:forward(torch.randn(3, 2))
+for i, k in ipairs(out) do print(i, k) end
+out = m:forward(torch.randn(1, 3, 2))
+for i, k in ipairs(out) do print(i, k) end
+```
+
+gives the output:
+
+```
+1
+ 0.1420
+-0.5698
+[torch.DoubleTensor of size 2]
+
+2
+ 0.1663
+ 0.1197
+[torch.DoubleTensor of size 2]
+
+3
+ 0.4198
+-1.1394
+[torch.DoubleTensor of size 2]
+
+
+1
+-2.4941
+-1.4541
+[torch.DoubleTensor of size 1x2]
+
+2
+ 0.4594
+ 1.1946
+[torch.DoubleTensor of size 1x2]
+
+3
+-2.3322
+-0.7383
+[torch.DoubleTensor of size 1x2]
+```
+
### A more complicated example
```lua
diff --git a/test.lua b/test.lua
index 55818e1..82aefc8 100644
--- a/test.lua
+++ b/test.lua
@@ -2929,6 +2929,13 @@ function nntest.SplitTable()
module = nn.SplitTable(d, 2)
mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
end
+
+ -- Negative indices
+ local module = nn.SplitTable(-3)
+ local input = torch.randn(3,4,5)
+ mytester:asserteq(#module:forward(input), 3, "negative index")
+ local input = torch.randn(2,3,4,5)
+ mytester:asserteq(#module:forward(input), 3, "negative index (minibatch)")
end
function nntest.SelectTable()