Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-04-22 08:11:04 +0300
committerRonan Collobert <ronan@collobert.com>2015-04-22 08:11:04 +0300
commit0063dbdefd77182c9235315d5574caf7bbf25147 (patch)
tree13846859c46423ba96c9ec02364c3a73eab8ba64 /View.lua
parent418624f67da0c61dd2a7205373e3ebe816a94aae (diff)
fix View (hopefuly)
now handles better -1 cases
Diffstat (limited to 'View.lua')
-rw-r--r--View.lua69
1 files changed, 37 insertions, 32 deletions
diff --git a/View.lua b/View.lua
index a6785c8..766e149 100644
--- a/View.lua
+++ b/View.lua
@@ -2,14 +2,23 @@ local View, parent = torch.class('nn.View', 'nn.Module')
function View:__init(...)
parent.__init(self)
- self.size = ...
- if select('#', ...) > 1 or type(self.size) == "number" then
+ if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then
+ self.size = select(1, ...)
+ else
self.size = torch.LongStorage({...})
end
- assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage")
+
self.numElements = 1
+ local inferdim = false
for i = 1,#self.size do
- self.numElements = self.numElements * self.size[i]
+ local szi = self.size[i]
+ if szi >= 0 then
+ self.numElements = self.numElements * self.size[i]
+ else
+ assert(szi == -1, 'size should be positive or -1')
+ assert(not inferdim, 'only one dimension can be at -1')
+ inferdim = true
+ end
end
self.output = nil
@@ -23,42 +32,38 @@ function View:setNumInputDims(numInputDims)
end
local function batchsize(input, size, numInputDims, numElements)
-
- -- handle special vector case
- if size:size() == 1 and size[1] == -1 then
- if numInputDims then
- numElements = 1
- local dim = input:nDimension()
- for i=1,numInputDims do
- numElements = numElements * input:size(dim-i+1)
- end
- else
- numElements = input:nElement()
- end
- size = torch.LongStorage{numElements}
- end
-
- -- find if number of elements is divisible with desired number
- local ine = input:nElement()
- local dim = 0
- local bsz = 1
- while ine > numElements do
- dim = dim + 1
- local dimsz = input:size(dim)
- if ine % numElements == 0 then
- dimsz = math.min(ine/numElements, dimsz)
- end
- ine = ine / dimsz
- bsz = bsz * dimsz
+ local ind = input:nDimension()
+ local isz = input:size()
+ local maxdim = numInputDims and numInputDims or ind
+ local ine = 1
+ for i=ind,ind-maxdim+1,-1 do
+ ine = ine * isz[i]
end
- if ine ~= numElements then
+ if ine % numElements ~= 0 then
error(string.format(
'input view (%s) and desired view (%s) do not match',
table.concat(input:size():totable(), 'x'),
table.concat(size:totable(), 'x')))
end
+ -- the remainder is either the batch...
+ local bsz = ine / numElements
+
+ -- ... or the missing size dim
+ for i=1,size:size() do
+ if size[i] == -1 then
+ bsz = 1
+ break
+ end
+ end
+
+ -- for dim over maxdim, it is definitively the batch
+ for i=ind-maxdim,1,-1 do
+ bsz = bsz * isz[i]
+ end
+
+ -- special card
if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then
return
end