Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2015-12-24 16:40:21 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2015-12-25 16:01:28 +0300
commit8021a68c445969d3afe7c03f6711a49d4454f1bf (patch)
tree4c1452d7b43042ce8d6aa454def8f6f749793c33
parentefefd6bbd1526468acce4ff79eb7c83c00e2d773 (diff)
make avg-pooling swappable with nn
-rw-r--r--SpatialAveragePooling.lua28
1 files changed, 25 insertions, 3 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua
index e94affe..d25138d 100644
--- a/SpatialAveragePooling.lua
+++ b/SpatialAveragePooling.lua
@@ -1,9 +1,31 @@
local SpatialAveragePooling, parent
= torch.class('cudnn.SpatialAveragePooling', 'cudnn._Pooling')
-function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH)
- parent.__init(self, kW, kH, dW, dH, padW, padH)
- self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
+local function backwardCompatible(self)
+ if self.ceil_mode == nil then
+ self.ceil_mode = false
+ self.count_include_pad = true
+ self.padH = 0
+ self.padW = 0
+ end
+end
+
+function SpatialAveragePooling:updateOutput(input)
+ -- for nn <> cudnn conversion
+ backwardCompatible(self)
+ if self.divide ~= nil then
+ assert(self.divide, 'not supported')
+ end
+
+ self.count_include_pad = self.count_include_pad ~= nil and
+ self.count_include_pad or true
+ if self.count_include_pad then
+ self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
+ else
+ error'This mode is untested in cudnn'
+ self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
+ end
+ return parent.updateOutput(self, input)
end
function SpatialAveragePooling:__tostring__()