diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-20 09:49:10 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-20 09:49:10 +0300 |
commit | a2def9658f9df262a252ff71e9dfd310ab722b13 (patch) | |
tree | 03710c67942f063162d5486270a7d013c352bd3c | |
parent | 89c1810899912b6b5acbeec738eb3d4a48437147 (diff) | |
parent | 591cebe40db126788b8ab83c98f80ac0c87f401d (diff) |
Merge pull request #11 from soumith/volcol
volumetric convolutions + unit tests
-rw-r--r-- | README.md | 9 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 166 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 84 |
4 files changed, 257 insertions, 3 deletions
@@ -18,9 +18,12 @@ Modules are API compatible their [`nn`](https://github.com/torch/nn) equivalents cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH) -cudnn.ReLU() -cudnn.Tanh() -cudnn.Sigmoid() +cudnn.ReLU() +cudnn.Tanh() +cudnn.Sigmoid() + +-- Volumetric inputs (4D or 5D batched mode) +cudnn.VolumetricConvolution(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH) -- SoftMax can be run in fast mode or accurate mode. Default is accurate mode. cudnn.SoftMax(fastMode [= false]) -- SoftMax across each image (just like nn.SoftMax) diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua new file mode 100644 index 0000000..05a08d4 --- /dev/null +++ b/VolumetricConvolution.lua @@ -0,0 +1,166 @@ +local VolumetricConvolution, parent = torch.class('cudnn.VolumetricConvolution', 'nn.VolumetricConvolution') +local ffi = require 'ffi' +local errcheck = cudnn.errcheck + +function VolumetricConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH) + parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH) + self.padT = padT or 0 + self.padW = padW or 0 + self.padH = padH or 0 + self:reset() + self.iSize = torch.LongStorage(5):fill(0) +end + +-- if you change the configuration of the module manually, call this +function VolumetricConvolution:resetWeightDescriptors() + assert(torch.typename(self.weight) == 'torch.CudaTensor', 'Only Cuda supported duh!') + assert(torch.typename(self.bias) == 'torch.CudaTensor', 'Only Cuda supported duh!') + -- create filterDescriptor for weight + self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]') + errcheck('cudnnCreateFilterDescriptor', self.weightDesc) + local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane, self.kT, self.kH, self.kW}) + errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], + 'CUDNN_DATA_FLOAT', 5, + desc:data()); + local function destroyWDesc(d) + errcheck('cudnnDestroyFilterDescriptor', d[0]); + end + ffi.gc(self.weightDesc, destroyWDesc) + + -- create descriptor for bias + self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane, 1, 1)) +end + +function VolumetricConvolution:createIODescriptors(input) + local batch = true + if input:dim() == 4 then + input = input:view(1, input:size(1), input:size(2), input:size(3), input:size(4)) + batch = false + end + assert(input:dim() == 5 and input:isContiguous()); + if not self.iDesc or not self.oDesc or + input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] + or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] + or input:size(5) ~= self.iSize[5] then + self.iSize = input:size() + -- resize gradInput + if self.gradInput then self.gradInput:resizeAs(input); end + -- create input descriptor + self.iDesc = cudnn.toDescriptor(input) + -- create conv descriptor + self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') + errcheck('cudnnCreateConvolutionDescriptor', self.convDesc) + local pad = torch.IntTensor({self.padT, self.padH, self.padW}) + local stride = torch.IntTensor({self.dT, self.dH, self.dW}) + local upscale = torch.IntTensor({1,1,1}) + errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 3, pad:data(), + stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION'); + local function destroyConvDesc(d) + errcheck('cudnnDestroyConvolutionDescriptor', d[0]); + end + ffi.gc(self.convDesc, destroyConvDesc) + + -- create output descriptor and resize output + local oSize = torch.IntTensor(5) + local oSizeD = oSize:data() + errcheck('cudnnGetConvolutionNdForwardOutputDim', self.convDesc[0], self.iDesc[0], + self.weightDesc[0], 5, oSizeD) + self.output:resize(oSize:long():storage()) + -- create descriptor for output + self.oDesc = cudnn.toDescriptor(self.output) + self.oDescBias = cudnn.toDescriptor(self.output:view(self.output:size(1), + self.output:size(2), + self.output:size(3) + *self.output:size(4), + self.output:size(5))) + + -- create forwardAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) + errcheck('cudnnGetConvolutionForwardAlgorithm', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType) + self.algType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionForwardWorkspaceSize', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + algType[0], bufSize:data()) + self.extraBuffer = self.extraBuffer or input.new(1) + if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end + + if not batch then + self.gradInput = self.gradInput:view(self.gradInput:size(2), + self.gradInput:size(3), + self.gradInput:size(4), + self.gradInput:size(5)) + self.output = self.output:view(self.output:size(2), + self.output:size(3), + self.output:size(4), + self.output:size(5)) + end + end +end + +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + +function VolumetricConvolution:updateOutput(input) + if not self.weightDesc then self:resetWeightDescriptors() end + self:createIODescriptors(input) + errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], + one:data(), + self.iDesc[0], input:data(), + self.weightDesc[0], self.weight:data(), + self.convDesc[0], self.algType[0], + self.extraBuffer:data(), self.extraBuffer:nElement(), + zero:data(), + self.oDesc[0], self.output:data()); + errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', + one:data(), self.biasDesc[0], self.bias:data(), one:data(), + self.oDescBias[0], self.output:data()); + return self.output +end + +function VolumetricConvolution:updateGradInput(input, gradOutput) + if not self.gradInput then return end + assert((gradOutput:dim() == 4 or gradOutput:dim() == 5) + and gradOutput:isContiguous()); + if not self.weightDesc then self:resetWeightDescriptors() end + self:createIODescriptors(input) + errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], + one:data(), + self.weightDesc[0], self.weight:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + zero:data(), + self.iDesc[0], self.gradInput:data()); + return self.gradInput +end + +function VolumetricConvolution:accGradParameters(input, gradOutput, scale) + self.scaleT = self.scaleT or torch.FloatTensor(1):fill(1.0) + self.scaleT = self.scaleT:float() -- this line forces this member to always be on CPU (needed for cudnn) + + scale = scale or 1.0 + self.scaleT[1] = scale + assert((gradOutput:dim() == 4 or gradOutput:dim() == 5) + and gradOutput:isContiguous()); + self:createIODescriptors(input) + if not self.weightDesc then self:resetWeightDescriptors() end + -- gradBias + errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], + self.scaleT:data(), + self.oDescBias[0], gradOutput:data(), + one:data(), + self.biasDesc[0], self.gradBias:data()); + -- gradWeight + errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1], + self.scaleT:data(), + self.iDesc[0], input:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + one:data(), + self.weightDesc[0], self.gradWeight:data()); + +end @@ -53,6 +53,7 @@ function cudnn.toDescriptor(t) end include 'SpatialConvolution.lua' +include 'VolumetricConvolution.lua' include 'Pooling.lua' include 'SpatialMaxPooling.lua' include 'SpatialAveragePooling.lua' diff --git a/test/test.lua b/test/test.lua index 435ba5f..9931431 100644 --- a/test/test.lua +++ b/test/test.lua @@ -163,6 +163,90 @@ function cudnntest.SpatialConvolution_backward_single() 'error on bias (backward) ') end +function cudnntest.VolumetricConvolution_forward_single() + local from = math.random(1,16) + local to = math.random(1,16) + local ki = math.random(3,5) + local kj = math.random(3,5) + local kk = math.random(3,5) + local si = math.random(1,ki-1) + local sj = math.random(1,kj-1) + local sk = math.random(1,kk-1) + local outi = math.random(1,17) + local outj = math.random(1,17) + local outk = math.random(1,5) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local ink = (outk-1)*sk+kk + local input = torch.randn(from,ink,inj,ini):cuda() + local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() --:cuda() + local groundtruth = sconv:forward(input:float()) + cutorch.synchronize() + local gconv = cudnn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):cuda() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + local rescuda = gconv:forward(input) + cutorch.synchronize() + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') +end + +function cudnntest.VolumetricConvolution_backward_single() + local from = math.random(1,16) + local to = math.random(1,16) + local ki = math.random(3,5) + local kj = math.random(3,5) + local kk = math.random(3,5) + local si = math.random(1,ki-1) + local sj = math.random(1,kj-1) + local sk = math.random(1,kk-1) + local outi = math.random(1,17) + local outj = math.random(1,17) + local outk = math.random(1,5) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local ink = (outk-1)*sk+kk + local input = torch.randn(from,ink,inj,ini):cuda() + local gradOutput = torch.randn(to,outk,outj,outi):cuda() + local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() --:cuda() + local groundtruth = sconv:forward(input:float()) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input:float(), gradOutput:float()) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):cuda() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + cutorch.synchronize() + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput) + cutorch.synchronize() + + mytester:asserteq(rescuda:dim(), 4, 'error in dimension') + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + + mytester:assertlt(error:abs():max(), precision_backward, + 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, + 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, + 'error on bias (backward) ') + +end function cudnntest.SpatialMaxPooling_batch() local bs = math.random(1,32) |