Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--JoinTable.lua32
-rw-r--r--SplitTable.lua24
-rw-r--r--doc/table.md108
-rw-r--r--test/test.lua37
4 files changed, 177 insertions, 24 deletions
diff --git a/JoinTable.lua b/JoinTable.lua
index dc20246..d445bd2 100644
--- a/JoinTable.lua
+++ b/JoinTable.lua
@@ -1,20 +1,26 @@
local JoinTable, parent = torch.class('nn.JoinTable', 'nn.Module')
-function JoinTable:__init(dimension)
+function JoinTable:__init(dimension, nInputDims)
parent.__init(self)
self.size = torch.LongStorage()
self.dimension = dimension
self.gradInput = {}
-end
+ self.nInputDims = nInputDims
+end
function JoinTable:updateOutput(input)
+ local dimension = self.dimension
+ if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+
for i=1,#input do
local currentOutput = input[i]
if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
- self.size[self.dimension] = self.size[self.dimension]
- + currentOutput:size(self.dimension)
+ self.size[dimension] = self.size[dimension]
+ + currentOutput:size(dimension)
end
end
self.output:resize(self.size)
@@ -22,15 +28,19 @@ function JoinTable:updateOutput(input)
local offset = 1
for i=1,#input do
local currentOutput = input[i]
- self.output:narrow(self.dimension, offset,
- currentOutput:size(self.dimension)):copy(currentOutput)
- offset = offset + currentOutput:size(self.dimension)
+ self.output:narrow(dimension, offset,
+ currentOutput:size(dimension)):copy(currentOutput)
+ offset = offset + currentOutput:size(dimension)
end
return self.output
-
end
function JoinTable:updateGradInput(input, gradOutput)
+ local dimension = self.dimension
+ if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+
for i=1,#input do
if self.gradInput[i] == nil then
self.gradInput[i] = input[i].new()
@@ -41,10 +51,10 @@ function JoinTable:updateGradInput(input, gradOutput)
local offset = 1
for i=1,#input do
local currentOutput = input[i]
- local currentGradInput = gradOutput:narrow(self.dimension, offset,
- currentOutput:size(self.dimension))
+ local currentGradInput = gradOutput:narrow(dimension, offset,
+ currentOutput:size(dimension))
self.gradInput[i]:copy(currentGradInput)
- offset = offset + currentOutput:size(self.dimension)
+ offset = offset + currentOutput:size(dimension)
end
return self.gradInput
end
diff --git a/SplitTable.lua b/SplitTable.lua
index d2c690e..0148c4e 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -1,30 +1,38 @@
local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')
-function SplitTable:__init(dimension)
+function SplitTable:__init(dimension, nInputDims)
parent.__init(self)
- self.modules = {}
+ self.modules = {}
self.dimension = dimension
+ self.nInputDims = nInputDims
end
function SplitTable:updateOutput(input)
- local currentOutput= {};
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local currentOutput= {}
+ local slices = input:size(dimension)
for i=1,slices do
- currentOutput[#currentOutput+1] = input:select(self.dimension,i)
+ currentOutput[#currentOutput+1] = input:select(dimension,i)
end
self.output = currentOutput
return self.output
end
-
function SplitTable:updateGradInput(input, gradOutput)
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local slices = input:size(dimension)
self.gradInput:resizeAs(input)
local offset = 1
for i=1,slices do
local currentGradInput = gradOutput[i];
- self.gradInput:select(self.dimension,i):copy(currentGradInput)
+ self.gradInput:select(dimension,i):copy(currentGradInput)
end
return self.gradInput
end
diff --git a/doc/table.md b/doc/table.md
index c55804a..4117117 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -93,11 +93,15 @@ which gives the output:
<a name="nn.SplitTable"/>
## SplitTable ##
-`module` = `SplitTable(dimension)`
+`module` = `SplitTable(dimension, nInputDims)`
Creates a module that takes a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor)
as input and outputs several tables, splitting the Tensor along dimension `dimension`.
+The optional parameter `nInputDims` allows to specify the number of dimensions that
+this module will receive. This makes it possible to forward both minibatch and
+non-minibatch tensors through the same module.
+
Example 1:
```lua
mlp=nn.SplitTable(2)
@@ -132,7 +136,7 @@ gives the output:
Example 2:
```lua
mlp=nn.SplitTable(1)
-pred=mlp:forward(torch.randn(10,3))
+pred=mlp:forward(torch.randn(4,3))
for i,k in pairs(pred) do print(i,k); end
```
gives the output:
@@ -162,6 +166,62 @@ gives the output:
[torch.Tensor of dimension 3]
```
+Example 3:
+```lua
+mlp=nn.SplitTable(1,2)
+pred=mlp:forward(torch.randn(2,4,3))
+for i,k in pairs(pred) do print(i,k); end
+pred=mlp:forward(torch.randn(4,3))
+for i,k in pairs(pred) do print(i,k); end
+```
+gives the output:
+```lua
+1
+-1.3533 0.7448 -0.8818
+-0.4521 -1.2463 0.0316
+[torch.DoubleTensor of dimension 2x3]
+
+2
+ 0.1130 -1.3904 1.4620
+ 0.6722 2.0910 -0.2466
+[torch.DoubleTensor of dimension 2x3]
+
+3
+ 0.4672 -1.2738 1.1559
+ 0.4664 0.0768 0.6243
+[torch.DoubleTensor of dimension 2x3]
+
+4
+ 0.4194 1.2991 0.2241
+ 2.9786 -0.6715 0.0393
+[torch.DoubleTensor of dimension 2x3]
+
+
+1
+-1.8932
+ 0.0516
+-0.6316
+[torch.DoubleTensor of dimension 3]
+
+2
+-0.3397
+-1.8881
+-0.0977
+[torch.DoubleTensor of dimension 3]
+
+3
+ 0.0135
+ 1.2089
+ 0.5785
+[torch.DoubleTensor of dimension 3]
+
+4
+-0.1758
+-0.0776
+-1.1013
+[torch.DoubleTensor of dimension 3]
+```
+
A more complicated example:
```lua
@@ -199,13 +259,17 @@ end
<a name="nn.JoinTable"/>
## JoinTable ##
-`module` = `JoinTable(dimension)`
+`module` = `JoinTable(dimension, nInputDims)`
Creates a module that takes a list of Tensors as input and outputs a
[Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor)
by joining them together along dimension `dimension`.
-Example:
+The optional parameter `nInputDims` allows to specify the number of dimensions that
+this module will receive. This makes it possible to forward both minibatch and
+non-minibatch tensors through the same module.
+
+Example 1:
```lua
x=torch.randn(5,1)
y=torch.randn(5,1)
@@ -227,12 +291,14 @@ gives the output:
0.6580
0.1784
-1.7362
-
+[torch.DoubleTensor of dimension 10x1]
+
1.3965 0.1575
0.5146 0.4491
-1.5244 0.6580
-0.9540 0.1784
0.4256 -1.7362
+[torch.DoubleTensor of dimension 5x2]
1.3965
0.5146
@@ -244,6 +310,38 @@ gives the output:
[torch.Tensor of dimension 7x1]
```
+Example 2:
+```lua
+module = nn.JoinTable(2,2)
+
+x=torch.randn(3,1)
+y=torch.randn(3,1)
+
+mx=torch.randn(2,3,1)
+my=torch.randn(2,3,1)
+
+print(module:forward{x,y})
+print(module:forward{mx,my})
+```
+gives the output:
+```lua
+ 0.4288 1.2002
+-1.4084 -0.7960
+-0.2091 0.1852
+[torch.DoubleTensor of dimension 3x2]
+
+(1,.,.) =
+ 0.5561 0.1228
+ -0.6792 0.1153
+ 0.0687 0.2955
+
+(2,.,.) =
+ 2.5787 1.8185
+ -0.9860 0.6756
+ 0.1989 -0.4327
+[torch.DoubleTensor of dimension 2x3x2]
+```
+
A more complicated example:
```lua
diff --git a/test/test.lua b/test/test.lua
index 775dded..3c4ec12 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1830,6 +1830,43 @@ function nntest.MulConstant()
mytester:assertlt(err, precision, 'bprop error ')
end
+function nntest.JoinTable()
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim() do
+ module = nn.JoinTable(d)
+ mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d)
+ end
+
+ -- Minibatch
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim()-1 do
+ module = nn.JoinTable(d, 2)
+ mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d)
+ end
+end
+
+function nntest.SplitTable()
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim() do
+ module = nn.SplitTable(d)
+ mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d)
+ end
+
+ -- Minibatch
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim()-1 do
+ module = nn.SplitTable(d, 2)
+ mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
+ end
+end
+
+
mytester:add(nntest)
if not nn then