diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-21 04:42:38 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-21 04:53:35 +0400 |
commit | d189525fac46fa122011318c87cf149e9ff19e19 (patch) | |
tree | 1ac297f42189709647b15daec3eabd6521de59f4 /SpatialMaxPooling.lua | |
parent | dcb0eee5c99ddb8264afa55f262ed23bb243281b (diff) |
fixed #3
Diffstat (limited to 'SpatialMaxPooling.lua')
-rw-r--r-- | SpatialMaxPooling.lua | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index bd9039d..5a330a7 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -11,6 +11,10 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH) self.dH = dH or kW self:cuda() self.iSize = torch.LongStorage(4):fill(0) + self:resetPoolDescriptors() +end + +function SpatialMaxPooling:resetPoolDescriptors() -- create pooling descriptor self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') errcheck('cudnnCreatePoolingDescriptor', self.poolDesc) @@ -20,7 +24,6 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH) errcheck('cudnnDestroyPoolingDescriptor', d[0]); end ffi.gc(self.poolDesc, destroyPoolDesc) - end function SpatialMaxPooling:createIODescriptors(input) @@ -42,6 +45,7 @@ end function SpatialMaxPooling:updateOutput(input) assert(input:dim() == 4 and input:isContiguous()); + if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], self.iDesc[0], input:data(), |