diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2015-09-07 00:08:28 +0300 |
---|---|---|
committer | Francisco Massa <fvsmassa@gmail.com> | 2015-12-15 02:02:42 +0300 |
commit | 738057ed24c83a47242ab4eb70c30335add804ac (patch) | |
tree | d934f43292171b7cc0c4330a674277c7999bb18e /SpatialAveragePooling.lua | |
parent | 06e2d8c106237a71fa89b3b5c37edab501e1dffa (diff) |
SpatialAveragePooling supports padding, ceil mode and exclude_pad division
Generalizes SpatialAveragePooling. When using padding or ceil mode, the number of elements in a region might be different from . The method considers only the number of elements in the pooling region for the averaging, whereas always divides by .
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r-- | SpatialAveragePooling.lua | 47 |
1 files changed, 43 insertions, 4 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index 88603a7..d8ef41f 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -1,16 +1,50 @@ local SpatialAveragePooling, parent = torch.class('nn.SpatialAveragePooling', 'nn.Module') -function SpatialAveragePooling:__init(kW, kH, dW, dH) +function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH) parent.__init(self) self.kW = kW self.kH = kH self.dW = dW or 1 self.dH = dH or 1 + self.padW = padW or 0 + self.padH = padH or 0 + self.ceil_mode = false + self.count_include_pad = true self.divide = true end +function SpatialAveragePooling:ceil() + self.ceil_mode = true + return self +end + +function SpatialAveragePooling:floor() + self.ceil_mode = false + return self +end + +function SpatialAveragePooling:setCountIncludePad() + self.count_include_pad = true + return self +end + +function SpatialAveragePooling:setCountExcludePad() + self.count_include_pad = false + return self +end + +local function backwardCompatible(self) + if self.ceil_mode == nil then + self.ceil_mode = false + self.count_include_pad = true + self.padH = 0 + self.padW = 0 + end +end + function SpatialAveragePooling:updateOutput(input) + backwardCompatible(self) input.nn.SpatialAveragePooling_updateOutput(self, input) -- for backward compatibility with saved models -- which are not supposed to have "divide" field @@ -25,13 +59,18 @@ function SpatialAveragePooling:updateGradInput(input, gradOutput) input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) -- for backward compatibility if not self.divide then - self.gradInput:mul(self.kW*self.kH) + self.gradInput:mul(self.kW*self.kH) end return self.gradInput end end function SpatialAveragePooling:__tostring__() - return string.format('%s(%d,%d,%d,%d)', torch.type(self), - self.kW, self.kH, self.dW, self.dH) + local s = string.format('%s(%d,%d,%d,%d', torch.type(self), + self.kW, self.kH, self.dW, self.dH) + if (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then + s = s .. ',' .. self.padW .. ','.. self.padH + end + s = s .. ')' + return s end |