diff options
author | vgire <vincent.gire@gmail.com> | 2015-11-09 21:08:14 +0300 |
---|---|---|
committer | vgire <vincent.gire@gmail.com> | 2015-11-10 04:06:59 +0300 |
commit | beb3ba0348c8a5a59776d3dff23a8c316c9b73ae (patch) | |
tree | b15b3d7852bbfaf50e7379b77b3eb678f083ef00 /Max.lua | |
parent | b9782bae6feb3307c8d79df8f4ff8d830bdad24a (diff) |
Add support for negative dimension and both batch and non batch inputs for nn.Min, nn.Max and nn.Mean
Diffstat (limited to 'Max.lua')
-rw-r--r-- | Max.lua | 24 |
1 files changed, 19 insertions, 5 deletions
@@ -1,9 +1,21 @@ local Max, parent = torch.class('nn.Max', 'nn.Module') -function Max:__init(dimension) +function Max:__init(dimension, nInputDims) parent.__init(self) dimension = dimension or 1 self.dimension = dimension + -- do not assign default value to nInputDims or it will break backward compatibility + self.nInputDims = nInputDims +end + +function Max:_getPositiveDimension(input) + local dimension = self.dimension + if dimension < 0 then + dimension = input:dim() + dimension + 1 + elseif self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + return dimension end function Max:_lazyInit() @@ -14,9 +26,10 @@ end function Max:updateOutput(input) self:_lazyInit() - torch.max(self._output, self._indices, input, self.dimension) + local dimension = self:_getPositiveDimension(input) + torch.max(self._output, self._indices, input, dimension) if input:dim() > 1 then - self.output = self._output:select(self.dimension, 1) + self.output = self._output:select(dimension, 1) else self.output = self._output end @@ -25,13 +38,14 @@ end function Max:updateGradInput(input, gradOutput) self:_lazyInit() + local dimension = self:_getPositiveDimension(input) local gradOutputView if input:dim() > 1 then - gradOutputView = nn.utils.addSingletonDimension(gradOutput, self.dimension) + gradOutputView = nn.utils.addSingletonDimension(gradOutput, dimension) else gradOutputView = gradOutput end - self.gradInput:resizeAs(input):zero():scatter(self.dimension, self._indices, gradOutputView) + self.gradInput:resizeAs(input):zero():scatter(dimension, self._indices, gradOutputView) return self.gradInput end |