Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r--VolumetricConvolution.lua166
1 files changed, 166 insertions, 0 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
new file mode 100644
index 0000000..05a08d4
--- /dev/null
+++ b/VolumetricConvolution.lua
@@ -0,0 +1,166 @@
+local VolumetricConvolution, parent = torch.class('cudnn.VolumetricConvolution', 'nn.VolumetricConvolution')
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+function VolumetricConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH)
+ parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH)
+ self.padT = padT or 0
+ self.padW = padW or 0
+ self.padH = padH or 0
+ self:reset()
+ self.iSize = torch.LongStorage(5):fill(0)
+end
+
+-- if you change the configuration of the module manually, call this
+function VolumetricConvolution:resetWeightDescriptors()
+ assert(torch.typename(self.weight) == 'torch.CudaTensor', 'Only Cuda supported duh!')
+ assert(torch.typename(self.bias) == 'torch.CudaTensor', 'Only Cuda supported duh!')
+ -- create filterDescriptor for weight
+ self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
+ local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane, self.kT, self.kH, self.kW})
+ errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
+ 'CUDNN_DATA_FLOAT', 5,
+ desc:data());
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', d[0]);
+ end
+ ffi.gc(self.weightDesc, destroyWDesc)
+
+ -- create descriptor for bias
+ self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane, 1, 1))
+end
+
+function VolumetricConvolution:createIODescriptors(input)
+ local batch = true
+ if input:dim() == 4 then
+ input = input:view(1, input:size(1), input:size(2), input:size(3), input:size(4))
+ batch = false
+ end
+ assert(input:dim() == 5 and input:isContiguous());
+ if not self.iDesc or not self.oDesc or
+ input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
+ or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4]
+ or input:size(5) ~= self.iSize[5] then
+ self.iSize = input:size()
+ -- resize gradInput
+ if self.gradInput then self.gradInput:resizeAs(input); end
+ -- create input descriptor
+ self.iDesc = cudnn.toDescriptor(input)
+ -- create conv descriptor
+ self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
+ local pad = torch.IntTensor({self.padT, self.padH, self.padW})
+ local stride = torch.IntTensor({self.dT, self.dH, self.dW})
+ local upscale = torch.IntTensor({1,1,1})
+ errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 3, pad:data(),
+ stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION');
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
+ end
+ ffi.gc(self.convDesc, destroyConvDesc)
+
+ -- create output descriptor and resize output
+ local oSize = torch.IntTensor(5)
+ local oSizeD = oSize:data()
+ errcheck('cudnnGetConvolutionNdForwardOutputDim', self.convDesc[0], self.iDesc[0],
+ self.weightDesc[0], 5, oSizeD)
+ self.output:resize(oSize:long():storage())
+ -- create descriptor for output
+ self.oDesc = cudnn.toDescriptor(self.output)
+ self.oDescBias = cudnn.toDescriptor(self.output:view(self.output:size(1),
+ self.output:size(2),
+ self.output:size(3)
+ *self.output:size(4),
+ self.output:size(5)))
+
+ -- create forwardAlgorithm descriptors for
+ local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
+ errcheck('cudnnGetConvolutionForwardAlgorithm',
+ cudnn.handle[cutorch.getDevice()-1],
+ self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0],
+ 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType)
+ self.algType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionForwardWorkspaceSize',
+ cudnn.handle[cutorch.getDevice()-1],
+ self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0],
+ algType[0], bufSize:data())
+ self.extraBuffer = self.extraBuffer or input.new(1)
+ if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end
+
+ if not batch then
+ self.gradInput = self.gradInput:view(self.gradInput:size(2),
+ self.gradInput:size(3),
+ self.gradInput:size(4),
+ self.gradInput:size(5))
+ self.output = self.output:view(self.output:size(2),
+ self.output:size(3),
+ self.output:size(4),
+ self.output:size(5))
+ end
+ end
+end
+
+local one = torch.FloatTensor({1});
+local zero = torch.FloatTensor({0});
+
+function VolumetricConvolution:updateOutput(input)
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ self:createIODescriptors(input)
+ errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1],
+ one:data(),
+ self.iDesc[0], input:data(),
+ self.weightDesc[0], self.weight:data(),
+ self.convDesc[0], self.algType[0],
+ self.extraBuffer:data(), self.extraBuffer:nElement(),
+ zero:data(),
+ self.oDesc[0], self.output:data());
+ errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C',
+ one:data(), self.biasDesc[0], self.bias:data(), one:data(),
+ self.oDescBias[0], self.output:data());
+ return self.output
+end
+
+function VolumetricConvolution:updateGradInput(input, gradOutput)
+ if not self.gradInput then return end
+ assert((gradOutput:dim() == 4 or gradOutput:dim() == 5)
+ and gradOutput:isContiguous());
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ self:createIODescriptors(input)
+ errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1],
+ one:data(),
+ self.weightDesc[0], self.weight:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.convDesc[0],
+ zero:data(),
+ self.iDesc[0], self.gradInput:data());
+ return self.gradInput
+end
+
+function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
+ self.scaleT = self.scaleT or torch.FloatTensor(1):fill(1.0)
+ self.scaleT = self.scaleT:float() -- this line forces this member to always be on CPU (needed for cudnn)
+
+ scale = scale or 1.0
+ self.scaleT[1] = scale
+ assert((gradOutput:dim() == 4 or gradOutput:dim() == 5)
+ and gradOutput:isContiguous());
+ self:createIODescriptors(input)
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ -- gradBias
+ errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1],
+ self.scaleT:data(),
+ self.oDescBias[0], gradOutput:data(),
+ one:data(),
+ self.biasDesc[0], self.gradBias:data());
+ -- gradWeight
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1],
+ self.scaleT:data(),
+ self.iDesc[0], input:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.convDesc[0],
+ one:data(),
+ self.weightDesc[0], self.gradWeight:data());
+
+end