diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-19 02:36:12 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-19 23:28:13 +0300 |
commit | ad958c0e268d876ee4d713510b8c3ef83b37bca0 (patch) | |
tree | 0defbe1196f778c9fb3f79f5f6e7a1da9ae92cda /SpatialConvolution.lua | |
parent | d290c4cb9d632120d3fba97caefb3afb961081bf (diff) |
everything works with R2. all unit tests pass. Maxpooling has free zero-padding
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 78 |
1 files changed, 55 insertions, 23 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index e939592..794653c 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -17,8 +17,10 @@ function SpatialConvolution:resetWeightDescriptors() -- create filterDescriptor for weight self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]') errcheck('cudnnCreateFilterDescriptor', self.weightDesc) - errcheck('cudnnSetFilterDescriptor', self.weightDesc[0], 'CUDNN_DATA_FLOAT', - self.nOutputPlane, self.nInputPlane, self.kH, self.kW); + local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane, self.kH, self.kW}) + errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], + 'CUDNN_DATA_FLOAT', 4, + desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); end @@ -46,22 +48,40 @@ function SpatialConvolution:createIODescriptors(input) -- create conv descriptor self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') errcheck('cudnnCreateConvolutionDescriptor', self.convDesc) - errcheck('cudnnSetConvolutionDescriptor', self.convDesc[0], self.iDesc[0], - self.weightDesc[0], self.padH, self.padW, - self.dH, self.dW, 1, 1, 'CUDNN_CROSS_CORRELATION'); + local pad = torch.IntTensor({self.padH, self.padW}) + local stride = torch.IntTensor({self.dH, self.dW}) + local upscale = torch.IntTensor({1,1}) + errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 2, pad:data(), + stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION'); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end ffi.gc(self.convDesc, destroyConvDesc) -- create output descriptor and resize output - local oSize = torch.IntTensor(4):fill(0) + local oSize = torch.IntTensor(4) local oSizeD = oSize:data() - errcheck('cudnnGetOutputTensor4dDim', self.convDesc[0], 'CUDNN_CONVOLUTION_FWD', - oSizeD, oSizeD+1, oSizeD+2, oSizeD+3) + errcheck('cudnnGetConvolutionNdForwardOutputDim', self.convDesc[0], self.iDesc[0], + self.weightDesc[0], 4, oSizeD) self.output:resize(oSize:long():storage()) -- create descriptor for output self.oDesc = cudnn.toDescriptor(self.output) + + -- create forwardAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) + errcheck('cudnnGetConvolutionForwardAlgorithm', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType) + self.algType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionForwardWorkspaceSize', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + algType[0], bufSize:data()) + self.extraBuffer = self.extraBuffer or input.new(1) + if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end + if not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), @@ -73,17 +93,22 @@ function SpatialConvolution:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function SpatialConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], + one:data(), self.iDesc[0], input:data(), self.weightDesc[0], self.weight:data(), - self.convDesc[0], self.oDesc[0], self.output:data(), - 'CUDNN_RESULT_NO_ACCUMULATE'); - local alpha = torch.FloatTensor({1}); - errcheck('cudnnAddTensor4d', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', - alpha:data(), self.biasDesc[0], self.bias:data(), + self.convDesc[0], self.algType[0], + self.extraBuffer:data(), self.extraBuffer:nElement(), + zero:data(), + self.oDesc[0], self.output:data()); + errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', + one:data(), self.biasDesc[0], self.bias:data(), one:data(), self.oDesc[0], self.output:data()); return self.output end @@ -95,39 +120,46 @@ function SpatialConvolution:updateGradInput(input, gradOutput) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], + one:data(), self.weightDesc[0], self.weight:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], - self.iDesc[0], self.gradInput:data(), - 'CUDNN_RESULT_NO_ACCUMULATE'); + zero:data(), + self.iDesc[0], self.gradInput:data()); return self.gradInput end +local scaleT = torch.FloatTensor(1):fill(1.0) function SpatialConvolution:accGradParameters(input, gradOutput, scale) - assert(scale == nil or scale == 1) + scale = scale or 1.0 + scaleT[1] = scale assert((gradOutput:dim() == 3 or gradOutput:dim() == 4) and gradOutput:isContiguous()); self:createIODescriptors(input) if not self.weightDesc then self:resetWeightDescriptors() end -- gradBias errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], + scaleT:data(), self.oDesc[0], gradOutput:data(), - self.biasDesc[0], self.gradBias:data(), - 'CUDNN_RESULT_ACCUMULATE'); + one:data(), + self.biasDesc[0], self.gradBias:data()); -- gradWeight errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1], + scaleT:data(), self.iDesc[0], input:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], - self.weightDesc[0], self.gradWeight:data(), - 'CUDNN_RESULT_ACCUMULATE'); + one:data(), + self.weightDesc[0], self.gradWeight:data()); end + --[[ function SpatialConvolution:zeroGradParameters() -- gradWeight, gradBias to zero - local alpha = torch.FloatTensor({0}); - errcheck('cudnnSetTensor4d', self.weightDesc, self.gradWeight:data(), alpha:data()); - errcheck('cudnnSetTensor4d', self.biasDesc, self.gradBias:data(), alpha:data()); + errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + self.weightDesc, self.gradWeight:data(), zero:data()); + errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + self.biasDesc, self.gradBias:data(), zero:data()); end ]]-- |