Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-09-30 04:08:31 +0300
committersoumith <soumith@fb.com>2016-09-30 04:08:31 +0300
commit1cfb74acdf32bd13fe27d0853e7ff5c9608afa6c (patch)
treec656a6c12b5cb1d791dd85cb7024b161d83c7da0
parent01e765d8ef8cadb079fb5063918b1524061b3241 (diff)
adding VolumetricFullConvolutionvolfullconv
-rw-r--r--VolumetricConvolution.lua6
-rw-r--r--VolumetricFullConvolution.lua420
-rw-r--r--convert.lua1
-rw-r--r--init.lua1
-rw-r--r--test/test.lua35
5 files changed, 460 insertions, 3 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 3f32c3d..3163041 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -33,7 +33,7 @@ end
function VolumetricConvolution:fastest(mode)
if mode == nil then mode = true end
self.fastest_mode = mode
- self.iSize = self.iSize or torch.LongStorage(4)
+ self.iSize = self.iSize or torch.LongStorage(5)
self.iSize:fill(0)
return self
end
@@ -48,7 +48,7 @@ function VolumetricConvolution:setMode(fmode, bdmode, bwmode)
if bwmode ~= nil then
self.bwmode = bwmode
end
- self.iSize = self.iSize or torch.LongStorage(4)
+ self.iSize = self.iSize or torch.LongStorage(5)
self.iSize:fill(0)
return self
end
@@ -68,7 +68,7 @@ function VolumetricConvolution:createIODescriptors(input)
batch = false
end
assert(input:dim() == 5 and input:isContiguous());
- self.iSize = self.iSize or torch.LongStorage(4):fill(0)
+ self.iSize = self.iSize or torch.LongStorage(5):fill(0)
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4]
diff --git a/VolumetricFullConvolution.lua b/VolumetricFullConvolution.lua
new file mode 100644
index 0000000..3cc43a3
--- /dev/null
+++ b/VolumetricFullConvolution.lua
@@ -0,0 +1,420 @@
+local VolumetricFullConvolution, parent
+ = torch.class('cudnn.VolumetricFullConvolution', 'nn.VolumetricFullConvolution')
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+local autotunerCache = {}
+autotunerCache[1] = {} -- forward
+autotunerCache[2] = {} -- backwardFilter
+autotunerCache[3] = {} -- backwardData
+
+-- if you change the configuration of the module manually, call this
+function VolumetricFullConvolution:resetWeightDescriptors()
+ assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!')
+ assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!')
+ -- create filterDescriptor for weight
+ self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
+ local desc = torch.IntTensor({self.nInputPlane, self.nOutputPlane,
+ self.kT, self.kH, self.kW})
+ errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
+ cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', 5,
+ desc:data());
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', d[0]);
+ end
+ ffi.gc(self.weightDesc, destroyWDesc)
+
+ -- create descriptor for bias
+ self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,
+ 1, 1))
+end
+
+function VolumetricFullConvolution:fastest(mode)
+ if mode == nil then mode = true end
+ self.fastest_mode = mode
+ self.iSize = self.iSize or torch.LongStorage(5)
+ self.iSize:fill(0)
+ return self
+end
+
+function VolumetricFullConvolution:setMode(fmode, bdmode, bwmode)
+ if fmode ~= nil then
+ self.fmode = fmode
+ end
+ if bdmode ~= nil then
+ self.bdmode = bdmode
+ end
+ if bwmode ~= nil then
+ self.bwmode = bwmode
+ end
+ self.iSize = self.iSize or torch.LongStorage(5)
+ self.iSize:fill(0)
+ return self
+end
+
+function VolumetricFullConvolution:resetMode()
+ self.fmode = nil
+ self.bdmode = nil
+ self.bwmode = nil
+ return self
+end
+
+function VolumetricFullConvolution:createIODescriptors(input)
+ local batch = true
+ if input:dim() == 4 then
+ input = input:view(1, input:size(1), input:size(2),
+ input:size(3), input:size(4))
+ batch = false
+ end
+ assert(input:dim() == 5 and input:isContiguous());
+ self.iSize = self.iSize or torch.LongStorage(5):fill(0)
+ if not self.iDesc or not self.oDesc or
+ input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
+ or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4]
+ or input:size(5) ~= self.iSize[5] then
+ self.iSize = input:size()
+ -- create input descriptor
+ self.iDesc = cudnn.toDescriptor(input)
+ -- create conv descriptor
+ self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
+ local pad = torch.IntTensor({self.padT, self.padH, self.padW})
+ local stride = torch.IntTensor({self.dT, self.dH, self.dW})
+ local upscale = torch.IntTensor({1,1,1})
+ errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
+ 3, pad:data(),
+ stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
+ cudnn.configmap(torch.type(self.weight)));
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
+ end
+ ffi.gc(self.convDesc, destroyConvDesc)
+
+ -- get output shape, resize output
+ local iwidth = input:size(5)
+ local iheight = input:size(4)
+ local idepth = input:size(3)
+ local owidth = (iwidth - 1) * self.dW - 2*self.padW + self.kW + self.adjW
+ local oheight = (iheight - 1) * self.dH - 2*self.padH + self.kH + self.adjH
+ local odepth = (idepth - 1) * self.dT - 2*self.padT + self.kT + self.adjT
+ local oSize = torch.IntTensor({input:size(1), self.nOutputPlane, odepth, oheight, owidth})
+ self.output:resize(oSize:long():storage())
+
+ -- create descriptor for output
+ local output_slice = {{},{1,self.nOutputPlane},{},{}}
+ self.oDesc = cudnn.toDescriptor(self.output[output_slice])
+ self.oDescBias = cudnn.toDescriptor(
+ self.output:view(self.output:size(1),
+ self.output:size(2),
+ self.output:size(3)*self.output:size(4),
+ self.output:size(5)))
+
+ -----------------------------------------------------------------------
+ local function shape(x)
+ return table.concat(x:size():totable(),'x')
+ end
+ local autotunerHash = shape(self.weight) .. ';'
+ .. shape(input) .. ';'
+ .. shape(self.output)
+
+ local maxBufSize = 0
+
+ -- create forwardAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nOutputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST'
+ end
+
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[1][autotunerHash] then
+ algType[0] = autotunerCache[1][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC FW: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.oDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[1][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "\nAutotuning VMC Forward: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.oDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.fmode or algType[0]
+ self.fwdAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionForwardWorkspaceSize',
+ cudnn.getHandle(),
+ self.oDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ -- create backwardFilterAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
+ end
+
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[2][autotunerHash] then
+ algType[0] = autotunerCache[2][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC BWF: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionBwdFilterAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.oDesc[0], self.iDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[2][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "Autotuning backwardFilter: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.oDesc[0], self.iDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.bwmode or algType[0]
+ self.bwdFilterAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize',
+ cudnn.getHandle(),
+ self.oDesc[0], self.iDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ -- create backwardDataAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nOutputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
+ end
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[3][autotunerHash] then
+ algType[0] = autotunerCache[3][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC BWD: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionBwdDataAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.iDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[3][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "Autotuning backwardData: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s\n",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.iDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.bdmode or algType[0]
+ self.bwdDataAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.iDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace()
+ self.extraBuffer = self.extraBuffer:cuda() -- always force float
+ self.extraBufferSizeInBytes =
+ self.extraBuffer:nElement() * 4 -- extraBuffer is always float
+ if maxBufSize > self.extraBufferSizeInBytes then
+ self.extraBuffer:resize(math.ceil(maxBufSize / 4))
+ self.extraBufferSizeInBytes = maxBufSize
+ end
+ -----------------------------------------------------------------------
+
+ if not batch then
+ self.output = self.output:view(self.output:size(2),
+ self.output:size(3),
+ self.output:size(4),
+ self.output:size(5))
+ end
+ end
+end
+
+
+
+
+local function makeContiguous(self, input, gradOutput)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:typeAs(input):resizeAs(input):copy(input)
+ input = self._input
+ end
+ if gradOutput and not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:typeAs(gradOutput):resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
+ return input, gradOutput
+end
+
+function VolumetricFullConvolution:updateOutput(input)
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ self:createIODescriptors(input)
+
+ -- Because SpatialFullConvolution is performing the adjoint of the forward
+ -- convolution operator, we need to swap the forward and backward passes.
+ errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
+ cudnn.scalar(input, 1),
+ self.weightDesc[0], self.weight:data(),
+ self.iDesc[0], input:data(),
+ self.convDesc[0], self.bwdDataAlgType[0],
+ self.extraBuffer:data(), self.extraBufferSizeInBytes,
+ cudnn.scalar(input, 0),
+ self.oDesc[0], self.output:data())
+
+ -- add bias
+ if self.bias then
+ errcheck('cudnnAddTensor', cudnn.getHandle(),
+ cudnn.scalar(input, 1), self.biasDesc[0], self.bias:data(),
+ cudnn.scalar(input, 1), self.oDescBias[0], self.output:data())
+ end
+
+ return self.output
+end
+
+function VolumetricFullConvolution:updateGradInput(input, gradOutput)
+ if not self.gradInput then return end
+ self.gradInput:resizeAs(input)
+
+ assert(gradOutput:dim() == 4 or gradOutput:dim() == 5, 'gradOutput has to be 4D or 5D');
+ assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous')
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ self:createIODescriptors(input)
+
+ errcheck('cudnnConvolutionForward', cudnn.getHandle(),
+ cudnn.scalar(input, 1),
+ self.oDesc[0], gradOutput:data(),
+ self.weightDesc[0], self.weight:data(),
+ self.convDesc[0],
+ self.fwdAlgType[0],
+ self.extraBuffer:data(), self.extraBufferSizeInBytes,
+ cudnn.scalar(input, 0),
+ self.iDesc[0], self.gradInput:data());
+ return self.gradInput
+end
+
+function VolumetricFullConvolution:accGradParameters(input, gradOutput, scale)
+ self.scaleT = self.scaleT or self.weight.new(1)
+ -- this line forces this member to always be on CPU (needed for cudnn)
+ self.scaleT = torch.type(self.weight) == 'torch.CudaDoubleTensor'
+ and self.scaleT:double() or self.scaleT:float()
+ scale = scale or 1.0
+ self.scaleT[1] = scale
+
+ input, gradOutput = makeContiguous(self, input, gradOutput)
+ assert(gradOutput:dim() == 4 or gradOutput:dim() == 5,
+ 'gradOutput has to be a 4D or 5D tensor');
+ self:createIODescriptors(input)
+ if not self.weightDesc then self:resetWeightDescriptors() end
+ -- gradBias
+ errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
+ self.scaleT:data(),
+ self.oDescBias[0], gradOutput:data(),
+ cudnn.scalar(input, 1),
+ self.biasDesc[0], self.gradBias:data());
+ -- gradWeight
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(),
+ self.scaleT:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.iDesc[0], input:data(),
+ self.convDesc[0],
+ self.bwdFilterAlgType[0],
+ self.extraBuffer:data(), self.extraBufferSizeInBytes,
+ cudnn.scalar(input, 1),
+ self.weightDesc[0], self.gradWeight:data());
+end
+
+function VolumetricFullConvolution:clearDesc()
+ self.weightDesc = nil
+ self.biasDesc = nil
+ self.convDesc = nil
+ self.iDesc = nil
+ self.oDesc = nil
+ self.oDescBias = nil
+ self.fwdAlgType = nil
+ self.bwdDataAlgType = nil
+ self.bwdFilterAlgType = nil
+ self.extraBuffer = nil
+ self.extraBufferInBytes = nil
+ self.scaleT = nil
+end
+
+function VolumetricFullConvolution:write(f)
+ self:clearDesc()
+ local var = {}
+ for k,v in pairs(self) do
+ var[k] = v
+ end
+ f:writeObject(var)
+end
+
+function VolumetricFullConvolution:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, 'extraBuffer', '_input', '_gradOutput')
+ return nn.Module.clearState(self)
+end
diff --git a/convert.lua b/convert.lua
index 5368418..4075122 100644
--- a/convert.lua
+++ b/convert.lua
@@ -14,6 +14,7 @@ local layer_list = {
'LogSoftMax',
'VolumetricBatchNormalization',
'VolumetricConvolution',
+ 'VolumetricFullConvolution',
'VolumetricMaxPooling',
'VolumetricAveragePooling',
}
diff --git a/init.lua b/init.lua
index bbb17a3..9b66c4d 100644
--- a/init.lua
+++ b/init.lua
@@ -172,6 +172,7 @@ end
require('cudnn.SpatialConvolution')
require('cudnn.VolumetricConvolution')
require('cudnn.SpatialFullConvolution')
+require('cudnn.VolumetricFullConvolution')
require('cudnn.Pooling')
require('cudnn.SpatialMaxPooling')
require('cudnn.SpatialAveragePooling')
diff --git a/test/test.lua b/test/test.lua
index 7236ddc..403b918 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -318,6 +318,41 @@ function cudnntest.VolumetricConvolution()
testLayer(sconv, gconv, input, gradOutput, scale, true, false)
end
+function cudnntest.VolumetricFullConvolution()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local to = math.random(1,64)
+ local ki = math.random(1,7)
+ local kj = math.random(1,7)
+ local kk = math.random(1,5)
+ local si = math.random(1,ki)
+ local sj = math.random(1,kj)
+ local sk = math.random(1,kk)
+ local ini = math.random(1,32)
+ local inj = math.random(1,32)
+ local ink = math.random(1,10)
+ local outi = (ini-1)*si+ki
+ local outj = (inj-1)*sj+kj
+ local outk = (ink-1)*sk+kk
+ local scale = math.random()
+
+ local input = torch.randn(bs,from,ink,inj,ini):cuda()
+ local gradOutput = torch.randn(bs,to,outk,outj,outi):cuda()
+ local sconv = nn.VolumetricFullConvolution(from,to,kk,ki,kj,sk,si,sj):cuda()
+ local gconv = cast(cudnn.VolumetricFullConvolution(from,to,kk,ki,kj,sk,si,sj):cuda():fastest())
+ gconv.weight:copy(sconv.weight)
+ gconv.bias:copy(sconv.bias)
+
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true) -- batch
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false) -- non-batch
+ local originalTypename = torch.typename(gconv)
+ local gconv = cast(cudnn.convert(sconv, cudnn))
+ mytester:asserteq(torch.typename(gconv),
+ originalTypename, 'conversion type check')
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true)
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false)
+end
+
function cudnntest.VolumetricMaxPooling()
local bs = math.random(1,4)
local from = math.random(1,4)