diff options
-rw-r--r-- | View.lua | 18 | ||||
-rw-r--r-- | doc/containers.md | 2 | ||||
-rw-r--r-- | doc/simple.md | 31 | ||||
-rw-r--r-- | test/test.lua | 7 |
4 files changed, 44 insertions, 14 deletions
@@ -14,15 +14,25 @@ function View:__init(...) self.output = nil self.gradInput = nil + self.numInputDims = nil end -local function isMinibatch(input, numElements) - return input:dim() > 1 and - input:nElement()/input:size(1) == numElements +function View:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +local function isMinibatch(input, numInputDims, numElements) + if numInputDims then + return input:dim() == numInputDims+1 + else + return input:dim() > 1 and + input:nElement()/input:size(1) == numElements + end end function View:updateOutput(input) - if isMinibatch(input, self.numElements) then + if isMinibatch(input, self.numInputDims, self.numElements) then self.output = input:view(input:size(1), unpack(self.size:totable())) else self.output = input:view(self.size) diff --git a/doc/containers.md b/doc/containers.md index 484b781..e078676 100644 --- a/doc/containers.md +++ b/doc/containers.md @@ -121,4 +121,4 @@ While the above containers are used for manipulating input [Tensors](https://git * [ConcatTable](table.md#nn.ConcatTable) * [ParallelTable](table.md#nn.ParallelTable) -These, along with all other modules for manipulating tables can be found [here](doc/table.md). +These, along with all other modules for manipulating tables can be found [here](table.md). diff --git a/doc/simple.md b/doc/simple.md index 6f5b13d..ad883b7 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -478,7 +478,11 @@ Example: This module creates a new view of the input tensor using the `sizes` passed to the constructor. The parameter `sizes` can either be a `LongStorage` or numbers. -Example: +The method `setNumInputDims()` allows to specify the expected number of dimensions +of the inputs of the modules. This makes it possible to use minibatch inputs when +using a size -1 for one of the dimensions. + +Example 1: ```lua > x=torch.Tensor(4,4) > for i=1,4 do @@ -488,14 +492,6 @@ Example: > end > print(x) -x=torch.Tensor(4,4) -for i=1,4 do - for j=1,4 do - x[i][j]=(i-1)*4+j; - end -end -print(x) - 1 2 3 4 5 6 7 8 9 10 11 12 @@ -541,6 +537,23 @@ print(x) [torch.DoubleTensor of dimension 16] ``` +Example 2: +```lua +> input = torch.Tensor(2,3) +> minibatch = torch.Tensor(5,2,3) +> m = nn.View(-1):setNumInputDims(2) +> print(#m:forward(input)) + + 6 +[torch.LongStorage of size 2] + +> print(#m:forward(minibatch)) + + 5 + 6 +[torch.LongStorage of size 2] + +``` <a name="nn.Select"/> ## Select ## diff --git a/test/test.lua b/test/test.lua index 6832885..04b1dd8 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1887,6 +1887,13 @@ function nntest.View() mytester:assertTableEq(module:forward(minibatch):nElement(), minibatch:nElement(), "Error in minibatch nElement") + local module = nn.View(-1):setNumInputDims(1) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension with size -1") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement with size -1") end -- Define a test for SpatialUpSamplingCuda |