Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2015-09-07 00:08:28 +0300
committerFrancisco Massa <fvsmassa@gmail.com>2015-12-15 02:02:42 +0300
commit738057ed24c83a47242ab4eb70c30335add804ac (patch)
treed934f43292171b7cc0c4330a674277c7999bb18e /SpatialAveragePooling.lua
parent06e2d8c106237a71fa89b3b5c37edab501e1dffa (diff)
SpatialAveragePooling supports padding, ceil mode and exclude_pad division
Generalizes SpatialAveragePooling. When using padding or ceil mode, the number of elements in a region might be different from . The method considers only the number of elements in the pooling region for the averaging, whereas always divides by .
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r--SpatialAveragePooling.lua47
1 files changed, 43 insertions, 4 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua
index 88603a7..d8ef41f 100644
--- a/SpatialAveragePooling.lua
+++ b/SpatialAveragePooling.lua
@@ -1,16 +1,50 @@
local SpatialAveragePooling, parent = torch.class('nn.SpatialAveragePooling', 'nn.Module')
-function SpatialAveragePooling:__init(kW, kH, dW, dH)
+function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH)
parent.__init(self)
self.kW = kW
self.kH = kH
self.dW = dW or 1
self.dH = dH or 1
+ self.padW = padW or 0
+ self.padH = padH or 0
+ self.ceil_mode = false
+ self.count_include_pad = true
self.divide = true
end
+function SpatialAveragePooling:ceil()
+ self.ceil_mode = true
+ return self
+end
+
+function SpatialAveragePooling:floor()
+ self.ceil_mode = false
+ return self
+end
+
+function SpatialAveragePooling:setCountIncludePad()
+ self.count_include_pad = true
+ return self
+end
+
+function SpatialAveragePooling:setCountExcludePad()
+ self.count_include_pad = false
+ return self
+end
+
+local function backwardCompatible(self)
+ if self.ceil_mode == nil then
+ self.ceil_mode = false
+ self.count_include_pad = true
+ self.padH = 0
+ self.padW = 0
+ end
+end
+
function SpatialAveragePooling:updateOutput(input)
+ backwardCompatible(self)
input.nn.SpatialAveragePooling_updateOutput(self, input)
-- for backward compatibility with saved models
-- which are not supposed to have "divide" field
@@ -25,13 +59,18 @@ function SpatialAveragePooling:updateGradInput(input, gradOutput)
input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
-- for backward compatibility
if not self.divide then
- self.gradInput:mul(self.kW*self.kH)
+ self.gradInput:mul(self.kW*self.kH)
end
return self.gradInput
end
end
function SpatialAveragePooling:__tostring__()
- return string.format('%s(%d,%d,%d,%d)', torch.type(self),
- self.kW, self.kH, self.dW, self.dH)
+ local s = string.format('%s(%d,%d,%d,%d', torch.type(self),
+ self.kW, self.kH, self.dW, self.dH)
+ if (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
+ s = s .. ',' .. self.padW .. ','.. self.padH
+ end
+ s = s .. ')'
+ return s
end