Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-11 01:03:02 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-11 01:03:02 +0400
commita8bf53f8738fd319568593718280343c3ebc93e6 (patch)
tree58f632560c26f31c9d87e81d6f14c546c3020bcd /SpatialConvolution.lua
first commit
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua114
1 files changed, 114 insertions, 0 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
new file mode 100644
index 0000000..c28fc57
--- /dev/null
+++ b/SpatialConvolution.lua
@@ -0,0 +1,114 @@
+local SpatialConvolution, parent = torch.class('cudnn.SpatialConvolution', 'nn.SpatialConvolution')
+local ffi = require 'ffi'
+local C = cudnn.C
+local errcheck = cudnn.errcheck
+
+function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+ parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH)
+ self.padW = padW or 0
+ self.padH = padH or 0
+ self:reset()
+ self:cuda()
+ self.iSize = torch.LongStorage(4):fill(0)
+ self:resetWeightDescriptors()
+ self.alpha = torch.FloatTensor({1});
+end
+
+-- if you change the configuration of the module manually, call this
+function SpatialConvolution:resetWeightDescriptors()
+ -- create filterDescriptor for weight
+ self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
+ errcheck('cudnnSetFilterDescriptor', self.weightDesc[0], 'CUDNN_DATA_FLOAT',
+ self.nOutputPlane, self.nInputPlane, self.kH, self.kW);
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', self.weightDesc[0]);
+ end
+ ffi.gc(self.weightDesc, destroyWDesc)
+
+ -- create descriptor for bias
+ self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane, 1, 1))
+end
+
+function SpatialConvolution:createIODescriptors(input)
+ if input:size(1) ~= self.iSize:size(1) or input:size(2) ~= self.iSize:size(2)
+ or input:size(3) ~= self.iSize:size(3) or input:size(4) ~= self.iSize:size(4) then
+ -- resize gradInput
+ self.gradInput:resizeAs(input)
+ -- create input descriptor
+ self.iDesc = cudnn.toDescriptor(input)
+ -- create conv descriptor
+ self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
+ errcheck('cudnnSetConvolutionDescriptor', self.convDesc[0], self.iDesc[0],
+ self.weightDesc[0], self.padH, self.padW,
+ self.dH, self.dW, 1, 1, 'CUDNN_CROSS_CORRELATION');
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', self.convDesc[0]);
+ end
+ ffi.gc(self.convDesc, destroyConvDesc)
+
+ -- create output descriptor and resize output
+ local oSize = torch.IntTensor(4):fill(0)
+ local oSizeD = oSize:data()
+ errcheck('cudnnGetOutputTensor4dDim', self.convDesc[0], 'CUDNN_CONVOLUTION_FWD',
+ oSizeD, oSizeD+1, oSizeD+2, oSizeD+3)
+ self.output:resize(oSize:long():storage())
+ -- create descriptor for output
+ self.oDesc = cudnn.toDescriptor(self.output)
+ end
+end
+
+function SpatialConvolution:updateOutput(input)
+ assert(input:dim() == 4 and input:isContiguous());
+ self:createIODescriptors(input)
+ errcheck('cudnnConvolutionForward', cudnn.handle[0],
+ self.iDesc[0], input:data(),
+ self.weightDesc[0], self.weight:data(),
+ self.convDesc[0], self.oDesc[0], self.output:data(),
+ 'CUDNN_RESULT_NO_ACCUMULATE');
+ local alpha = torch.FloatTensor({1});
+ errcheck('cudnnAddTensor4d', cudnn.handle[0], 'CUDNN_ADD_SAME_C',
+ alpha:data(), self.biasDesc[0], self.bias:data(),
+ self.oDesc[0], self.output:data());
+ return self.output
+end
+
+function SpatialConvolution:updateGradInput(input, gradOutput)
+ assert(input:dim() == 4 and input:isContiguous());
+ assert(gradOutput:dim() == 4 and gradOutput:isContiguous());
+ errcheck('cudnnConvolutionBackwardData', cudnn.handle[0],
+ self.weightDesc[0], self.weight:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.convDesc[0],
+ self.iDesc[0], self.gradInput:data(),
+ 'CUDNN_RESULT_NO_ACCUMULATE');
+ return self.gradInput
+end
+
+function SpatialConvolution:accGradParameters(input, gradOutput, scale)
+ assert(scale == nil or scale == 1)
+ assert(input:dim() == 4 and input:isContiguous());
+ assert(gradOutput:dim() == 4 and gradOutput:isContiguous());
+ -- gradBias
+ errcheck('cudnnConvolutionBackwardBias', cudnn.handle[0],
+ self.oDesc[0], gradOutput:data(),
+ self.biasDesc[0], self.gradBias:data(),
+ 'CUDNN_RESULT_ACCUMULATE');
+ -- gradWeight
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[0],
+ self.iDesc[0], input:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.convDesc[0],
+ self.weightDesc[0], self.gradWeight:data(),
+ 'CUDNN_RESULT_ACCUMULATE');
+
+end
+--[[
+function SpatialConvolution:zeroGradParameters()
+ -- gradWeight, gradBias to zero
+ local alpha = torch.FloatTensor({0});
+ errcheck('cudnnSetTensor4d', self.weightDesc, self.gradWeight:data(), alpha:data());
+ errcheck('cudnnSetTensor4d', self.biasDesc, self.gradBias:data(), alpha:data());
+end
+]]--