Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-27 21:37:23 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-27 21:37:23 +0400
commita90d4843629736e1aabdd6ca7438b295415bd8b4 (patch)
tree2bb6f083cf8dd864350a774117ace9ca84ad67ed /SpatialAveragePooling.lua
parent2da6353dc28457c5b8dd758d552026c9aebdbcba (diff)
SpatialAveragePooling binding
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r--SpatialAveragePooling.lua20
1 files changed, 20 insertions, 0 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua
new file mode 100644
index 0000000..3e09120
--- /dev/null
+++ b/SpatialAveragePooling.lua
@@ -0,0 +1,20 @@
+local SpatialAveragePooling, parent = torch.class('cudnn.SpatialAveragePooling', 'cudnn.SpatialMaxPooling')
+local ffi = require 'ffi'
+local C = cudnn.C
+local errcheck = cudnn.errcheck
+
+function SpatialAveragePooling:__init(kW, kH, dW, dH)
+ parent.__init(self, kW, kH, dW, dH)
+end
+
+function SpatialAveragePooling:resetPoolDescriptors()
+ -- create pooling descriptor
+ self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]')
+ errcheck('cudnnCreatePoolingDescriptor', self.poolDesc)
+ errcheck('cudnnSetPoolingDescriptor', self.poolDesc[0], 'CUDNN_POOLING_AVERAGE',
+ self.kH, self.kW, self.dH, self.dW);
+ local function destroyPoolDesc(d)
+ errcheck('cudnnDestroyPoolingDescriptor', d[0]);
+ end
+ ffi.gc(self.poolDesc, destroyPoolDesc)
+end