diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-04-22 23:49:13 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-04-22 23:49:13 +0300 |
commit | 394554a8440725e5fd53664fbff675ee567d0fae (patch) | |
tree | 5426016aa2f3f4c798e845c77aed79bfb26eb639 | |
parent | dbfeab9b7d48c0c6cd7024dd36b8d2badd275759 (diff) | |
parent | 0063dbdefd77182c9235315d5574caf7bbf25147 (diff) |
Merge pull request #240 from torch/fixview
fix View (hopefuly)
-rw-r--r-- | View.lua | 69 | ||||
-rw-r--r-- | test.lua | 11 |
2 files changed, 48 insertions, 32 deletions
@@ -2,14 +2,23 @@ local View, parent = torch.class('nn.View', 'nn.Module') function View:__init(...) parent.__init(self) - self.size = ... - if select('#', ...) > 1 or type(self.size) == "number" then + if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then + self.size = select(1, ...) + else self.size = torch.LongStorage({...}) end - assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage") + self.numElements = 1 + local inferdim = false for i = 1,#self.size do - self.numElements = self.numElements * self.size[i] + local szi = self.size[i] + if szi >= 0 then + self.numElements = self.numElements * self.size[i] + else + assert(szi == -1, 'size should be positive or -1') + assert(not inferdim, 'only one dimension can be at -1') + inferdim = true + end end self.output = nil @@ -23,42 +32,38 @@ function View:setNumInputDims(numInputDims) end local function batchsize(input, size, numInputDims, numElements) - - -- handle special vector case - if size:size() == 1 and size[1] == -1 then - if numInputDims then - numElements = 1 - local dim = input:nDimension() - for i=1,numInputDims do - numElements = numElements * input:size(dim-i+1) - end - else - numElements = input:nElement() - end - size = torch.LongStorage{numElements} - end - - -- find if number of elements is divisible with desired number - local ine = input:nElement() - local dim = 0 - local bsz = 1 - while ine > numElements do - dim = dim + 1 - local dimsz = input:size(dim) - if ine % numElements == 0 then - dimsz = math.min(ine/numElements, dimsz) - end - ine = ine / dimsz - bsz = bsz * dimsz + local ind = input:nDimension() + local isz = input:size() + local maxdim = numInputDims and numInputDims or ind + local ine = 1 + for i=ind,ind-maxdim+1,-1 do + ine = ine * isz[i] end - if ine ~= numElements then + if ine % numElements ~= 0 then error(string.format( 'input view (%s) and desired view (%s) do not match', table.concat(input:size():totable(), 'x'), table.concat(size:totable(), 'x'))) end + -- the remainder is either the batch... + local bsz = ine / numElements + + -- ... or the missing size dim + for i=1,size:size() do + if size[i] == -1 then + bsz = 1 + break + end + end + + -- for dim over maxdim, it is definitively the batch + for i=ind-maxdim,1,-1 do + bsz = bsz * isz[i] + end + + -- special card if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then return end @@ -2872,6 +2872,17 @@ function nntest.View() minibatch:size(1), "Error in minibatch dimension with size -1") + -- another setNumInputDims case + local minibatch = torch.rand(2,5,4,10) + local module = nn.View(4,-1):setNumInputDims(2) + local out = module:forward(minibatch) + mytester:assertTableEq(out:size(1), minibatch:size(1)*minibatch:size(2), + "Error in minibatch dimension with size -1") + mytester:assertTableEq(out:size(2), minibatch:size(3), + "Error in minibatch dimension with size -1") + mytester:assertTableEq(out:size(3), minibatch:size(4), + "Error in minibatch dimension with size -1") + -- Minibatch Generalization local minibatch = torch.rand(5,2,6) local module = nn.View(6) |