Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:23:17 +0400
committerSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:29:28 +0400
commitd85c2ce7ab24a699855f41b1919a74f81def47cd (patch)
tree27fcdcfcf13c63b06610943e82a8798ef78d260c
parentea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff)
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and non-minibatch tensors through the same module. If this method is not used, the behaviour of these modules is the same as before.
-rw-r--r--JoinTable.lua34
-rw-r--r--SplitTable.lua29
-rw-r--r--doc/table.md106
-rw-r--r--test/test.lua39
4 files changed, 190 insertions, 18 deletions
diff --git a/JoinTable.lua b/JoinTable.lua
index dc20246..04e6d31 100644
--- a/JoinTable.lua
+++ b/JoinTable.lua
@@ -5,16 +5,29 @@ function JoinTable:__init(dimension)
self.size = torch.LongStorage()
self.dimension = dimension
self.gradInput = {}
+ self.nInputDims = nil
end
+-- Sets the expected number of dimensions
+-- in a non-minibatch input.
+function JoinTable:setNumInputDims(nInputDims)
+ self.nInputDims = nInputDims
+ return self
+end
+
function JoinTable:updateOutput(input)
+ local dimension = self.dimension
+ if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+
for i=1,#input do
local currentOutput = input[i]
if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
- self.size[self.dimension] = self.size[self.dimension]
- + currentOutput:size(self.dimension)
+ self.size[dimension] = self.size[dimension]
+ + currentOutput:size(dimension)
end
end
self.output:resize(self.size)
@@ -22,15 +35,20 @@ function JoinTable:updateOutput(input)
local offset = 1
for i=1,#input do
local currentOutput = input[i]
- self.output:narrow(self.dimension, offset,
- currentOutput:size(self.dimension)):copy(currentOutput)
- offset = offset + currentOutput:size(self.dimension)
+ self.output:narrow(dimension, offset,
+ currentOutput:size(dimension)):copy(currentOutput)
+ offset = offset + currentOutput:size(dimension)
end
return self.output
end
function JoinTable:updateGradInput(input, gradOutput)
+ local dimension = self.dimension
+ if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+
for i=1,#input do
if self.gradInput[i] == nil then
self.gradInput[i] = input[i].new()
@@ -41,10 +59,10 @@ function JoinTable:updateGradInput(input, gradOutput)
local offset = 1
for i=1,#input do
local currentOutput = input[i]
- local currentGradInput = gradOutput:narrow(self.dimension, offset,
- currentOutput:size(self.dimension))
+ local currentGradInput = gradOutput:narrow(dimension, offset,
+ currentOutput:size(dimension))
self.gradInput[i]:copy(currentGradInput)
- offset = offset + currentOutput:size(self.dimension)
+ offset = offset + currentOutput:size(dimension)
end
return self.gradInput
end
diff --git a/SplitTable.lua b/SplitTable.lua
index d2c690e..b69e9ee 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -2,29 +2,44 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')
function SplitTable:__init(dimension)
parent.__init(self)
- self.modules = {}
+ self.modules = {}
self.dimension = dimension
+ self.nInputDims = nil
+end
+
+-- Sets the expected number of dimensions
+-- in a non-minibatch input.
+function SplitTable:setNumInputDims(nInputDims)
+ self.nInputDims = nInputDims
+ return self
end
function SplitTable:updateOutput(input)
- local currentOutput= {};
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local currentOutput= {}
+ local slices = input:size(dimension)
for i=1,slices do
- currentOutput[#currentOutput+1] = input:select(self.dimension,i)
+ currentOutput[#currentOutput+1] = input:select(dimension,i)
end
self.output = currentOutput
return self.output
end
-
function SplitTable:updateGradInput(input, gradOutput)
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local slices = input:size(dimension)
self.gradInput:resizeAs(input)
local offset = 1
for i=1,slices do
local currentGradInput = gradOutput[i];
- self.gradInput:select(self.dimension,i):copy(currentGradInput)
+ self.gradInput:select(dimension,i):copy(currentGradInput)
end
return self.gradInput
end
diff --git a/doc/table.md b/doc/table.md
index c55804a..97c2741 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -98,6 +98,10 @@ which gives the output:
Creates a module that takes a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor)
as input and outputs several tables, splitting the Tensor along dimension `dimension`.
+The method `setNumInputDims` allows to specify the number of dimensions that
+this module will receive. This makes it possible to forward both minibatch and
+non-minibatch tensors through the same module.
+
Example 1:
```lua
mlp=nn.SplitTable(2)
@@ -132,7 +136,7 @@ gives the output:
Example 2:
```lua
mlp=nn.SplitTable(1)
-pred=mlp:forward(torch.randn(10,3))
+pred=mlp:forward(torch.randn(4,3))
for i,k in pairs(pred) do print(i,k); end
```
gives the output:
@@ -162,6 +166,63 @@ gives the output:
[torch.Tensor of dimension 3]
```
+Example 3:
+```lua
+mlp=nn.SplitTable(1)
+mlp:setNumInputDims(2)
+pred=mlp:forward(torch.randn(2,4,3))
+for i,k in pairs(pred) do print(i,k); end
+pred=mlp:forward(torch.randn(4,3))
+for i,k in pairs(pred) do print(i,k); end
+```
+gives the output:
+```lua
+1
+-1.3533 0.7448 -0.8818
+-0.4521 -1.2463 0.0316
+[torch.DoubleTensor of dimension 2x3]
+
+2
+ 0.1130 -1.3904 1.4620
+ 0.6722 2.0910 -0.2466
+[torch.DoubleTensor of dimension 2x3]
+
+3
+ 0.4672 -1.2738 1.1559
+ 0.4664 0.0768 0.6243
+[torch.DoubleTensor of dimension 2x3]
+
+4
+ 0.4194 1.2991 0.2241
+ 2.9786 -0.6715 0.0393
+[torch.DoubleTensor of dimension 2x3]
+
+
+1
+-1.8932
+ 0.0516
+-0.6316
+[torch.DoubleTensor of dimension 3]
+
+2
+-0.3397
+-1.8881
+-0.0977
+[torch.DoubleTensor of dimension 3]
+
+3
+ 0.0135
+ 1.2089
+ 0.5785
+[torch.DoubleTensor of dimension 3]
+
+4
+-0.1758
+-0.0776
+-1.1013
+[torch.DoubleTensor of dimension 3]
+```
+
A more complicated example:
```lua
@@ -205,7 +266,11 @@ Creates a module that takes a list of Tensors as input and outputs a
[Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor)
by joining them together along dimension `dimension`.
-Example:
+The method `setNumInputDims` allows to specify the number of dimensions that
+this module will receive. This makes it possible to forward both minibatch and
+non-minibatch tensors through the same module.
+
+Example 1:
```lua
x=torch.randn(5,1)
y=torch.randn(5,1)
@@ -227,12 +292,14 @@ gives the output:
0.6580
0.1784
-1.7362
-
+[torch.DoubleTensor of dimension 10x1]
+
1.3965 0.1575
0.5146 0.4491
-1.5244 0.6580
-0.9540 0.1784
0.4256 -1.7362
+[torch.DoubleTensor of dimension 5x2]
1.3965
0.5146
@@ -244,6 +311,39 @@ gives the output:
[torch.Tensor of dimension 7x1]
```
+Example 2:
+```lua
+module = nn.JoinTable(2)
+module:setNumInputDims(2)
+
+x=torch.randn(3,1)
+y=torch.randn(3,1)
+
+mx=torch.randn(2,3,1)
+my=torch.randn(2,3,1)
+
+print(module:forward{x,y})
+print(module:forward{mx,my})
+```
+gives the output:
+```lua
+ 0.4288 1.2002
+-1.4084 -0.7960
+-0.2091 0.1852
+[torch.DoubleTensor of dimension 3x2]
+
+(1,.,.) =
+ 0.5561 0.1228
+ -0.6792 0.1153
+ 0.0687 0.2955
+
+(2,.,.) =
+ 2.5787 1.8185
+ -0.9860 0.6756
+ 0.1989 -0.4327
+[torch.DoubleTensor of dimension 2x3x2]
+```
+
A more complicated example:
```lua
diff --git a/test/test.lua b/test/test.lua
index be17fd7..2db6f2d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1790,6 +1790,45 @@ function nntest.LookupTable()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.JoinTable()
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim() do
+ module = nn.JoinTable(d)
+ mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d)
+ end
+
+ -- Minibatch
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim()-1 do
+ module = nn.JoinTable(d)
+ module:setNumInputDims(2)
+ mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d)
+ end
+end
+
+function nntest.SplitTable()
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim() do
+ module = nn.SplitTable(d)
+ mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d)
+ end
+
+ -- Minibatch
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim()-1 do
+ module = nn.SplitTable(d)
+ module:setNumInputDims(2)
+ mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
+ end
+end
+
+
mytester:add(nntest)
if not nn then