Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialAveragePooling.lua - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 21fd334e27a8dbdeeb4b69249884a8e214b72922 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
local SpatialAveragePooling, parent
= torch.class('cudnn.SpatialAveragePooling', 'cudnn._Pooling')

local function backwardCompatible(self)
   if self.ceil_mode == nil then
      self.ceil_mode = false
      self.count_include_pad = true
      self.padH = 0
      self.padW = 0
   end
end

function SpatialAveragePooling:updateOutput(input)
   -- for nn <> cudnn conversion
   backwardCompatible(self)
   if self.divide ~= nil then
      assert(self.divide, 'not supported')
   end

   self.count_include_pad = self.count_include_pad ~= nil and
      self.count_include_pad or true
   if self.count_include_pad then
      self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
   else
      error'This mode is untested in cudnn'
      self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
   end
   return parent.updateOutput(self, input)
end

function SpatialAveragePooling:__tostring__()
   return nn.SpatialAveragePooling.__tostring__(self)
end