diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-27 21:37:23 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-27 21:37:23 +0400 |
commit | a90d4843629736e1aabdd6ca7438b295415bd8b4 (patch) | |
tree | 2bb6f083cf8dd864350a774117ace9ca84ad67ed /SpatialAveragePooling.lua | |
parent | 2da6353dc28457c5b8dd758d552026c9aebdbcba (diff) |
SpatialAveragePooling binding
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r-- | SpatialAveragePooling.lua | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua new file mode 100644 index 0000000..3e09120 --- /dev/null +++ b/SpatialAveragePooling.lua @@ -0,0 +1,20 @@ +local SpatialAveragePooling, parent = torch.class('cudnn.SpatialAveragePooling', 'cudnn.SpatialMaxPooling') +local ffi = require 'ffi' +local C = cudnn.C +local errcheck = cudnn.errcheck + +function SpatialAveragePooling:__init(kW, kH, dW, dH) + parent.__init(self, kW, kH, dW, dH) +end + +function SpatialAveragePooling:resetPoolDescriptors() + -- create pooling descriptor + self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') + errcheck('cudnnCreatePoolingDescriptor', self.poolDesc) + errcheck('cudnnSetPoolingDescriptor', self.poolDesc[0], 'CUDNN_POOLING_AVERAGE', + self.kH, self.kW, self.dH, self.dW); + local function destroyPoolDesc(d) + errcheck('cudnnDestroyPoolingDescriptor', d[0]); + end + ffi.gc(self.poolDesc, destroyPoolDesc) +end |