local Max, parent = torch.class('nn.Max', 'nn.Module') function Max:__init(dimension, nInputDims) parent.__init(self) dimension = dimension or 1 self.dimension = dimension -- do not assign default value to nInputDims or it will break backward compatibility self.nInputDims = nInputDims end function Max:_getPositiveDimension(input) local dimension = self.dimension if dimension < 0 then dimension = input:dim() + dimension + 1 elseif self.nInputDims and input:dim()==(self.nInputDims+1) then dimension = dimension + 1 end return dimension end function Max:_lazyInit() self._output = self._output or self.output.new() if not self._indices then if torch.typename(self.output):find('torch%.Cuda.*Tensor') then self._indices = torch.CudaLongTensor and torch.CudaLongTensor() or torch.CudaTensor() else self._indices = torch.LongTensor() end end end function Max:updateOutput(input) self:_lazyInit() local dimension = self:_getPositiveDimension(input) torch.max(self._output, self._indices, input, dimension) if input:dim() > 1 then self.output:set(self._output:select(dimension, 1)) else self.output:set(self._output) end return self.output end function Max:updateGradInput(input, gradOutput) self:_lazyInit() local dimension = self:_getPositiveDimension(input) local gradOutputView if input:dim() > 1 then gradOutputView = nn.utils.addSingletonDimension(gradOutput, dimension) else gradOutputView = gradOutput end self.gradInput:resizeAs(input):zero():scatter(dimension, self._indices, gradOutputView) return self.gradInput end function Max:type(type, tensorCache) self._indices = nil parent.type(self, type, tensorCache) return self end function Max:clearState() nn.utils.clear(self, '_indices', '_output') return parent.clearState(self) end