Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialFractionalMaxPooling.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 884751d4142541aa7ccee087d34f7ca80166428f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
local SpatialFractionalMaxPooling, parent =
   torch.class('nn.SpatialFractionalMaxPooling', 'nn.Module')

-- Usage:
-- nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, outW, outH)
--   the output should be the exact size (outH x outW)
-- nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, ratioW, ratioH)
--   the output should be the size (floor(inH x ratioH) x floor(inW x ratioW))
--   ratios are numbers between (0, 1) exclusive
function SpatialFractionalMaxPooling:__init(poolSizeW, poolSizeH, arg1, arg2)
   parent.__init(self)
   assert(poolSizeW >= 2)
   assert(poolSizeH >= 2)

   -- Pool size (how wide the pooling for each output unit is)
   self.poolSizeW = poolSizeW
   self.poolSizeH = poolSizeH

   -- Random samples are drawn for all
   -- batch * plane * (height, width; i.e., 2) points. This determines
   -- the 2d "pseudorandom" overlapping pooling regions for each
   -- (batch element x input plane). A new set of random samples is
   -- drawn every updateOutput call, unless we disable it via
   -- :fixPoolingRegions().
   self.randomSamples = nil

   -- Flag to disable re-generation of random samples for producing
   -- a new pooling. For testing purposes
   self.newRandomPool = false

   if arg1 >= 1 and arg2 >= 1 then
      -- Desired output size: the input tensor will determine the reduction
      -- ratio
      self.outW = arg1
      self.outH = arg2
   else
      -- Reduction ratio specified per each input
      -- This is the reduction ratio that we use
      self.ratioW = arg1
      self.ratioH = arg2

      -- The reduction ratio must be between 0 and 1
      assert(self.ratioW > 0 and self.ratioW < 1)
      assert(self.ratioH > 0 and self.ratioH < 1)
   end
end

function SpatialFractionalMaxPooling:getBufferSize_(input)
   local batchSize = 0
   local planeSize = 0

   if input:nDimension() == 3 then
      batchSize = 1
      planeSize = input:size(1)
   elseif input:nDimension() == 4 then
      batchSize = input:size(1)
      planeSize = input:size(2)
   else
      error('input must be dim 3 or 4')
   end

   return torch.LongStorage({batchSize, planeSize, 2})
end

function SpatialFractionalMaxPooling:initSampleBuffer_(input)
   local sampleBufferSize = self:getBufferSize_(input)

   if self.randomSamples == nil then
      self.randomSamples = input.new():resize(sampleBufferSize):uniform()
   elseif (self.randomSamples:size(1) ~= sampleBufferSize[1] or
           self.randomSamples:size(2) ~= sampleBufferSize[2]) then
      self.randomSamples:resize(sampleBufferSize):uniform()
   else
      if not self.newRandomPool then
         -- Create new pooling windows, since this is a subsequent call
         self.randomSamples:uniform()
      end
   end
end

function SpatialFractionalMaxPooling:getOutputSizes_(input)
   local outW = self.outW
   local outH = self.outH
   if self.ratioW ~= nil and self.ratioH ~= nil then
      if input:nDimension() == 4 then
         outW = math.floor(input:size(4) * self.ratioW)
         outH = math.floor(input:size(3) * self.ratioH)
      elseif input:nDimension() == 3 then
         outW = math.floor(input:size(3) * self.ratioW)
         outH = math.floor(input:size(2) * self.ratioH)
      else
         error('input must be dim 3 or 4')
      end

      -- Neither can be smaller than 1
      assert(outW > 0, 'reduction ratio or input width too small')
      assert(outH > 0, 'reduction ratio or input height too small')
   else
      assert(outW ~= nil and outH ~= nil)
   end

   return outW, outH
end

-- Call this to turn off regeneration of random pooling regions each
-- updateOutput call.
function SpatialFractionalMaxPooling:fixPoolingRegions(val)
   if val == nil then
      val = true
   end

   self.newRandomPool = val
   return self
end

function SpatialFractionalMaxPooling:updateOutput(input)
   self.indices = self.indices or torch.LongTensor()
   if torch.typename(input):find('torch%.Cuda.*Tensor') then
      self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
   else
      self.indices = self.indices:long()
   end
   self:initSampleBuffer_(input)
   local outW, outH = self:getOutputSizes_(input)

   input.THNN.SpatialFractionalMaxPooling_updateOutput(
      input:cdata(),
      self.output:cdata(),
      outW, outH, self.poolSizeW, self.poolSizeH,
      self.indices:cdata(), self.randomSamples:cdata())
   return self.output
end

function SpatialFractionalMaxPooling:updateGradInput(input, gradOutput)
   assert(self.randomSamples ~= nil,
          'must call updateOutput/forward first')

   local outW, outH = self:getOutputSizes_(input)

   input.THNN.SpatialFractionalMaxPooling_updateGradInput(
      input:cdata(),
      gradOutput:cdata(),
      self.gradInput:cdata(),
      outW, outH, self.poolSizeW, self.poolSizeH,
      self.indices:cdata())
   return self.gradInput
end

-- backward compat
function SpatialFractionalMaxPooling:empty()
   self:clearState()
end

function SpatialFractionalMaxPooling:clearState()
   self.indices = nil
   self.randomSamples = nil
   return parent.clearState(self)
end

function SpatialFractionalMaxPooling:__tostring__()
   return string.format('%s(%dx%d, %d,%d)', torch.type(self),
                        self.outW and self.outW or self.ratioW,
                        self.outH and self.outH or self.ratioH,
                        self.poolSizeW, self.poolSizeH)
end