diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-11 01:03:02 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-11 01:03:02 +0400 |
commit | a8bf53f8738fd319568593718280343c3ebc93e6 (patch) | |
tree | 58f632560c26f31c9d87e81d6f14c546c3020bcd /SpatialConvolution.lua |
first commit
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua new file mode 100644 index 0000000..c28fc57 --- /dev/null +++ b/SpatialConvolution.lua @@ -0,0 +1,114 @@ +local SpatialConvolution, parent = torch.class('cudnn.SpatialConvolution', 'nn.SpatialConvolution') +local ffi = require 'ffi' +local C = cudnn.C +local errcheck = cudnn.errcheck + +function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) + parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH) + self.padW = padW or 0 + self.padH = padH or 0 + self:reset() + self:cuda() + self.iSize = torch.LongStorage(4):fill(0) + self:resetWeightDescriptors() + self.alpha = torch.FloatTensor({1}); +end + +-- if you change the configuration of the module manually, call this +function SpatialConvolution:resetWeightDescriptors() + -- create filterDescriptor for weight + self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]') + errcheck('cudnnCreateFilterDescriptor', self.weightDesc) + errcheck('cudnnSetFilterDescriptor', self.weightDesc[0], 'CUDNN_DATA_FLOAT', + self.nOutputPlane, self.nInputPlane, self.kH, self.kW); + local function destroyWDesc(d) + errcheck('cudnnDestroyFilterDescriptor', self.weightDesc[0]); + end + ffi.gc(self.weightDesc, destroyWDesc) + + -- create descriptor for bias + self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane, 1, 1)) +end + +function SpatialConvolution:createIODescriptors(input) + if input:size(1) ~= self.iSize:size(1) or input:size(2) ~= self.iSize:size(2) + or input:size(3) ~= self.iSize:size(3) or input:size(4) ~= self.iSize:size(4) then + -- resize gradInput + self.gradInput:resizeAs(input) + -- create input descriptor + self.iDesc = cudnn.toDescriptor(input) + -- create conv descriptor + self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') + errcheck('cudnnCreateConvolutionDescriptor', self.convDesc) + errcheck('cudnnSetConvolutionDescriptor', self.convDesc[0], self.iDesc[0], + self.weightDesc[0], self.padH, self.padW, + self.dH, self.dW, 1, 1, 'CUDNN_CROSS_CORRELATION'); + local function destroyConvDesc(d) + errcheck('cudnnDestroyConvolutionDescriptor', self.convDesc[0]); + end + ffi.gc(self.convDesc, destroyConvDesc) + + -- create output descriptor and resize output + local oSize = torch.IntTensor(4):fill(0) + local oSizeD = oSize:data() + errcheck('cudnnGetOutputTensor4dDim', self.convDesc[0], 'CUDNN_CONVOLUTION_FWD', + oSizeD, oSizeD+1, oSizeD+2, oSizeD+3) + self.output:resize(oSize:long():storage()) + -- create descriptor for output + self.oDesc = cudnn.toDescriptor(self.output) + end +end + +function SpatialConvolution:updateOutput(input) + assert(input:dim() == 4 and input:isContiguous()); + self:createIODescriptors(input) + errcheck('cudnnConvolutionForward', cudnn.handle[0], + self.iDesc[0], input:data(), + self.weightDesc[0], self.weight:data(), + self.convDesc[0], self.oDesc[0], self.output:data(), + 'CUDNN_RESULT_NO_ACCUMULATE'); + local alpha = torch.FloatTensor({1}); + errcheck('cudnnAddTensor4d', cudnn.handle[0], 'CUDNN_ADD_SAME_C', + alpha:data(), self.biasDesc[0], self.bias:data(), + self.oDesc[0], self.output:data()); + return self.output +end + +function SpatialConvolution:updateGradInput(input, gradOutput) + assert(input:dim() == 4 and input:isContiguous()); + assert(gradOutput:dim() == 4 and gradOutput:isContiguous()); + errcheck('cudnnConvolutionBackwardData', cudnn.handle[0], + self.weightDesc[0], self.weight:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + self.iDesc[0], self.gradInput:data(), + 'CUDNN_RESULT_NO_ACCUMULATE'); + return self.gradInput +end + +function SpatialConvolution:accGradParameters(input, gradOutput, scale) + assert(scale == nil or scale == 1) + assert(input:dim() == 4 and input:isContiguous()); + assert(gradOutput:dim() == 4 and gradOutput:isContiguous()); + -- gradBias + errcheck('cudnnConvolutionBackwardBias', cudnn.handle[0], + self.oDesc[0], gradOutput:data(), + self.biasDesc[0], self.gradBias:data(), + 'CUDNN_RESULT_ACCUMULATE'); + -- gradWeight + errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[0], + self.iDesc[0], input:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + self.weightDesc[0], self.gradWeight:data(), + 'CUDNN_RESULT_ACCUMULATE'); + +end +--[[ +function SpatialConvolution:zeroGradParameters() + -- gradWeight, gradBias to zero + local alpha = torch.FloatTensor({0}); + errcheck('cudnnSetTensor4d', self.weightDesc, self.gradWeight:data(), alpha:data()); + errcheck('cudnnSetTensor4d', self.biasDesc, self.gradBias:data(), alpha:data()); +end +]]-- |