diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-26 01:13:21 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-26 01:13:21 +0400 |
commit | 2e010324c996dde64d0e3c437ea941b6a591069f (patch) | |
tree | 7bbbeb9edb375730a796443a74a7b40bd2066a09 /Reshape.lua | |
parent | 590a77573fee782060177adfcd0afc97d3c30521 (diff) |
added optional last argument batchMode to nn.Reshape
Diffstat (limited to 'Reshape.lua')
-rw-r--r-- | Reshape.lua | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/Reshape.lua b/Reshape.lua index 3b988ac..0baab17 100644 --- a/Reshape.lua +++ b/Reshape.lua @@ -6,6 +6,10 @@ function Reshape:__init(...) self.size = torch.LongStorage() self.batchsize = torch.LongStorage() + if torch.type(arg[#arg]) == 'boolean' then + self.batchMode = arg[#arg] + table.remove(arg, #arg) + end local n = #arg if n == 1 and torch.typename(arg[1]) == 'torch.LongStorage' then self.size:resize(#arg[1]):copy(arg[1]) @@ -35,7 +39,10 @@ function Reshape:updateOutput(input) input = self._input end - if input:nElement() == self.nelement then + if (self.batchMode == false) or ( + (self.batchMode == nil) and + (input:nElement() == self.nelement and input:size(1) ~= 1) + ) then self.output:view(input, self.size) else self.batchsize[1] = input:size(1) |