diff options
-rw-r--r-- | RNN.lua | 2 | ||||
-rw-r--r-- | SpatialConvolution.lua | 200 | ||||
-rw-r--r-- | algo.lua | 122 | ||||
-rw-r--r-- | functional.lua | 2 | ||||
-rw-r--r-- | test/test.lua | 35 | ||||
-rw-r--r-- | test/test_groups.lua | 2 |
6 files changed, 159 insertions, 204 deletions
@@ -383,7 +383,7 @@ function RNN:updateOutput(input) if self.cellOutput then self.cellInput = self.cellOutput:clone() end - end + end if (self.batchFirst) then self.output = self.output:transpose(1, 2) end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 58c78b2..dfd52e2 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -1,13 +1,9 @@ local SpatialConvolution, parent = torch.class('cudnn.SpatialConvolution', 'nn.SpatialConvolution') local ffi = require 'ffi' +local algo = require 'algo' local errcheck = cudnn.errcheck -local autotunerCache = {} -autotunerCache[1] = {} -- forward -autotunerCache[2] = {} -- backwardFilter -autotunerCache[3] = {} -- backwardData - function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, groups) local delayedReset = self.reset @@ -98,8 +94,7 @@ function SpatialConvolution:createIODescriptors(input) end assert(input:dim() == 4 and input:isContiguous()); self.iSize = self.iSize or torch.LongStorage(4):fill(0) - if not self.iDesc or not self.oDesc or - input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] + if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() @@ -143,188 +138,6 @@ function SpatialConvolution:createIODescriptors(input) self.oDesc = cudnn.toDescriptor(output_slice) self.oDescForBias = cudnn.toDescriptor(self.output) - ----------------------------------------------------------------------- - local function shape(x) - local sz = x:size() - local str = '' - for i=1,sz:size() do - str = str .. sz[i] .. 'x' - end - if #str > 0 then - str = str:sub(1, #str-1) - end - return str - end - local autotunerHash = shape(self.weight) .. ';' - .. shape(input_slice) .. ';' - .. shape(output_slice) - - local maxBufSize = 0 - - -- create forwardAlgorithm descriptors - local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' - local algWorkspaceLimit = self.workspace_limit - or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. - - if self.fastest_mode or cudnn.fastest == true then - algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' - end - - if cudnn.benchmark then -- the manual auto-tuner is run - if autotunerCache[1][autotunerHash] then - algType[0] = autotunerCache[1][autotunerHash] - if cudnn.verbose then - print('Autotuning SC FW: using cached algo = ', algType[0], ' for: ', autotunerHash) - end - else - local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1) - local intt = torch.IntTensor(1); - errcheck('cudnnFindConvolutionForwardAlgorithm', - cudnn.getHandle(), - self.iDesc[0], self.weightDesc[0], - self.convDesc[0], self.oDesc[0], - 1, intt:data(), perfResults) - algType[0] = perfResults[0].algo - autotunerCache[1][autotunerHash] = perfResults[0].algo - if cudnn.verbose then - print(string.format( - "\nAutotuning SC Forward: Time: %3.5f Memory: %8d Algorithm: %d" - .. " Weight: %15s Input: %15s Output: %15s", - perfResults[0].time, tonumber(perfResults[0].memory), - tonumber(perfResults[0].algo), - shape(self.weight), shape(input_slice), - shape(output_slice))) - end - end - else - errcheck('cudnnGetConvolutionForwardAlgorithm', - cudnn.getHandle(), - self.iDesc[0], self.weightDesc[0], - self.convDesc[0], self.oDesc[0], - algSearchMode, algWorkspaceLimit, algType) - end - algType[0] = self.fmode or algType[0] - self.fwdAlgType = algType - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionForwardWorkspaceSize', - cudnn.getHandle(), - self.iDesc[0], self.weightDesc[0], - self.convDesc[0], self.oDesc[0], - algType[0], bufSize:data()) - maxBufSize = math.max(maxBufSize, bufSize[1]) - - -- create backwardFilterAlgorithm descriptors - local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' - local algWorkspaceLimit = self.workspace_limit - or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. - if self.fastest_mode or cudnn.fastest == true then - algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' - end - - if cudnn.benchmark then -- the manual auto-tuner is run - if autotunerCache[2][autotunerHash] then - algType[0] = autotunerCache[2][autotunerHash] - if cudnn.verbose then - print('Autotuning SC BW: using cached algo = ', algType[0], ' for: ', autotunerHash) - end - else - local perfResults = ffi.new("cudnnConvolutionBwdFilterAlgoPerf_t[?]", 1) - local intt = torch.IntTensor(1); - errcheck('cudnnFindConvolutionBackwardFilterAlgorithm', - cudnn.getHandle(), - self.iDesc[0], self.oDesc[0], - self.convDesc[0], self.weightDesc[0], - 1, intt:data(), perfResults) - algType[0] = perfResults[0].algo - autotunerCache[2][autotunerHash] = perfResults[0].algo - if cudnn.verbose then - print(string.format( - "Autotuning backwardFilter: Time: %3.5f Memory: %8d Algorithm: %d" - .. " Weight: %15s Input: %15s Output: %15s", - perfResults[0].time, tonumber(perfResults[0].memory), - tonumber(perfResults[0].algo), - shape(self.weight), shape(input_slice), - shape(output_slice))) - end - end - else - errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', - cudnn.getHandle(), - self.iDesc[0], self.oDesc[0], - self.convDesc[0], self.weightDesc[0], - algSearchMode, algWorkspaceLimit, algType) - end - algType[0] = self.bwmode or algType[0] - self.bwdFilterAlgType = algType - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize', - cudnn.getHandle(), - self.iDesc[0], self.oDesc[0], - self.convDesc[0], self.weightDesc[0], - algType[0], bufSize:data()) - maxBufSize = math.max(maxBufSize, bufSize[1]) - - -- create backwardDataAlgorithm descriptors - local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' - local algWorkspaceLimit = self.workspace_limit - or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. - if self.fastest_mode or cudnn.fastest == true then - algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' - end - if cudnn.benchmark then -- the manual auto-tuner is run - if autotunerCache[3][autotunerHash] then - algType[0] = autotunerCache[3][autotunerHash] - if cudnn.verbose then - print('Autotuning SC BWD: using cached algo = ', algType[0], ' for: ', autotunerHash) - end - else - local perfResults = ffi.new("cudnnConvolutionBwdDataAlgoPerf_t[?]", 1) - local intt = torch.IntTensor(1); - errcheck('cudnnFindConvolutionBackwardDataAlgorithm', - cudnn.getHandle(), - self.weightDesc[0], self.oDesc[0], - self.convDesc[0], self.iDesc[0], - 1, intt:data(), perfResults) - algType[0] = perfResults[0].algo - autotunerCache[3][autotunerHash] = perfResults[0].algo - if cudnn.verbose then - print(string.format( - "Autotuning backwardData: Time: %3.5f Memory: %8d Algorithm: %d" - .. " Weight: %15s Input: %15s Output: %15s\n", - perfResults[0].time, tonumber(perfResults[0].memory), - tonumber(perfResults[0].algo), - shape(self.weight), shape(input_slice), - shape(output_slice))) - end - end - else - errcheck('cudnnGetConvolutionBackwardDataAlgorithm', - cudnn.getHandle(), - self.weightDesc[0], self.oDesc[0], - self.convDesc[0], self.iDesc[0], - algSearchMode, algWorkspaceLimit, algType) - end - algType[0] = self.bdmode or algType[0] - self.bwdDataAlgType = algType - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize', - cudnn.getHandle(), - self.weightDesc[0], self.oDesc[0], - self.convDesc[0], self.iDesc[0], - algType[0], bufSize:data()) - maxBufSize = math.max(maxBufSize, bufSize[1]) - - self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace() - self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float - if maxBufSize > self.extraBufferSizeInBytes then - self.extraBuffer:resize(math.ceil(maxBufSize/4)) - self.extraBufferSizeInBytes = maxBufSize - end - - ----------------------------------------------------------------------- -- create offsets for groups local iH, iW = input:size(3), input:size(4) local kH, kW = self.kH, self.kW @@ -338,6 +151,9 @@ function SpatialConvolution:createIODescriptors(input) self.output:size(3), self.output:size(4)) end + + -- setup forward algo right away + algo.setupForwardAlgorithm(self, input_slice, output_slice) end end @@ -392,6 +208,9 @@ function SpatialConvolution:updateGradInput(input, gradOutput) assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) + if not self.bwdDataAlgType then + algo.setupBackwardDataAlgorithm(self) + end for g = 0,self.groups - 1 do errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(), @@ -419,6 +238,9 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) + if not self.bwdFilterAlgType then + algo.setupBackwardFilterAlgorithm(self) + end -- gradBias if self.bias then diff --git a/algo.lua b/algo.lua new file mode 100644 index 0000000..a88a839 --- /dev/null +++ b/algo.lua @@ -0,0 +1,122 @@ +local ffi = require 'ffi' +local errcheck = cudnn.errcheck + +local algo = {} +local autotunerCache = {} +autotunerCache[1] = {} -- forward +autotunerCache[2] = {} -- backwardFilter +autotunerCache[3] = {} -- backwardData + +local function setupAlgo(self, algo_t, perf_t, findAPI, getAPI, wsAPI, algSearchMode, params) + -- create forwardAlgorithm descriptors + local algType = ffi.new(algo_t, 1) + + self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace() + self.extraBufferSizeInBytes = self.extraBuffer:nElement() * self.extraBuffer.elementSize() + + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kH * self.kW * self.extraBuffer.elementSize()) + + + if cudnn.benchmark then -- the manual auto-tuner is run + if autotunerCache[1][self.autotunerHash] then + algType[0] = autotunerCache[1][self.autotunerHash] + if cudnn.verbose then + print('\nAutotuning ', algo_t, ' using cached algo = ' , algType[0] , ' for: ', self.autotunerHash) + end + else + local perfResults = ffi.new(perf_t, 1) + local intt = torch.IntTensor(1) + errcheck(findAPI, + cudnn.getHandle(), + params[1], params[2], params[3], params[4], + 1, intt:data(), perfResults) + algType[0] = perfResults[0].algo + autotunerCache[1][self.autotunerHash] = perfResults[0].algo + if cudnn.verbose then + print(string.format( + "\nAutotuning " .. algo_t .. " Time: %3.5f Memory: %8d Algorithm: %d" + .. " hash: %45s\n", + perfResults[0].time, tonumber(perfResults[0].memory), + tonumber(perfResults[0].algo), self.autotunerHash )) + + end + end + else + errcheck(getAPI, + cudnn.getHandle(), + params[1], params[2], params[3], params[4], + algSearchMode, algWorkspaceLimit, algType) + end + local bufSize = torch.LongTensor(1) + errcheck(wsAPI, + cudnn.getHandle(), + params[1], params[2], params[3], params[4], + algType[0], bufSize:data()) + if self.extraBufferSizeInBytes < bufSize[1] then + self.extraBuffer:resize(math.ceil(bufSize[1]/self.extraBuffer.elementSize())) + self.extraBufferSizeInBytes = bufSize[1] + end + return algType +end + +function algo.setupForwardAlgorithm(self, input_slice, output_slice) + local function shape(x) + local sz = x:size() + local str = '' + for i=1,sz:size() do + str = str .. sz[i] .. 'x' + end + if #str > 0 then + str = str:sub(1, #str-1) + end + return str + end + + self.autotunerHash = shape(self.weight) .. ';' + .. shape(input_slice) .. ';' + .. shape(output_slice) + + local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' + if self.fastest_mode or cudnn.fastest == true then + algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' + end + + local params = { self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0] } + local algType = setupAlgo(self,"cudnnConvolutionFwdAlgo_t[?]", "cudnnConvolutionFwdAlgoPerf_t[?]", + 'cudnnFindConvolutionForwardAlgorithm', 'cudnnGetConvolutionForwardAlgorithm', + 'cudnnGetConvolutionForwardWorkspaceSize', algSearchMode, params) + algType[0] = self.fmode or algType[0] + self.fwdAlgType = algType + self.bwdDataAlgType = nil + self.bwdFilterAlgType = nil +end + +function algo.setupBackwardFilterAlgorithm(self) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' + if self.fastest_mode or cudnn.fastest == true then + algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' + end + local params = { self.iDesc[0], self.oDesc[0], self.convDesc[0], self.weightDesc[0] } + local algType = setupAlgo(self,"cudnnConvolutionBwdFilterAlgo_t[?]", "cudnnConvolutionBwdFilterAlgoPerf_t[?]", + 'cudnnFindConvolutionBackwardFilterAlgorithm', 'cudnnGetConvolutionBackwardFilterAlgorithm', + 'cudnnGetConvolutionBackwardFilterWorkspaceSize', algSearchMode, + params) + algType[0] = self.bwmode or algType[0] + self.bwdFilterAlgType = algType +end + +function algo.setupBackwardDataAlgorithm(self) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' + if self.fastest_mode or cudnn.fastest == true then + algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' + end + local params = { self.weightDesc[0], self.oDesc[0], self.convDesc[0], self.iDesc[0] } + local algType = setupAlgo(self,"cudnnConvolutionBwdDataAlgo_t[?]", "cudnnConvolutionBwdDataAlgoPerf_t[?]", + 'cudnnFindConvolutionBackwardDataAlgorithm', 'cudnnGetConvolutionBackwardDataAlgorithm', + 'cudnnGetConvolutionBackwardDataWorkspaceSize', algSearchMode, params) + algType[0] = self.bdmode or algType[0] + self.bwdDataAlgType = algType +end + +return algo diff --git a/functional.lua b/functional.lua index cea9df9..8eee7c9 100644 --- a/functional.lua +++ b/functional.lua @@ -105,7 +105,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' local algWorkspaceLimit = 0 if workspace then - algWorkspaceLimit = workspace:nElement() * 4 -- 4 = sizeof float + algWorkspaceLimit = workspace:nElement() * workspace:elementSize() end errcheck('cudnnGetConvolutionForwardAlgorithm', handle, diff --git a/test/test.lua b/test/test.lua index a4f9bcb..aef4daa 100644 --- a/test/test.lua +++ b/test/test.lua @@ -33,6 +33,14 @@ local testparams_double = { precision_io = 1e-5, } +local testparams_double = { + test_type = 'torch.CudaDoubleTensor', + precision_forward = 1e-4, + precision_backward = 2e-2, + precision_jac = 1e-3, + precision_io = 1e-5, +} + local testparams = testparams_half local function cast(input) @@ -1043,7 +1051,7 @@ function cudnntest.SpatialCrossMapLRN_batch() local size = math.random(1,3)*2+1 local nbfeatures = math.random(3,8) local alpha = math.random(1,100)/100 - local beta = math.random(0,100)/100 + local beta = math.random(1,100)/100 local k = math.random(1,3) local tm = {} @@ -1507,10 +1515,10 @@ math.randomseed(os.time()) mytester = torch.Tester() mytester:add(cudnntest) -if torch.random(1,2) == 1 then - cudnn.benchmark = true -- run manual auto-tuner --- cudnn.verbose = true -end +-- if torch.random(1,2) == 1 then +-- cudnn.benchmark = true -- run manual auto-tuner + cudnn.verbose = true +--end for i=1,cutorch.getDeviceCount() do @@ -1520,19 +1528,22 @@ for i=1,cutorch.getDeviceCount() do cutorch.setDevice(i) - print'Testing torch.CudaHalfTensor' - testparams = testparams_half - mytester:run() - print'Testing torch.CudaTensor' - testparams = testparams_float - mytester:run() +-- double tensor may be broken --- double tensor may be broken at some places, gets NaNs. -- print'Testing torch.CudaDoubleTensor' +-- torch.setdefaulttensortype('torch.DoubleTensor') -- testparams = testparams_double -- mytester:run() + print'Testing torch.CudaTensor' + testparams = testparams_float + mytester:run() + + print'Testing torch.CudaHalfTensor' + testparams = testparams_half + mytester:run() + end os.execute('rm -f modelTemp.t7') diff --git a/test/test_groups.lua b/test/test_groups.lua index 8b386b9..1675fdd 100644 --- a/test/test_groups.lua +++ b/test/test_groups.lua @@ -34,6 +34,6 @@ ccn2_gradWeight = ccn2_conv.gradWeight:t() assert((cudnn_output - ccn2_output):abs():max() < 1e-4) assert((cudnn_gradInput - ccn2_gradInput):abs():max() < 1e-4) -assert((cudnn_gradWeight - ccn2_gradWeight):abs():max() < 5e-2) +assert((cudnn_gradWeight - ccn2_gradWeight):abs():max() < 1e-1) print 'no assertions' |