Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvgire <vincent.gire@gmail.com>2015-11-09 21:08:14 +0300
committervgire <vincent.gire@gmail.com>2015-11-10 04:06:59 +0300
commitbeb3ba0348c8a5a59776d3dff23a8c316c9b73ae (patch)
treeb15b3d7852bbfaf50e7379b77b3eb678f083ef00 /Max.lua
parentb9782bae6feb3307c8d79df8f4ff8d830bdad24a (diff)
Add support for negative dimension and both batch and non batch inputs for nn.Min, nn.Max and nn.Mean
Diffstat (limited to 'Max.lua')
-rw-r--r--Max.lua24
1 files changed, 19 insertions, 5 deletions
diff --git a/Max.lua b/Max.lua
index 079081d..4b5e8f2 100644
--- a/Max.lua
+++ b/Max.lua
@@ -1,9 +1,21 @@
local Max, parent = torch.class('nn.Max', 'nn.Module')
-function Max:__init(dimension)
+function Max:__init(dimension, nInputDims)
parent.__init(self)
dimension = dimension or 1
self.dimension = dimension
+ -- do not assign default value to nInputDims or it will break backward compatibility
+ self.nInputDims = nInputDims
+end
+
+function Max:_getPositiveDimension(input)
+ local dimension = self.dimension
+ if dimension < 0 then
+ dimension = input:dim() + dimension + 1
+ elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ return dimension
end
function Max:_lazyInit()
@@ -14,9 +26,10 @@ end
function Max:updateOutput(input)
self:_lazyInit()
- torch.max(self._output, self._indices, input, self.dimension)
+ local dimension = self:_getPositiveDimension(input)
+ torch.max(self._output, self._indices, input, dimension)
if input:dim() > 1 then
- self.output = self._output:select(self.dimension, 1)
+ self.output = self._output:select(dimension, 1)
else
self.output = self._output
end
@@ -25,13 +38,14 @@ end
function Max:updateGradInput(input, gradOutput)
self:_lazyInit()
+ local dimension = self:_getPositiveDimension(input)
local gradOutputView
if input:dim() > 1 then
- gradOutputView = nn.utils.addSingletonDimension(gradOutput, self.dimension)
+ gradOutputView = nn.utils.addSingletonDimension(gradOutput, dimension)
else
gradOutputView = gradOutput
end
- self.gradInput:resizeAs(input):zero():scatter(self.dimension, self._indices, gradOutputView)
+ self.gradInput:resizeAs(input):zero():scatter(dimension, self._indices, gradOutputView)
return self.gradInput
end