diff options
author | vgire <vincent.gire@gmail.com> | 2015-11-09 21:08:14 +0300 |
---|---|---|
committer | vgire <vincent.gire@gmail.com> | 2015-11-10 04:06:59 +0300 |
commit | beb3ba0348c8a5a59776d3dff23a8c316c9b73ae (patch) | |
tree | b15b3d7852bbfaf50e7379b77b3eb678f083ef00 /Mean.lua | |
parent | b9782bae6feb3307c8d79df8f4ff8d830bdad24a (diff) |
Add support for negative dimension and both batch and non batch inputs for nn.Min, nn.Max and nn.Mean
Diffstat (limited to 'Mean.lua')
-rw-r--r-- | Mean.lua | 25 |
1 files changed, 19 insertions, 6 deletions
@@ -1,27 +1,40 @@ local Mean, parent = torch.class('nn.Mean', 'nn.Module') -function Mean:__init(dimension) +function Mean:__init(dimension, nInputDims) parent.__init(self) dimension = dimension or 1 self.dimension = dimension + -- do not assign default value to nInputDims or it will break backward compatibility + self.nInputDims = nInputDims self._gradInput = torch.Tensor() end +function Mean:_getPositiveDimension(input) + local dimension = self.dimension + if dimension < 0 then + dimension = input:dim() + dimension + 1 + elseif self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + return dimension +end + function Mean:updateOutput(input) - self.output:mean(input, self.dimension) + local dimension = self:_getPositiveDimension(input) + self.output:mean(input, dimension) if self.output:nDimension() > 1 then - self.output = self.output:select(self.dimension, 1) + self.output = self.output:select(dimension, 1) end return self.output end function Mean:updateGradInput(input, gradOutput) + local dimension = self:_getPositiveDimension(input) self._gradInput:resizeAs(gradOutput):copy(gradOutput) - self._gradInput:mul(1/input:size(self.dimension)) + self._gradInput:mul(1/input:size(dimension)) if input:nDimension() > 1 then - self._gradInput = nn.utils.addSingletonDimension(self._gradInput, - self.dimension) + self._gradInput = nn.utils.addSingletonDimension(self._gradInput, dimension) end self.gradInput = self._gradInput:expandAs(input) return self.gradInput |