diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-19 02:36:12 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-19 23:28:13 +0300 |
commit | ad958c0e268d876ee4d713510b8c3ef83b37bca0 (patch) | |
tree | 0defbe1196f778c9fb3f79f5f6e7a1da9ae92cda /Pooling.lua | |
parent | d290c4cb9d632120d3fba97caefb3afb961081bf (diff) |
everything works with R2. all unit tests pass. Maxpooling has free zero-padding
Diffstat (limited to 'Pooling.lua')
-rw-r--r-- | Pooling.lua | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/Pooling.lua b/Pooling.lua index 0ab7cb4..87d56bf 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -2,12 +2,14 @@ local Pooling, parent = torch.class('cudnn._Pooling', 'nn.Module') local ffi = require 'ffi' local errcheck = cudnn.errcheck -function Pooling:__init(kW, kH, dW, dH) +function Pooling:__init(kW, kH, dW, dH, padW, padH) parent.__init(self) self.kW = kW self.kH = kH self.dW = dW or kW self.dH = dH or kW + self.padW = padW or 0 + self.padH = padH or 0 self.iSize = torch.LongStorage(4):fill(0) self.ceil_mode = false end @@ -26,8 +28,11 @@ function Pooling:resetPoolDescriptors() -- create pooling descriptor self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') errcheck('cudnnCreatePoolingDescriptor', self.poolDesc) - errcheck('cudnnSetPoolingDescriptor', self.poolDesc[0], self.mode, - self.kH, self.kW, self.dH, self.dW); + local ker = torch.IntTensor({self.kH, self.kW}) + local str = torch.IntTensor({self.dH, self.dW}) + local pad = torch.IntTensor({self.padH, self.padW}) + errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 2, + ker:data(), pad:data(), str:data()); local function destroyPoolDesc(d) errcheck('cudnnDestroyPoolingDescriptor', d[0]); end @@ -73,11 +78,17 @@ function Pooling:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function Pooling:updateOutput(input) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) - errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], + errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], + self.poolDesc[0], + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -92,9 +103,11 @@ function Pooling:updateGradInput(input, gradOutput) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) errcheck('cudnnPoolingBackward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), self.iDesc[0], input:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end |