blob: 21fd334e27a8dbdeeb4b69249884a8e214b72922 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
local SpatialAveragePooling, parent
= torch.class('cudnn.SpatialAveragePooling', 'cudnn._Pooling')
local function backwardCompatible(self)
if self.ceil_mode == nil then
self.ceil_mode = false
self.count_include_pad = true
self.padH = 0
self.padW = 0
end
end
function SpatialAveragePooling:updateOutput(input)
-- for nn <> cudnn conversion
backwardCompatible(self)
if self.divide ~= nil then
assert(self.divide, 'not supported')
end
self.count_include_pad = self.count_include_pad ~= nil and
self.count_include_pad or true
if self.count_include_pad then
self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
else
error'This mode is untested in cudnn'
self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
end
return parent.updateOutput(self, input)
end
function SpatialAveragePooling:__tostring__()
return nn.SpatialAveragePooling.__tostring__(self)
end
|