Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-26 01:13:21 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-26 01:13:21 +0400
commit2e010324c996dde64d0e3c437ea941b6a591069f (patch)
tree7bbbeb9edb375730a796443a74a7b40bd2066a09 /Reshape.lua
parent590a77573fee782060177adfcd0afc97d3c30521 (diff)
added optional last argument batchMode to nn.Reshape
Diffstat (limited to 'Reshape.lua')
-rw-r--r--Reshape.lua9
1 files changed, 8 insertions, 1 deletions
diff --git a/Reshape.lua b/Reshape.lua
index 3b988ac..0baab17 100644
--- a/Reshape.lua
+++ b/Reshape.lua
@@ -6,6 +6,10 @@ function Reshape:__init(...)
self.size = torch.LongStorage()
self.batchsize = torch.LongStorage()
+ if torch.type(arg[#arg]) == 'boolean' then
+ self.batchMode = arg[#arg]
+ table.remove(arg, #arg)
+ end
local n = #arg
if n == 1 and torch.typename(arg[1]) == 'torch.LongStorage' then
self.size:resize(#arg[1]):copy(arg[1])
@@ -35,7 +39,10 @@ function Reshape:updateOutput(input)
input = self._input
end
- if input:nElement() == self.nelement then
+ if (self.batchMode == false) or (
+ (self.batchMode == nil) and
+ (input:nElement() == self.nelement and input:size(1) ~= 1)
+ ) then
self.output:view(input, self.size)
else
self.batchsize[1] = input:size(1)