diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-04-22 08:11:04 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-04-22 08:11:04 +0300 |
commit | 0063dbdefd77182c9235315d5574caf7bbf25147 (patch) | |
tree | 13846859c46423ba96c9ec02364c3a73eab8ba64 /View.lua | |
parent | 418624f67da0c61dd2a7205373e3ebe816a94aae (diff) |
fix View (hopefuly)
now handles better -1 cases
Diffstat (limited to 'View.lua')
-rw-r--r-- | View.lua | 69 |
1 files changed, 37 insertions, 32 deletions
@@ -2,14 +2,23 @@ local View, parent = torch.class('nn.View', 'nn.Module') function View:__init(...) parent.__init(self) - self.size = ... - if select('#', ...) > 1 or type(self.size) == "number" then + if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then + self.size = select(1, ...) + else self.size = torch.LongStorage({...}) end - assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage") + self.numElements = 1 + local inferdim = false for i = 1,#self.size do - self.numElements = self.numElements * self.size[i] + local szi = self.size[i] + if szi >= 0 then + self.numElements = self.numElements * self.size[i] + else + assert(szi == -1, 'size should be positive or -1') + assert(not inferdim, 'only one dimension can be at -1') + inferdim = true + end end self.output = nil @@ -23,42 +32,38 @@ function View:setNumInputDims(numInputDims) end local function batchsize(input, size, numInputDims, numElements) - - -- handle special vector case - if size:size() == 1 and size[1] == -1 then - if numInputDims then - numElements = 1 - local dim = input:nDimension() - for i=1,numInputDims do - numElements = numElements * input:size(dim-i+1) - end - else - numElements = input:nElement() - end - size = torch.LongStorage{numElements} - end - - -- find if number of elements is divisible with desired number - local ine = input:nElement() - local dim = 0 - local bsz = 1 - while ine > numElements do - dim = dim + 1 - local dimsz = input:size(dim) - if ine % numElements == 0 then - dimsz = math.min(ine/numElements, dimsz) - end - ine = ine / dimsz - bsz = bsz * dimsz + local ind = input:nDimension() + local isz = input:size() + local maxdim = numInputDims and numInputDims or ind + local ine = 1 + for i=ind,ind-maxdim+1,-1 do + ine = ine * isz[i] end - if ine ~= numElements then + if ine % numElements ~= 0 then error(string.format( 'input view (%s) and desired view (%s) do not match', table.concat(input:size():totable(), 'x'), table.concat(size:totable(), 'x'))) end + -- the remainder is either the batch... + local bsz = ine / numElements + + -- ... or the missing size dim + for i=1,size:size() do + if size[i] == -1 then + bsz = 1 + break + end + end + + -- for dim over maxdim, it is definitively the batch + for i=ind-maxdim,1,-1 do + bsz = bsz * isz[i] + end + + -- special card if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then return end |