diff options
-rw-r--r-- | View.lua | 19 | ||||
-rw-r--r-- | doc/simple.md | 31 | ||||
-rw-r--r-- | test/test.lua | 7 |
3 files changed, 43 insertions, 14 deletions
@@ -9,21 +9,30 @@ function View:__init(...) assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage") self.numElements = 1 for i = 1,#self.size do - assert(self.size[i]>0, "Only positive sizes are allowed") self.numElements = self.numElements * self.size[i] end self.output = nil self.gradInput = nil + self.numInputDims = nil end -local function isMinibatch(input, numElements) - return input:dim() > 1 and - input:nElement()/input:size(1) == numElements +function View:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +local function isMinibatch(input, numInputDims, numElements) + if numInputDims then + return input:dim() == numInputDims+1 + else + return input:dim() > 1 and + input:nElement()/input:size(1) == numElements + end end function View:updateOutput(input) - if isMinibatch(input, self.numElements) then + if isMinibatch(input, self.numInputDims, self.numElements) then self.output = input:view(input:size(1), unpack(self.size:totable())) else self.output = input:view(self.size) diff --git a/doc/simple.md b/doc/simple.md index 6f5b13d..ad883b7 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -478,7 +478,11 @@ Example: This module creates a new view of the input tensor using the `sizes` passed to the constructor. The parameter `sizes` can either be a `LongStorage` or numbers. -Example: +The method `setNumInputDims()` allows to specify the expected number of dimensions +of the inputs of the modules. This makes it possible to use minibatch inputs when +using a size -1 for one of the dimensions. + +Example 1: ```lua > x=torch.Tensor(4,4) > for i=1,4 do @@ -488,14 +492,6 @@ Example: > end > print(x) -x=torch.Tensor(4,4) -for i=1,4 do - for j=1,4 do - x[i][j]=(i-1)*4+j; - end -end -print(x) - 1 2 3 4 5 6 7 8 9 10 11 12 @@ -541,6 +537,23 @@ print(x) [torch.DoubleTensor of dimension 16] ``` +Example 2: +```lua +> input = torch.Tensor(2,3) +> minibatch = torch.Tensor(5,2,3) +> m = nn.View(-1):setNumInputDims(2) +> print(#m:forward(input)) + + 6 +[torch.LongStorage of size 2] + +> print(#m:forward(minibatch)) + + 5 + 6 +[torch.LongStorage of size 2] + +``` <a name="nn.Select"/> ## Select ## diff --git a/test/test.lua b/test/test.lua index 5e4bce7..88adf96 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1884,6 +1884,13 @@ function nntest.View() mytester:assertTableEq(module:forward(minibatch):nElement(), minibatch:nElement(), "Error in minibatch nElement") + local module = nn.View(-1):setNumInputDims(1) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension with size -1") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement with size -1") end -- Define a test for SpatialUpSamplingCuda |