Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Max.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8273e808cb807630ab2e26f51a914e8224a4bdd5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
local Max, parent = torch.class('nn.Max', 'nn.Module')

function Max:__init(dimension, nInputDims)
   parent.__init(self)
   dimension = dimension or 1
   self.dimension = dimension
   -- do not assign default value to nInputDims or it will break backward compatibility
   self.nInputDims = nInputDims
end

function Max:_getPositiveDimension(input)
   local dimension = self.dimension
   if dimension < 0 then
      dimension = input:dim() + dimension + 1
   elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
      dimension = dimension + 1
   end
   return dimension
end

function Max:_lazyInit()
   self._output = self._output or self.output.new()
   if not self._indices then
      if torch.typename(self.output):find('torch%.Cuda.*Tensor') then
         self._indices = torch.CudaLongTensor and torch.CudaLongTensor() or torch.CudaTensor()
      else
         self._indices = torch.LongTensor()
      end
   end
end

function Max:updateOutput(input)
   self:_lazyInit()
   local dimension = self:_getPositiveDimension(input)
   torch.max(self._output, self._indices, input, dimension)
   if input:dim() > 1 then
     self.output:set(self._output:select(dimension, 1))
   else
     self.output:set(self._output)
   end
   return self.output
end

function Max:updateGradInput(input, gradOutput)
   self:_lazyInit()
   local dimension = self:_getPositiveDimension(input)
   local gradOutputView
   if input:dim() > 1 then
     gradOutputView = nn.utils.addSingletonDimension(gradOutput, dimension)
   else
     gradOutputView = gradOutput
   end
   self.gradInput:resizeAs(input):zero():scatter(dimension, self._indices, gradOutputView)
   return self.gradInput
end

function Max:type(type, tensorCache)
    self._indices = nil
    parent.type(self, type, tensorCache)
    return self
end

function Max:clearState()
   nn.utils.clear(self, '_indices', '_output')
   return parent.clearState(self)
end