diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-26 17:19:53 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-26 17:19:53 +0400 |
commit | 896ad1c1bf5588b2944c79fb24a0aee1ae7db726 (patch) | |
tree | 2d23facf5bcc965dda0b6b02c31d5e7f0d7d6bfe | |
parent | 4a368c6616da3698fd419fc5c6a2adffcffe5ae0 (diff) | |
parent | 8fd02c336bc2d03aa1fe21fced69615fd1e7b99b (diff) |
Merge pull request #17 from sergomezcol/master
Add minibatch support for nn.JoinTable and nn.SplitTable
-rw-r--r-- | JoinTable.lua | 32 | ||||
-rw-r--r-- | SplitTable.lua | 24 | ||||
-rw-r--r-- | doc/table.md | 108 | ||||
-rw-r--r-- | test/test.lua | 37 |
4 files changed, 177 insertions, 24 deletions
diff --git a/JoinTable.lua b/JoinTable.lua index dc20246..d445bd2 100644 --- a/JoinTable.lua +++ b/JoinTable.lua @@ -1,20 +1,26 @@ local JoinTable, parent = torch.class('nn.JoinTable', 'nn.Module') -function JoinTable:__init(dimension) +function JoinTable:__init(dimension, nInputDims) parent.__init(self) self.size = torch.LongStorage() self.dimension = dimension self.gradInput = {} -end + self.nInputDims = nInputDims +end function JoinTable:updateOutput(input) + local dimension = self.dimension + if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + for i=1,#input do local currentOutput = input[i] if i == 1 then self.size:resize(currentOutput:dim()):copy(currentOutput:size()) else - self.size[self.dimension] = self.size[self.dimension] - + currentOutput:size(self.dimension) + self.size[dimension] = self.size[dimension] + + currentOutput:size(dimension) end end self.output:resize(self.size) @@ -22,15 +28,19 @@ function JoinTable:updateOutput(input) local offset = 1 for i=1,#input do local currentOutput = input[i] - self.output:narrow(self.dimension, offset, - currentOutput:size(self.dimension)):copy(currentOutput) - offset = offset + currentOutput:size(self.dimension) + self.output:narrow(dimension, offset, + currentOutput:size(dimension)):copy(currentOutput) + offset = offset + currentOutput:size(dimension) end return self.output - end function JoinTable:updateGradInput(input, gradOutput) + local dimension = self.dimension + if self.nInputDims and input[1]:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + for i=1,#input do if self.gradInput[i] == nil then self.gradInput[i] = input[i].new() @@ -41,10 +51,10 @@ function JoinTable:updateGradInput(input, gradOutput) local offset = 1 for i=1,#input do local currentOutput = input[i] - local currentGradInput = gradOutput:narrow(self.dimension, offset, - currentOutput:size(self.dimension)) + local currentGradInput = gradOutput:narrow(dimension, offset, + currentOutput:size(dimension)) self.gradInput[i]:copy(currentGradInput) - offset = offset + currentOutput:size(self.dimension) + offset = offset + currentOutput:size(dimension) end return self.gradInput end diff --git a/SplitTable.lua b/SplitTable.lua index d2c690e..0148c4e 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -1,30 +1,38 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module') -function SplitTable:__init(dimension) +function SplitTable:__init(dimension, nInputDims) parent.__init(self) - self.modules = {} + self.modules = {} self.dimension = dimension + self.nInputDims = nInputDims end function SplitTable:updateOutput(input) - local currentOutput= {}; - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local currentOutput= {} + local slices = input:size(dimension) for i=1,slices do - currentOutput[#currentOutput+1] = input:select(self.dimension,i) + currentOutput[#currentOutput+1] = input:select(dimension,i) end self.output = currentOutput return self.output end - function SplitTable:updateGradInput(input, gradOutput) - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local slices = input:size(dimension) self.gradInput:resizeAs(input) local offset = 1 for i=1,slices do local currentGradInput = gradOutput[i]; - self.gradInput:select(self.dimension,i):copy(currentGradInput) + self.gradInput:select(dimension,i):copy(currentGradInput) end return self.gradInput end diff --git a/doc/table.md b/doc/table.md index c55804a..4117117 100644 --- a/doc/table.md +++ b/doc/table.md @@ -93,11 +93,15 @@ which gives the output: <a name="nn.SplitTable"/> ## SplitTable ## -`module` = `SplitTable(dimension)` +`module` = `SplitTable(dimension, nInputDims)` Creates a module that takes a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor) as input and outputs several tables, splitting the Tensor along dimension `dimension`. +The optional parameter `nInputDims` allows to specify the number of dimensions that +this module will receive. This makes it possible to forward both minibatch and +non-minibatch tensors through the same module. + Example 1: ```lua mlp=nn.SplitTable(2) @@ -132,7 +136,7 @@ gives the output: Example 2: ```lua mlp=nn.SplitTable(1) -pred=mlp:forward(torch.randn(10,3)) +pred=mlp:forward(torch.randn(4,3)) for i,k in pairs(pred) do print(i,k); end ``` gives the output: @@ -162,6 +166,62 @@ gives the output: [torch.Tensor of dimension 3] ``` +Example 3: +```lua +mlp=nn.SplitTable(1,2) +pred=mlp:forward(torch.randn(2,4,3)) +for i,k in pairs(pred) do print(i,k); end +pred=mlp:forward(torch.randn(4,3)) +for i,k in pairs(pred) do print(i,k); end +``` +gives the output: +```lua +1 +-1.3533 0.7448 -0.8818 +-0.4521 -1.2463 0.0316 +[torch.DoubleTensor of dimension 2x3] + +2 + 0.1130 -1.3904 1.4620 + 0.6722 2.0910 -0.2466 +[torch.DoubleTensor of dimension 2x3] + +3 + 0.4672 -1.2738 1.1559 + 0.4664 0.0768 0.6243 +[torch.DoubleTensor of dimension 2x3] + +4 + 0.4194 1.2991 0.2241 + 2.9786 -0.6715 0.0393 +[torch.DoubleTensor of dimension 2x3] + + +1 +-1.8932 + 0.0516 +-0.6316 +[torch.DoubleTensor of dimension 3] + +2 +-0.3397 +-1.8881 +-0.0977 +[torch.DoubleTensor of dimension 3] + +3 + 0.0135 + 1.2089 + 0.5785 +[torch.DoubleTensor of dimension 3] + +4 +-0.1758 +-0.0776 +-1.1013 +[torch.DoubleTensor of dimension 3] +``` + A more complicated example: ```lua @@ -199,13 +259,17 @@ end <a name="nn.JoinTable"/> ## JoinTable ## -`module` = `JoinTable(dimension)` +`module` = `JoinTable(dimension, nInputDims)` Creates a module that takes a list of Tensors as input and outputs a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor) by joining them together along dimension `dimension`. -Example: +The optional parameter `nInputDims` allows to specify the number of dimensions that +this module will receive. This makes it possible to forward both minibatch and +non-minibatch tensors through the same module. + +Example 1: ```lua x=torch.randn(5,1) y=torch.randn(5,1) @@ -227,12 +291,14 @@ gives the output: 0.6580 0.1784 -1.7362 - +[torch.DoubleTensor of dimension 10x1] + 1.3965 0.1575 0.5146 0.4491 -1.5244 0.6580 -0.9540 0.1784 0.4256 -1.7362 +[torch.DoubleTensor of dimension 5x2] 1.3965 0.5146 @@ -244,6 +310,38 @@ gives the output: [torch.Tensor of dimension 7x1] ``` +Example 2: +```lua +module = nn.JoinTable(2,2) + +x=torch.randn(3,1) +y=torch.randn(3,1) + +mx=torch.randn(2,3,1) +my=torch.randn(2,3,1) + +print(module:forward{x,y}) +print(module:forward{mx,my}) +``` +gives the output: +```lua + 0.4288 1.2002 +-1.4084 -0.7960 +-0.2091 0.1852 +[torch.DoubleTensor of dimension 3x2] + +(1,.,.) = + 0.5561 0.1228 + -0.6792 0.1153 + 0.0687 0.2955 + +(2,.,.) = + 2.5787 1.8185 + -0.9860 0.6756 + 0.1989 -0.4327 +[torch.DoubleTensor of dimension 2x3x2] +``` + A more complicated example: ```lua diff --git a/test/test.lua b/test/test.lua index 775dded..3c4ec12 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1830,6 +1830,43 @@ function nntest.MulConstant() mytester:assertlt(err, precision, 'bprop error ') end +function nntest.JoinTable() + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim() do + module = nn.JoinTable(d) + mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d) + end + + -- Minibatch + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim()-1 do + module = nn.JoinTable(d, 2) + mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d) + end +end + +function nntest.SplitTable() + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim() do + module = nn.SplitTable(d) + mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d) + end + + -- Minibatch + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim()-1 do + module = nn.SplitTable(d, 2) + mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d) + end +end + + mytester:add(nntest) if not nn then |