diff options
author | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-30 19:39:16 +0400 |
---|---|---|
committer | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-30 19:42:18 +0400 |
commit | ddfe535a7900939136f40773847f5b1152643a2a (patch) | |
tree | b610aa66a7906d54c5e2daa76d0e20aebb81ae70 /View.lua | |
parent | 814fba19f5c2168e29a377941b4711bfa39665e8 (diff) |
Add setNumInputDims() method to nn.View
This allows to use minibatches when using -1 for one of the sizes.
Diffstat (limited to 'View.lua')
-rw-r--r-- | View.lua | 19 |
1 files changed, 14 insertions, 5 deletions
@@ -9,21 +9,30 @@ function View:__init(...) assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage") self.numElements = 1 for i = 1,#self.size do - assert(self.size[i]>0, "Only positive sizes are allowed") self.numElements = self.numElements * self.size[i] end self.output = nil self.gradInput = nil + self.numInputDims = nil end -local function isMinibatch(input, numElements) - return input:dim() > 1 and - input:nElement()/input:size(1) == numElements +function View:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +local function isMinibatch(input, numInputDims, numElements) + if numInputDims then + return input:dim() == numInputDims+1 + else + return input:dim() > 1 and + input:nElement()/input:size(1) == numElements + end end function View:updateOutput(input) - if isMinibatch(input, self.numElements) then + if isMinibatch(input, self.numInputDims, self.numElements) then self.output = input:view(input:size(1), unpack(self.size:totable())) else self.output = input:view(self.size) |