Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Gomez <sergomezcol@gmail.com>2014-06-30 19:39:16 +0400
committerSergio Gomez <sergomezcol@gmail.com>2014-06-30 19:42:18 +0400
commitddfe535a7900939136f40773847f5b1152643a2a (patch)
treeb610aa66a7906d54c5e2daa76d0e20aebb81ae70 /View.lua
parent814fba19f5c2168e29a377941b4711bfa39665e8 (diff)
Add setNumInputDims() method to nn.View
This allows to use minibatches when using -1 for one of the sizes.
Diffstat (limited to 'View.lua')
-rw-r--r--View.lua19
1 files changed, 14 insertions, 5 deletions
diff --git a/View.lua b/View.lua
index 351adc3..92842ff 100644
--- a/View.lua
+++ b/View.lua
@@ -9,21 +9,30 @@ function View:__init(...)
assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage")
self.numElements = 1
for i = 1,#self.size do
- assert(self.size[i]>0, "Only positive sizes are allowed")
self.numElements = self.numElements * self.size[i]
end
self.output = nil
self.gradInput = nil
+ self.numInputDims = nil
end
-local function isMinibatch(input, numElements)
- return input:dim() > 1 and
- input:nElement()/input:size(1) == numElements
+function View:setNumInputDims(numInputDims)
+ self.numInputDims = numInputDims
+ return self
+end
+
+local function isMinibatch(input, numInputDims, numElements)
+ if numInputDims then
+ return input:dim() == numInputDims+1
+ else
+ return input:dim() > 1 and
+ input:nElement()/input:size(1) == numElements
+ end
end
function View:updateOutput(input)
- if isMinibatch(input, self.numElements) then
+ if isMinibatch(input, self.numInputDims, self.numElements) then
self.output = input:view(input:size(1), unpack(self.size:totable()))
else
self.output = input:view(self.size)