Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-21 04:42:38 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-21 04:53:35 +0400
commitd189525fac46fa122011318c87cf149e9ff19e19 (patch)
tree1ac297f42189709647b15daec3eabd6521de59f4 /SpatialMaxPooling.lua
parentdcb0eee5c99ddb8264afa55f262ed23bb243281b (diff)
fixed #3
Diffstat (limited to 'SpatialMaxPooling.lua')
-rw-r--r--SpatialMaxPooling.lua6
1 files changed, 5 insertions, 1 deletions
diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua
index bd9039d..5a330a7 100644
--- a/SpatialMaxPooling.lua
+++ b/SpatialMaxPooling.lua
@@ -11,6 +11,10 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH)
self.dH = dH or kW
self:cuda()
self.iSize = torch.LongStorage(4):fill(0)
+ self:resetPoolDescriptors()
+end
+
+function SpatialMaxPooling:resetPoolDescriptors()
-- create pooling descriptor
self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]')
errcheck('cudnnCreatePoolingDescriptor', self.poolDesc)
@@ -20,7 +24,6 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
end
ffi.gc(self.poolDesc, destroyPoolDesc)
-
end
function SpatialMaxPooling:createIODescriptors(input)
@@ -42,6 +45,7 @@ end
function SpatialMaxPooling:updateOutput(input)
assert(input:dim() == 4 and input:isContiguous());
+ if not self.poolDesc then self:resetPoolDescriptors() end
self:createIODescriptors(input)
errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0],
self.iDesc[0], input:data(),