diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-12-24 16:40:21 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-12-25 16:01:28 +0300 |
commit | 8021a68c445969d3afe7c03f6711a49d4454f1bf (patch) | |
tree | 4c1452d7b43042ce8d6aa454def8f6f749793c33 | |
parent | efefd6bbd1526468acce4ff79eb7c83c00e2d773 (diff) |
make avg-pooling swappable with nn
-rw-r--r-- | SpatialAveragePooling.lua | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index e94affe..d25138d 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -1,9 +1,31 @@ local SpatialAveragePooling, parent = torch.class('cudnn.SpatialAveragePooling', 'cudnn._Pooling') -function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH) - parent.__init(self, kW, kH, dW, dH, padW, padH) - self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' +local function backwardCompatible(self) + if self.ceil_mode == nil then + self.ceil_mode = false + self.count_include_pad = true + self.padH = 0 + self.padW = 0 + end +end + +function SpatialAveragePooling:updateOutput(input) + -- for nn <> cudnn conversion + backwardCompatible(self) + if self.divide ~= nil then + assert(self.divide, 'not supported') + end + + self.count_include_pad = self.count_include_pad ~= nil and + self.count_include_pad or true + if self.count_include_pad then + self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' + else + error'This mode is untested in cudnn' + self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING' + end + return parent.updateOutput(self, input) end function SpatialAveragePooling:__tostring__() |