diff options
author | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:23:17 +0400 |
---|---|---|
committer | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:29:28 +0400 |
commit | d85c2ce7ab24a699855f41b1919a74f81def47cd (patch) | |
tree | 27fcdcfcf13c63b06610943e82a8798ef78d260c | |
parent | ea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff) |
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and
non-minibatch tensors through the same module.
If this method is not used, the behaviour of these modules is the
same as before.
-rw-r--r-- | JoinTable.lua | 34 | ||||
-rw-r--r-- | SplitTable.lua | 29 | ||||
-rw-r--r-- | doc/table.md | 106 | ||||
-rw-r--r-- | test/test.lua | 39 |
4 files changed, 190 insertions, 18 deletions
diff --git a/JoinTable.lua b/JoinTable.lua index dc20246..04e6d31 100644 --- a/JoinTable.lua +++ b/JoinTable.lua @@ -5,16 +5,29 @@ function JoinTable:__init(dimension) self.size = torch.LongStorage() self.dimension = dimension self.gradInput = {} + self.nInputDims = nil end +-- Sets the expected number of dimensions +-- in a non-minibatch input. +function JoinTable:setNumInputDims(nInputDims) + self.nInputDims = nInputDims + return self +end + function JoinTable:updateOutput(input) + local dimension = self.dimension + if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + for i=1,#input do local currentOutput = input[i] if i == 1 then self.size:resize(currentOutput:dim()):copy(currentOutput:size()) else - self.size[self.dimension] = self.size[self.dimension] - + currentOutput:size(self.dimension) + self.size[dimension] = self.size[dimension] + + currentOutput:size(dimension) end end self.output:resize(self.size) @@ -22,15 +35,20 @@ function JoinTable:updateOutput(input) local offset = 1 for i=1,#input do local currentOutput = input[i] - self.output:narrow(self.dimension, offset, - currentOutput:size(self.dimension)):copy(currentOutput) - offset = offset + currentOutput:size(self.dimension) + self.output:narrow(dimension, offset, + currentOutput:size(dimension)):copy(currentOutput) + offset = offset + currentOutput:size(dimension) end return self.output end function JoinTable:updateGradInput(input, gradOutput) + local dimension = self.dimension + if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + for i=1,#input do if self.gradInput[i] == nil then self.gradInput[i] = input[i].new() @@ -41,10 +59,10 @@ function JoinTable:updateGradInput(input, gradOutput) local offset = 1 for i=1,#input do local currentOutput = input[i] - local currentGradInput = gradOutput:narrow(self.dimension, offset, - currentOutput:size(self.dimension)) + local currentGradInput = gradOutput:narrow(dimension, offset, + currentOutput:size(dimension)) self.gradInput[i]:copy(currentGradInput) - offset = offset + currentOutput:size(self.dimension) + offset = offset + currentOutput:size(dimension) end return self.gradInput end diff --git a/SplitTable.lua b/SplitTable.lua index d2c690e..b69e9ee 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -2,29 +2,44 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module') function SplitTable:__init(dimension) parent.__init(self) - self.modules = {} + self.modules = {} self.dimension = dimension + self.nInputDims = nil +end + +-- Sets the expected number of dimensions +-- in a non-minibatch input. +function SplitTable:setNumInputDims(nInputDims) + self.nInputDims = nInputDims + return self end function SplitTable:updateOutput(input) - local currentOutput= {}; - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local currentOutput= {} + local slices = input:size(dimension) for i=1,slices do - currentOutput[#currentOutput+1] = input:select(self.dimension,i) + currentOutput[#currentOutput+1] = input:select(dimension,i) end self.output = currentOutput return self.output end - function SplitTable:updateGradInput(input, gradOutput) - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local slices = input:size(dimension) self.gradInput:resizeAs(input) local offset = 1 for i=1,slices do local currentGradInput = gradOutput[i]; - self.gradInput:select(self.dimension,i):copy(currentGradInput) + self.gradInput:select(dimension,i):copy(currentGradInput) end return self.gradInput end diff --git a/doc/table.md b/doc/table.md index c55804a..97c2741 100644 --- a/doc/table.md +++ b/doc/table.md @@ -98,6 +98,10 @@ which gives the output: Creates a module that takes a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor) as input and outputs several tables, splitting the Tensor along dimension `dimension`. +The method `setNumInputDims` allows to specify the number of dimensions that +this module will receive. This makes it possible to forward both minibatch and +non-minibatch tensors through the same module. + Example 1: ```lua mlp=nn.SplitTable(2) @@ -132,7 +136,7 @@ gives the output: Example 2: ```lua mlp=nn.SplitTable(1) -pred=mlp:forward(torch.randn(10,3)) +pred=mlp:forward(torch.randn(4,3)) for i,k in pairs(pred) do print(i,k); end ``` gives the output: @@ -162,6 +166,63 @@ gives the output: [torch.Tensor of dimension 3] ``` +Example 3: +```lua +mlp=nn.SplitTable(1) +mlp:setNumInputDims(2) +pred=mlp:forward(torch.randn(2,4,3)) +for i,k in pairs(pred) do print(i,k); end +pred=mlp:forward(torch.randn(4,3)) +for i,k in pairs(pred) do print(i,k); end +``` +gives the output: +```lua +1 +-1.3533 0.7448 -0.8818 +-0.4521 -1.2463 0.0316 +[torch.DoubleTensor of dimension 2x3] + +2 + 0.1130 -1.3904 1.4620 + 0.6722 2.0910 -0.2466 +[torch.DoubleTensor of dimension 2x3] + +3 + 0.4672 -1.2738 1.1559 + 0.4664 0.0768 0.6243 +[torch.DoubleTensor of dimension 2x3] + +4 + 0.4194 1.2991 0.2241 + 2.9786 -0.6715 0.0393 +[torch.DoubleTensor of dimension 2x3] + + +1 +-1.8932 + 0.0516 +-0.6316 +[torch.DoubleTensor of dimension 3] + +2 +-0.3397 +-1.8881 +-0.0977 +[torch.DoubleTensor of dimension 3] + +3 + 0.0135 + 1.2089 + 0.5785 +[torch.DoubleTensor of dimension 3] + +4 +-0.1758 +-0.0776 +-1.1013 +[torch.DoubleTensor of dimension 3] +``` + A more complicated example: ```lua @@ -205,7 +266,11 @@ Creates a module that takes a list of Tensors as input and outputs a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor) by joining them together along dimension `dimension`. -Example: +The method `setNumInputDims` allows to specify the number of dimensions that +this module will receive. This makes it possible to forward both minibatch and +non-minibatch tensors through the same module. + +Example 1: ```lua x=torch.randn(5,1) y=torch.randn(5,1) @@ -227,12 +292,14 @@ gives the output: 0.6580 0.1784 -1.7362 - +[torch.DoubleTensor of dimension 10x1] + 1.3965 0.1575 0.5146 0.4491 -1.5244 0.6580 -0.9540 0.1784 0.4256 -1.7362 +[torch.DoubleTensor of dimension 5x2] 1.3965 0.5146 @@ -244,6 +311,39 @@ gives the output: [torch.Tensor of dimension 7x1] ``` +Example 2: +```lua +module = nn.JoinTable(2) +module:setNumInputDims(2) + +x=torch.randn(3,1) +y=torch.randn(3,1) + +mx=torch.randn(2,3,1) +my=torch.randn(2,3,1) + +print(module:forward{x,y}) +print(module:forward{mx,my}) +``` +gives the output: +```lua + 0.4288 1.2002 +-1.4084 -0.7960 +-0.2091 0.1852 +[torch.DoubleTensor of dimension 3x2] + +(1,.,.) = + 0.5561 0.1228 + -0.6792 0.1153 + 0.0687 0.2955 + +(2,.,.) = + 2.5787 1.8185 + -0.9860 0.6756 + 0.1989 -0.4327 +[torch.DoubleTensor of dimension 2x3x2] +``` + A more complicated example: ```lua diff --git a/test/test.lua b/test/test.lua index be17fd7..2db6f2d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1790,6 +1790,45 @@ function nntest.LookupTable() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.JoinTable() + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim() do + module = nn.JoinTable(d) + mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d) + end + + -- Minibatch + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim()-1 do + module = nn.JoinTable(d) + module:setNumInputDims(2) + mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d) + end +end + +function nntest.SplitTable() + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim() do + module = nn.SplitTable(d) + mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d) + end + + -- Minibatch + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim()-1 do + module = nn.SplitTable(d) + module:setNumInputDims(2) + mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d) + end +end + + mytester:add(nntest) if not nn then |