diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-30 23:22:55 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-30 23:22:55 +0400 |
commit | f32f33b49680a6bda36e13c4a8ada2e11ef61496 (patch) | |
tree | 7c9d011b3a35e799d58e802c2e5e77a6eb351f73 | |
parent | 54a9476ca762ce647eb05437661a5a9415949e11 (diff) | |
parent | ddfe535a7900939136f40773847f5b1152643a2a (diff) |
Merge pull request #24 from sergomezcol/master
Add mini-batch support for nn.View
-rw-r--r-- | View.lua | 18 | ||||
-rw-r--r-- | doc/simple.md | 31 | ||||
-rw-r--r-- | test/test.lua | 7 |
3 files changed, 43 insertions, 13 deletions
@@ -14,15 +14,25 @@ function View:__init(...) self.output = nil self.gradInput = nil + self.numInputDims = nil end -local function isMinibatch(input, numElements) - return input:dim() > 1 and - input:nElement()/input:size(1) == numElements +function View:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +local function isMinibatch(input, numInputDims, numElements) + if numInputDims then + return input:dim() == numInputDims+1 + else + return input:dim() > 1 and + input:nElement()/input:size(1) == numElements + end end function View:updateOutput(input) - if isMinibatch(input, self.numElements) then + if isMinibatch(input, self.numInputDims, self.numElements) then self.output = input:view(input:size(1), unpack(self.size:totable())) else self.output = input:view(self.size) diff --git a/doc/simple.md b/doc/simple.md index 6f5b13d..ad883b7 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -478,7 +478,11 @@ Example: This module creates a new view of the input tensor using the `sizes` passed to the constructor. The parameter `sizes` can either be a `LongStorage` or numbers. -Example: +The method `setNumInputDims()` allows to specify the expected number of dimensions +of the inputs of the modules. This makes it possible to use minibatch inputs when +using a size -1 for one of the dimensions. + +Example 1: ```lua > x=torch.Tensor(4,4) > for i=1,4 do @@ -488,14 +492,6 @@ Example: > end > print(x) -x=torch.Tensor(4,4) -for i=1,4 do - for j=1,4 do - x[i][j]=(i-1)*4+j; - end -end -print(x) - 1 2 3 4 5 6 7 8 9 10 11 12 @@ -541,6 +537,23 @@ print(x) [torch.DoubleTensor of dimension 16] ``` +Example 2: +```lua +> input = torch.Tensor(2,3) +> minibatch = torch.Tensor(5,2,3) +> m = nn.View(-1):setNumInputDims(2) +> print(#m:forward(input)) + + 6 +[torch.LongStorage of size 2] + +> print(#m:forward(minibatch)) + + 5 + 6 +[torch.LongStorage of size 2] + +``` <a name="nn.Select"/> ## Select ## diff --git a/test/test.lua b/test/test.lua index 5e4bce7..88adf96 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1884,6 +1884,13 @@ function nntest.View() mytester:assertTableEq(module:forward(minibatch):nElement(), minibatch:nElement(), "Error in minibatch nElement") + local module = nn.View(-1):setNumInputDims(1) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension with size -1") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement with size -1") end -- Define a test for SpatialUpSamplingCuda |