Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-12-19 02:36:12 +0300
committerSoumith Chintala <soumith@gmail.com>2014-12-19 23:28:13 +0300
commitad958c0e268d876ee4d713510b8c3ef83b37bca0 (patch)
tree0defbe1196f778c9fb3f79f5f6e7a1da9ae92cda /Pooling.lua
parentd290c4cb9d632120d3fba97caefb3afb961081bf (diff)
everything works with R2. all unit tests pass. Maxpooling has free zero-padding
Diffstat (limited to 'Pooling.lua')
-rw-r--r--Pooling.lua21
1 files changed, 17 insertions, 4 deletions
diff --git a/Pooling.lua b/Pooling.lua
index 0ab7cb4..87d56bf 100644
--- a/Pooling.lua
+++ b/Pooling.lua
@@ -2,12 +2,14 @@ local Pooling, parent = torch.class('cudnn._Pooling', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
-function Pooling:__init(kW, kH, dW, dH)
+function Pooling:__init(kW, kH, dW, dH, padW, padH)
parent.__init(self)
self.kW = kW
self.kH = kH
self.dW = dW or kW
self.dH = dH or kW
+ self.padW = padW or 0
+ self.padH = padH or 0
self.iSize = torch.LongStorage(4):fill(0)
self.ceil_mode = false
end
@@ -26,8 +28,11 @@ function Pooling:resetPoolDescriptors()
-- create pooling descriptor
self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]')
errcheck('cudnnCreatePoolingDescriptor', self.poolDesc)
- errcheck('cudnnSetPoolingDescriptor', self.poolDesc[0], self.mode,
- self.kH, self.kW, self.dH, self.dW);
+ local ker = torch.IntTensor({self.kH, self.kW})
+ local str = torch.IntTensor({self.dH, self.dW})
+ local pad = torch.IntTensor({self.padH, self.padW})
+ errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 2,
+ ker:data(), pad:data(), str:data());
local function destroyPoolDesc(d)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
end
@@ -73,11 +78,17 @@ function Pooling:createIODescriptors(input)
end
end
+local one = torch.FloatTensor({1});
+local zero = torch.FloatTensor({0});
+
function Pooling:updateOutput(input)
if not self.poolDesc then self:resetPoolDescriptors() end
self:createIODescriptors(input)
- errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0],
+ errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1],
+ self.poolDesc[0],
+ one:data(),
self.iDesc[0], input:data(),
+ zero:data(),
self.oDesc[0], self.output:data());
return self.output
end
@@ -92,9 +103,11 @@ function Pooling:updateGradInput(input, gradOutput)
if not self.poolDesc then self:resetPoolDescriptors() end
self:createIODescriptors(input)
errcheck('cudnnPoolingBackward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0],
+ one:data(),
self.oDesc[0], self.output:data(),
self.oDesc[0], gradOutput:data(),
self.iDesc[0], input:data(),
+ zero:data(),
self.iDesc[0], self.gradInput:data());
return self.gradInput
end