diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-09-22 12:19:15 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-09-23 04:06:43 +0300 |
commit | 9465aae4f41734c8218adaf2d50c7b3f5c9e80f7 (patch) | |
tree | f847cee2bf66726d11e1f5a6e402f936a108a401 | |
parent | a17af4f12cbeb87103dbc514408eb64e1be85ba7 (diff) |
Revamped workspace handling in find.lua
Retired functional.lua: impossible to maintain consistently with Find.
Simplified FindEx state machine: replaced witgh warmup iterations concept, controllable by user.
FindEx still needs some work.
Improved cache handling and debug print
-rw-r--r-- | SpatialConvolution.lua | 25 | ||||
-rw-r--r-- | SpatialFullConvolution.lua | 31 | ||||
-rw-r--r-- | TemporalConvolution.lua | 3 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 15 | ||||
-rw-r--r-- | find.lua | 274 | ||||
-rw-r--r-- | functional.lua | 375 | ||||
-rw-r--r-- | init.lua | 66 | ||||
-rw-r--r-- | test/test.lua | 149 |
8 files changed, 235 insertions, 703 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index f2ab112..512e7c2 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -126,12 +126,12 @@ function SpatialConvolution:createIODescriptors(input) self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') self.padH, self.padW = self.padH or 0, self.padW or 0 - local pad = torch.IntTensor({self.padH, self.padW}) - local stride = torch.IntTensor({self.dH, self.dW}) + self.pad = torch.IntTensor({self.padH, self.padW}) + self.stride = torch.IntTensor({self.dH, self.dW}) local upscale = torch.IntTensor({1,1}) errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + 2, self.pad:data(), + self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', cudnn.configmap(torch.type(self.weight))); @@ -188,9 +188,8 @@ function SpatialConvolution:updateOutput(input) self:createIODescriptors(input) local finder = find.get() -- force recalculation - if not (self.fmode and finder.useCalculatedWorkspaceSize) then - self.fmode = finder:forwardAlgorithm(self, { self.iDesc[0], self.input_slice, self.weightDesc[0], self.weight, self.convDesc[0], self.oDesc[0], self.output_slice}) - end + self.fmode = finder:forwardAlgorithm(self, { self.iDesc[0], self.input_slice, self.weightDesc[0], self.weight, self.convDesc[0], self.oDesc[0], self.output_slice}) + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0, self.groups - 1 do errcheck(self,'cudnnConvolutionForward', cudnn.getHandle(), @@ -221,9 +220,8 @@ function SpatialConvolution:updateGradInput(input, gradOutput) input, gradOutput = makeContiguous(self, input, gradOutput) self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bdmode) then - self.bdmode = finder:backwardDataAlgorithm(self, { self.weightDesc[0], self.weight, self.oDesc[0], self.output_slice, self.convDesc[0], self.iDesc[0], self.input_slice }) - end + self.bdmode = finder:backwardDataAlgorithm(self, { self.weightDesc[0], self.weight, self.oDesc[0], self.output_slice, self.convDesc[0], self.iDesc[0], self.input_slice }) + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0,self.groups - 1 do errcheck(self,'cudnnConvolutionBackwardData', cudnn.getHandle(), @@ -249,10 +247,8 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) input, gradOutput = makeContiguous(self, input, gradOutput) self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bmode) then - self.bmode=finder:backwardFilterAlgorithm(self, { self.iDesc[0], self.input_slice, self.oDesc[0], self.output_slice, self.convDesc[0], self.weightDesc[0], self.weight}) - end - + self.bmode=finder:backwardFilterAlgorithm(self, { self.iDesc[0], self.input_slice, self.oDesc[0], self.output_slice, self.convDesc[0], self.weightDesc[0], self.weight}) + finder:setCalculatedWorkspaceSize(true) -- gradBias if self.bias then errcheck(self,'cudnnConvolutionBackwardBias', cudnn.getHandle(), @@ -261,6 +257,7 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) cudnn.scalar(input, 1), self.biasDesc[0], self.gradBias:data()) end + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0, self.groups - 1 do -- gradWeight diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua index 16ccedb..b83f253 100644 --- a/SpatialFullConvolution.lua +++ b/SpatialFullConvolution.lua @@ -45,12 +45,12 @@ function SpatialFullConvolution:createIODescriptors(input) -- create conv descriptor self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') - local pad = torch.IntTensor({self.padH, self.padW}) - local stride = torch.IntTensor({self.dH, self.dW}) + self.pad = torch.IntTensor({self.padH, self.padW}) + self.stride = torch.IntTensor({self.dH, self.dW}) local upscale = torch.IntTensor({1,1}) errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + 2, self.pad:data(), + self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', cudnn.configmap(torch.type(self.weight))); -- get output shape, resize output @@ -83,12 +83,11 @@ end function SpatialFullConvolution:updateOutput(input) self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bdmode) then - self.bdmode = finder:backwardDataAlgorithm(self, {self.weightDesc[0], self.weight, + self.bdmode = finder:backwardDataAlgorithm(self, {self.weightDesc[0], self.weight, self.iDesc[0],self.input_slice, self.convDesc[0], self.oDesc[0], self.output_slice}) - end + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() -- Because SpatialFullConvolution is performing the adjoint of the forward @@ -120,11 +119,10 @@ function SpatialFullConvolution:updateGradInput(input, gradOutput) assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous') self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.fmode) then - self.fmode = finder:forwardAlgorithm(self, {self.oDesc[0], self.output_slice, - self.weightDesc[0], self.weight, - self.convDesc[0], self.iDesc[0], self.input_slice}) - end + self.fmode = finder:forwardAlgorithm(self, {self.oDesc[0], self.output_slice, + self.weightDesc[0], self.weight, + self.convDesc[0], self.iDesc[0], self.input_slice}) + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() errcheck(self,'cudnnConvolutionForward', cudnn.getHandle(), cudnn.scalar(input, 1), @@ -151,11 +149,9 @@ function SpatialFullConvolution:accGradParameters(input, gradOutput, scale) assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous') self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bmode) then - self.bmode = finder:backwardFilterAlgorithm(self, {self.oDesc[0], self.output_slice, - self.iDesc[0], self.input_slice, - self.convDesc[0], self.weightDesc[0], self.weight}) - end + self.bmode = finder:backwardFilterAlgorithm(self, {self.oDesc[0], self.output_slice, + self.iDesc[0], self.input_slice, + self.convDesc[0], self.weightDesc[0], self.weight}) -- gradBias if self.bias then errcheck(self,'cudnnConvolutionBackwardBias', cudnn.getHandle(), @@ -164,6 +160,7 @@ function SpatialFullConvolution:accGradParameters(input, gradOutput, scale) cudnn.scalar(input, 1), self.biasDesc[0], self.gradBias:data()) end + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() -- gradWeight errcheck(self,'cudnnConvolutionBackwardFilter', cudnn.getHandle(), diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua index 0654a39..cee0f44 100644 --- a/TemporalConvolution.lua +++ b/TemporalConvolution.lua @@ -52,7 +52,6 @@ local function inputview(input) end function TemporalConvolution:updateOutput(input) - find.get():verifyWorkspaceSize(self) local _input = inputview(input) assert(_input:size(4) == self.inputFrameSize,'invalid input frame size') self.buffer = self.buffer or input.new() @@ -86,7 +85,6 @@ end function TemporalConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end - find.get():verifyWorkspaceSize(self) local _gradOutput = transposeGradOutput(gradOutput,self.buffer) local _input = inputview(input) self.gradInput = Convolution.updateGradInput(self, _input, _gradOutput) @@ -99,7 +97,6 @@ function TemporalConvolution:updateGradInput(input, gradOutput) end function TemporalConvolution:accGradParameters(input,gradOutput,scale) - find.get():verifyWorkspaceSize(self) --2d (4d) view of input local _input = inputview(input) -- transpose gradOutput (it will likely be transposed twice, hopefully, no big deal diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 6a06075..d38125b 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -37,13 +37,18 @@ function VolumetricConvolution:createIODescriptors(input) -- create conv descriptor self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') - local pad = torch.IntTensor({self.padT, self.padH, self.padW}) - local stride = torch.IntTensor({self.dT, self.dH, self.dW}) + self.pad = torch.IntTensor({self.padT, self.padH, self.padW}) + self.stride = torch.IntTensor({self.dT, self.dH, self.dW}) local upscale = torch.IntTensor({1,1,1}) + local mathtype=cudnn.configmap(torch.type(self.weight)) + -- 3D convolutions do not work in 16 bits + if mathtype == 'CUDNN_DATA_HALF' then + mathtype = 'CUDNN_DATA_FLOAT' + end errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0], - 3, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap(torch.type(self.weight))); + 3, self.pad:data(), + self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + mathtype); -- create output descriptor and resize output local oSize = torch.IntTensor(5) @@ -6,6 +6,10 @@ find.__index = find -- constants to index array tables below local Fwd, BwdFilter, BwdData = 1, 2, 3 +local warmupIterations = 0 + +local Meg = 1024*1024 + -- cudnnGetxxx APIs: default, when cudnn.benchmark == false local getAlgos = {'cudnnGetConvolutionForwardAlgorithm', 'cudnnGetConvolutionBackwardFilterAlgorithm', @@ -25,9 +29,44 @@ local findExAlgos = {'cudnnFindConvolutionForwardAlgorithmEx', 'cudnnFindConvolutionBackwardDataAlgorithmEx'} +local fwdAlgoNames = { + "IMPLICIT_GEMM", + "IMPLICIT_PRECOMP_GEMM", + "GEMM", + "DIRECT", + "FFT", + "FFT_TILING", + "WINOGRAD", + "WINOGRAD_NONFUSED" +} + +local bwdFilterAlgoNames = { + "ALGO_0", + "ALGO_1", + "FFT", + "ALGO_3", + "WINOGRAD", + "WINOGRAD_NONFUSED" +} + +local bwdDataAlgoNames = { + "ALGO_0", + "ALGO_1", + "FFT", + "FFT_TILING", + "WINOGRAD", + "WINOGRAD_NONFUSED" +} + +local algoNames = {fwdAlgoNames, bwdFilterAlgoNames, bwdDataAlgoNames} + local function call(layer, f, ...) + if find.verbose then + + print("find:call: calling " .. f .. ", hash: ", layer.autotunerHash) + end local status = cudnn.call(f, ...) - if status ~= ffi.C.CUDNN_STATUS_SUCCESS and cudnn.verbose then + if status ~= ffi.C.CUDNN_STATUS_SUCCESS and (find.verbose or find.verboseError) then local stride = ffi.new('int[8]') local upscale = ffi.new('int[8]') local dim = ffi.new('int[8]') @@ -37,9 +76,9 @@ local function call(layer, f, ...) 4, dim, pad, stride, upscale, mode, datatype) print("find:call:" .. f .. " failed: ", tonumber(status) , ' mode : ', tonumber(mode[0]), ' datatype : ', tonumber(datatype[0])) - if layer.autotunerHash then - print("Hash sizes: " , layer.autotunerHash) - end + end + if find.verbose then + print("find:call: success, " .. f ) end return status end @@ -55,7 +94,7 @@ end find.errcheck = errcheck local function noFallback(layer) - if cudnn.verbose then + if find.verbose then print("find.defaultFallback: call failed for: ", layer.autotunerHash) end return false @@ -75,7 +114,7 @@ local function defaultFallback(layer, replay) upscale, mode, datatype) if datatype[0] == ffi.C.CUDNN_DATA_HALF then - if cudnn.verbose then + if find.verbose then if replay then print("find.defaultFallback: replay for ", layer.autotunerHash) else @@ -84,7 +123,7 @@ local function defaultFallback(layer, replay) end errcheck(layer,'cudnnSetConvolutionNdDescriptor', layer.convDesc[0], dim[0], pad, stride, - upscale, mode[0], 'CUDNN_DATA_FLOAT') + upscale, mode[0], ffi.C.CUDNN_DATA_FLOAT) return true else return false @@ -96,25 +135,15 @@ function find.create(id) local finder = {} setmetatable(finder,find) finder.id = id - finder:resetStateMachine() finder:resetAlgorithmCache() + finder:resetStateMachine() if cutorch.hasHalf then finder.fallback = defaultFallback end return finder end --- FindEx State Machine cycle works as follows: --- iteration #0(useDefaultWorkspaceSize) : call FindEx with default WS size (let everybody allocate I/O, weights etc) --- iteration #1(useMaxWorkspaceSize) : call FindEx with maximum WS size, calculate common target WS using largest WS requested --- iteration #2+(useCalculatedWorkspaceSize) : set calculated WS. call FindEx again with calculated WS size, cache the result --- note: calculatedWorkspaceSize array is attribute of the cache (maximum WS of the cached algos) and reset separately - --- This resets SM of particular device to cycle 0 : useDefaultWorkspaceSize function find:resetStateMachine() - self.useDefaultWorkspaceSize = true - self.useMaxWorkspaceSize = false - self.useCalculatedWorkspaceSize = false self.iteration = 0 end @@ -122,14 +151,15 @@ local finders = nil -- this resets algorithm cache for device function find:resetAlgorithmCache() self.calculatedWorkspaceSize = {} - self.maxWorkspaceSize = 0 + self:calculateMaxWorkspaceSize() self.useFindEx = cudnn.useFindEx and (cudnn.benchmark or cudnn.fastest) self.autotunerCache = {{}, {}, {}} end -function find.reset() +function find.reset(warmup) cutorch:synchronizeAll() finders = {} + warmupIterations = warmup or 0 end function find.get() @@ -143,80 +173,68 @@ function find.get() end function find:lookup(layer, findAPI_idx) - if self.useFindEx and not self.useCalculatedWorkspaceSize then - return nil - else - return self.autotunerCache[findAPI_idx][layer.autotunerHash] - end + return self.autotunerCache[findAPI_idx][layer.autotunerHash] end -- record algo, memory in cache function find:store(layer, findAPI_idx, cachedAlgo) - self.autotunerCache[findAPI_idx][layer.autotunerHash] = cachedAlgo + if warmupIterations==0 then + self.autotunerCache[findAPI_idx][layer.autotunerHash] = cachedAlgo + end end -function find:setMaxWorkspaceSize(reserve, fraction) +function find:calculateMaxWorkspaceSize(reserve, fraction) if not reserve or reserve < cudnn.reservedGPUBytes then reserve = cudnn.reservedGPUBytes end local max_fraction = cudnn.maxWorkspaceGPUMemPercent/100 if not fraction or fraction > max_fraction then fraction = max_fraction end + local buf, curSize = cudnn.getSharedWorkspace() -- check current usage local freeMemory, totalMemory = cutorch.getMemoryUsage(self.id) - - local ws, curSize = cudnn.getSharedWorkspace() local newSize= (freeMemory+curSize-reserve) * fraction - if (newSize > curSize) then - self.maxWorkspaceSize = newSize - cudnn.setSharedWorkspaceSize(newSize) - else - self.maxWorkspaceSize = curSize - end - self.useMaxWorkspaceSize = true - if cudnn.verbose then - print("setMaxWorkspaceSize Memory: ", freeMemory, totalMemory, self.maxWorkspaceSize) + self.maxWorkspaceSize = newSize + if find.verbose then + print("calculateMaxWorkspaceSize Memory: ", freeMemory/Meg, "M free, " , totalMemory/Meg, "M total, " , self.maxWorkspaceSize/Meg, "M Workspace" ) end end function find:setCalculatedWorkspaceSize(greater) - for i,bytes in pairs (self.calculatedWorkspaceSize) do - cudnn.setSharedWorkspaceSize(bytes, greater) + local device = cutorch.getDevice() + for stream,bytes in pairs (self.calculatedWorkspaceSize) do + cudnn.setSharedWorkspaceSize(bytes, greater, device, stream) end - self.useCalculatedWorkspaceSize = true end --- adjusts workspace immediately in no FindEx function find:registerWorkspaceSize(cachedAlgo) - local stream = cutorch.getStream() - if self.useFindEx then - if not self.calculatedWorkspaceSize[stream] then - self.calculatedWorkspaceSize[stream] = 0 - end - -- find algo with a size that keeps the sum of stream sizes within ws size - for a=1,#cachedAlgo do - local algoSize = cachedAlgo[a].memory - local delta = algoSize - self.calculatedWorkspaceSize[stream] - if delta > 0 then - -- check if we still fit - local totalWS = 0 - for s,sz in pairs(self.calculatedWorkspaceSize) do - totalWS = totalWS + sz - end - if totalWS + delta < self.maxWorkspaceSize then - self.calculatedWorkspaceSize[stream] = algoSize + delta - if cudnn.verbose then - print("find:registerWorkspaceSize: calculated ", self.calculatedWorkspaceSize[stream], " delta = ", delta, "max : " , self.maxWorkspaceSize) - end - return cachedAlgo[a].algo - end - else - return cachedAlgo[a].algo - end -- delta - end - return nil - else - -- no FindEx - do not rely on find stored data - cudnn.setSharedWorkspaceSize(cachedAlgo[1].memory, true) - return cachedAlgo[1].algo - end + local stream = cutorch.getStream() + + if not self.calculatedWorkspaceSize[stream] then + self.calculatedWorkspaceSize[stream] = 0 + end + + if self.calculatedWorkspaceSize[stream] > self.maxWorkspaceSize then + self.calculatedWorkspaceSize[stream] = self.maxWorkspaceSize + end + + -- find algo with a size that keeps the sum of stream sizes within ws size + for a=1,#cachedAlgo do + local algoSize = cachedAlgo[a].memory + local delta = algoSize - self.calculatedWorkspaceSize[stream] + if delta > 0 then + -- check if we still fit + local totalWS = 0 + for s,sz in pairs(self.calculatedWorkspaceSize) do + totalWS = totalWS + sz + end + if totalWS + delta < self.maxWorkspaceSize then + self.calculatedWorkspaceSize[stream] = algoSize + return a + end + else + -- keep previously calculated WS size for the stream + return a + end -- delta + end + return 0 end function find:reserveBytes(layer) @@ -226,8 +244,7 @@ function find:reserveBytes(layer) return reserve end - -function find:verifyWorkspaceSize(layer) +function find:verifyReserveForWeights(layer) local freeMemory, totalMemory = cutorch.getMemoryUsage(self.id) local reserve = self:reserveBytes(layer) if freeMemory < reserve then @@ -236,12 +253,11 @@ function find:verifyWorkspaceSize(layer) end end -function find:newIteration(layer) - if self.useCalculatedWorkspaceSize or not self.useFindEx then - --- end state - return false - end + +function find:advanceStateMachine(layer, findAPI_idx) + if warmupIterations == 0 then return end if not layer.iteration then layer.iteration = {0,0,0} end + -- find last iteration local max_iter = 0 for k,v in pairs(layer.iteration) do @@ -250,33 +266,9 @@ function find:newIteration(layer) if (self.iteration < max_iter and max_iter > 1) then self.iteration = max_iter - if cudnn.verbose then print ("CUDNN Find SM: iteration ", self.iteration) end - return true - else - return false + if find.verbose then print ("CUDNN Find SM: iteration #", self.iteration) end + if warmupIterations > 0 then warmupIterations = warmupIterations -1 end end -end - -function find:advanceStateMachine(layer, findAPI_idx) - if not self.useFindEx then return end - if self:newIteration(layer) then - -- iteration changed, advance state machine - if self.useMaxWorkspaceSize then - if cudnn.verbose then print ("CUDNN Find SM: max->calculated ", self.calculatedWorkspaceSize) end - self:setCalculatedWorkspaceSize() - self.useMaxWorkspaceSize = false - end - if self.useDefaultWorkspaceSize then - if self.useFindEx then - if cudnn.verbose then print ("CUDNN Find SM: default->max") end - self:setMaxWorkspaceSize(self:reserveBytes(layer)) - else - if cudnn.verbose then print ("CUDNN Find SM: default->calculated ", self.calculatedWorkspaceSize) end - self:setCalculatedWorkspaceSize(true) - end - self.useDefaultWorkspaceSize = false - end - end layer.iteration[findAPI_idx] = layer.iteration[findAPI_idx] + 1 end @@ -301,7 +293,7 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() local validResults = 0 - local API = cudnn.useFindEx and findExAlgos[findAPI_idx] + local API = self.useFindEx and findExAlgos[findAPI_idx] or ( (cudnn.benchmark or cudnn.fastest) and findNoExAlgos[findAPI_idx] or getAlgos[findAPI_idx]) local perfResults = perfResultsArray[findAPI_idx] @@ -312,7 +304,7 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) validResults = #cachedAlgo useFallback = cachedAlgo[1].fallback -- need to replay fallback on cache hit - if useFallback then self.fallback(layer) end + if useFallback then self.fallback(layer, true) end else cacheHit = '' cachedAlgo = {} @@ -322,13 +314,15 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) if findAPI_idx == BwdFilter then params[7] = params[7]:clone() end + self:calculateMaxWorkspaceSize() + cudnn.setSharedWorkspaceSize(self.maxWorkspaceSize) end local function callCudnn(layer) local ret = 0 validResults = 0 if cudnn.benchmark or cudnn.fastest then - if cudnn.useFindEx then + if self.useFindEx then ret = call(layer, API, cudnn.getHandle(), params[1], params[2]:data(), params[3], params[4]:data(), layer.convDesc[0], params[6], params[7]:data(), @@ -349,9 +343,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) params[1], params[3], layer.convDesc[0], params[6], algSearchMode, algWorkspaceLimit, algType[findAPI_idx]) local retAlgo = algType[findAPI_idx][0] - if cudnn.verbose then + if find.verbose then print(string.format( - "\n" .. getAPI .. ": %d (ws limit: %d) mode = %s", + "\n" .. API .. ": %d (ws limit: %d) mode = %s", tonumber(retAlgo), algWorkspaceLimit, algSearchMode)) @@ -361,9 +355,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) cudnn.getHandle(), params[1], params[3], layer.convDesc[0], params[6], retAlgo, bufSize:data()) - if cudnn.verbose then + if find.verbose then print(string.format( - "\n" .. wsAPI .. ": bufSize: %d, current ws: %d", + "\n" .. getWSAlgos[findAPI_idx] .. ": bufSize: %d, current ws: %d", tonumber(bufSize[1]), tonumber(extraBufferSize))) end perfResults[0].algo = retAlgo @@ -371,11 +365,11 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) perfResults[0].status = ret end - if cudnn.verbose then + if find.verbose then print("\ncallCudnn: ", API, "returned ", numPerfResults[0], " results , status = " , ret, "status[0] = " , perfResults[0].status, "\n") end - if ret ~= 0 or numPerfResults[0] < 1 then + if ret ~= 0 then return ret end @@ -388,22 +382,23 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) time = tonumber(res.time), status = tonumber(res.status), fallback = useFallback} - if cudnn.verbose and find.verbose then + if find.verbose then local fallback = '' if (useFallback) then fallback = "[FALLBACK]" end print(string.format( - "\n" .. API .. " algo: %d (status: %d), memory: %8d, count: %d" + "\n" .. API .. " algo: %s (%d, status: %d), memory: %8d, count: %d" .. " hash: %45s " .. cacheHit .. fallback, - cachedAlgo[validResults].algo, cachedAlgo[validResults].status, + algoNames[findAPI_idx][cachedAlgo[validResults].algo+1], cachedAlgo[validResults].algo, cachedAlgo[validResults].status, cachedAlgo[validResults].memory, r, layer.autotunerHash)) end end end - if validResults < 1 and cudnn.verbose then + if validResults < 1 and find.verbose then print("Could not find any valid convolution algorithms for sizes: " .. layer.autotunerHash) -- todo: add case of multi-stream not fitting in size + return 1 end - return ret + return 0 end -- do the actual call @@ -419,32 +414,43 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) end end self:store(layer, findAPI_idx, cachedAlgo) + if self.useFindEx then + cudnn.setSharedWorkspaceSize(extraBufferSize) + end end -- this may return different algo if size does not fit retAlgo = self:registerWorkspaceSize(cachedAlgo) - - - -- this may return different algo if size does not fit - retAlgo = self:registerWorkspaceSize(cachedAlgo) - - if cudnn.verbose then + if retAlgo==0 then + -- TODO: fallback to recalculate + error("No algorithms found that would fit in free memory") + return -1 + end + if cudnn.verbose or find.verbose then local freeMemory, totalMemory = cutorch.getMemoryUsage(self.id) local fallback = "" if (useFallback) then fallback = "[FALLBACK]" end print(string.format( - "\n" .. API .. ": %d(%d) Workspace: %8d (current ws size %d, free: %d) hash: %45s" .. cacheHit .. fallback, - retAlgo, #cachedAlgo, tonumber(cachedAlgo[1].memory), extraBufferSize, freeMemory, layer.autotunerHash)) + "\n" .. API .. ": %s(%d)[%d of %d] Workspace: %8fM (current ws size %fM, max: %dM free: %dM) hash: %45s" .. cacheHit .. fallback, + algoNames[findAPI_idx][cachedAlgo[retAlgo].algo+1], cachedAlgo[retAlgo].algo, retAlgo, #cachedAlgo, + tonumber(cachedAlgo[retAlgo].memory)/Meg, extraBufferSize/Meg, self.maxWorkspaceSize/Meg, freeMemory/Meg, layer.autotunerHash)) end - return retAlgo + return cachedAlgo[retAlgo].algo end function find:prepare(layer, input_slice, output_slice) local function shape(x) - return table.concat(x:size():totable(),'x') + return table.concat(x:size():totable(),',') + end + local function vals(x) + return table.concat(x:totable(),',') end - layer.autotunerHash = shape(layer.weight) .. ';' - .. shape(input_slice) .. ';' - .. shape(output_slice) .. "[" .. layer.padH .. ":" .. layer.padW .. ']' .. cudnn.configmap(torch.type(layer.weight)) + layer.autotunerHash = + '-dimA' .. shape(input_slice) + ..' -filtA' .. shape(layer.weight) + ..' ' .. shape(output_slice) + ..' -padA' .. vals(layer.pad) + ..' -convStrideA' .. vals(layer.stride) + .. ' ' .. cudnn.configmap(torch.type(layer.weight)) layer:resetMode() layer.iteration = nil diff --git a/functional.lua b/functional.lua index 4955c09..dfc5d8e 100644 --- a/functional.lua +++ b/functional.lua @@ -5,378 +5,5 @@ local cudnn = require 'cudnn.env' local ffi = require 'ffi' local errcheck = cudnn.errcheck - -local NULL -if not jit then - NULL = ffi.C.NULL -end - cudnn.functional = {} - - - - -local function Batch2D(t) - return t:view(1, t:size(1), t:size(2), t:size(3)) -end - --- accumulates the bias into output. --- output is assumed to be allocated and given. -cudnn.functional.bias2D_updateOutput = function(handle, bias, output) - output = output:dim() == 3 and Batch2D(output) or output - - local biasDesc = cudnn.toDescriptor(bias:view(1, bias:nElement(),1,1)) - local oDesc = cudnn.toDescriptor(output) - errcheck('cudnnAddTensor', handle, - cudnn.scalar(output, 1), biasDesc[0], bias:data(), - cudnn.scalar(output, 1), oDesc[0], output:data()) -end - --- accumulates the gradients into gradBias. --- gradBias is assumed to be allocated and given. -cudnn.functional.bias2D_accGradParameters = function(handle, gradOutput, gradBias, scale) - gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput - scale = scale or 1.0 - local scaleT = torch.FloatTensor({scale}) - local oDesc = cudnn.toDescriptor(gradOutput) - local biasDesc = cudnn.toDescriptor(gradBias:view(1, gradBias:nElement(),1,1)) - errcheck('cudnnConvolutionBackwardBias', handle, - scaleT:data(), - oDesc[0], gradOutput:data(), - cudnn.scalar(gradOutput, 1), - biasDesc[0], gradBias:data()) -end - --- Does a 2D Convolution (updateOutput) on input, weight --- output is assumed to be allocated and given. -cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, output, - strideH, strideW, padH, padW, workspace) - input = input:dim() == 3 and Batch2D(input) or input - output = output:dim() == 3 and Batch2D(output) or output - - -- create a weight descriptor - local weightDesc = ffi.new('struct cudnnFilterStruct*[1]') - errcheck('cudnnCreateFilterDescriptor', weightDesc) - local nOutputPlane, nInputPlane, kH, kW - = weight:size(1), weight:size(2), weight:size(3), weight:size(4) - local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, - desc:data()); - local function destroyWDesc(d) - errcheck('cudnnDestroyFilterDescriptor', d[0]); - end - ffi.gc(weightDesc, destroyWDesc) - - -- create a convolution descriptor - local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') - errcheck('cudnnCreateConvolutionDescriptor', convDesc) - local pad = torch.IntTensor({padH, padW}) - local stride = torch.IntTensor({strideH, strideW}) - local upscale = torch.IntTensor({1,1}) - errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap(torch.type(weight))); - local function destroyConvDesc(d) - errcheck('cudnnDestroyConvolutionDescriptor', d[0]); - end - ffi.gc(convDesc, destroyConvDesc) - - -- create input descriptor - local iDesc = cudnn.toDescriptor(input) - - -- create output descriptor - local oSize = torch.IntTensor(4) - errcheck('cudnnGetConvolutionNdForwardOutputDim', - convDesc[0], iDesc[0], - weightDesc[0], 4, oSize:data()) - oSize = oSize:long() - assert(output:dim() == 4 and - output:size(1) == oSize[1] and - output:size(2) == oSize[2] and - output:size(3) == oSize[3] and - output:size(4) == oSize[4], - 'Output is of wrong size') - -- create descriptor for output - local oDesc = cudnn.toDescriptor(output) - - -- create forwardAlgorithm descriptors for - local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' - local algWorkspaceLimit = 0 - if workspace then - algWorkspaceLimit = workspace:nElement() * workspace:elementSize() - end - errcheck('cudnnGetConvolutionForwardAlgorithm', - handle, - iDesc[0], weightDesc[0], - convDesc[0], oDesc[0], - algSearchMode, algWorkspaceLimit, algType) - - -- do convolution - errcheck('cudnnConvolutionForward', handle, - cudnn.scalar(input, 1), - iDesc[0], input:data(), - weightDesc[0], weight:data(), - convDesc[0], algType[0], - workspace and workspace:data() or nil, algWorkspaceLimit, - cudnn.scalar(input, 0), - oDesc[0], output:data()); -end - --- Does a 2D Convolution (updateGradInput) on input, weight, output, gradOutput --- gradInput is assumed to be allocated and given. -cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, output, gradOutput, gradInput, - strideH, strideW, padH, padW) - input = input:dim() == 3 and Batch2D(input) or input - output = output:dim() == 3 and Batch2D(output) or output - gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput - gradInput = gradInput:dim() == 3 and Batch2D(gradInput) or gradInput - - -- create a weight descriptor - local weightDesc = ffi.new('struct cudnnFilterStruct*[1]') - errcheck('cudnnCreateFilterDescriptor', weightDesc) - local nOutputPlane, nInputPlane, kH, kW - = weight:size(1), weight:size(2), weight:size(3), weight:size(4) - local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, - desc:data()); - local function destroyWDesc(d) - errcheck('cudnnDestroyFilterDescriptor', d[0]); - end - ffi.gc(weightDesc, destroyWDesc) - - -- create a convolution descriptor - local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') - errcheck('cudnnCreateConvolutionDescriptor', convDesc) - local pad = torch.IntTensor({padH, padW}) - local stride = torch.IntTensor({strideH, strideW}) - local upscale = torch.IntTensor({1,1}) - errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap(torch.type(weight))); - local function destroyConvDesc(d) - errcheck('cudnnDestroyConvolutionDescriptor', d[0]); - end - ffi.gc(convDesc, destroyConvDesc) - - -- create input, output descriptor - local iDesc = cudnn.toDescriptor(input) - local oDesc = cudnn.toDescriptor(output) - - local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' - - errcheck('cudnnGetConvolutionBackwardDataAlgorithm', - cudnn.getHandle(), - weightDesc[0], oDesc[0], - convDesc[0], iDesc[0], - algSearchMode, 0, algType) - - -- do convolution - errcheck('cudnnConvolutionBackwardData', handle, - cudnn.scalar(input, 1), - weightDesc[0], weight:data(), - oDesc[0], gradOutput:data(), - convDesc[0], - algType[0], - NULL, 0, - cudnn.scalar(input, 0), - iDesc[0], gradInput:data()); - - -end - --- accumulates the gradients into gradWeight. --- gradWeight is assumed to be allocated and given. -local scaleT = torch.FloatTensor(1):fill(1.0) -cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradWeight, gradOutput, - strideH, strideW, padH, padW, scale) - input = input:dim() == 3 and Batch2D(input) or input - gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput - - scale = scale or 1.0 - scaleT[1] = scale - -- create a weight descriptor - local weightDesc = ffi.new('struct cudnnFilterStruct*[1]') - errcheck('cudnnCreateFilterDescriptor', weightDesc) - local nOutputPlane, nInputPlane, kH, kW - = gradWeight:size(1), gradWeight:size(2), gradWeight:size(3), gradWeight:size(4) - local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, - desc:data()); - local function destroyWDesc(d) - errcheck('cudnnDestroyFilterDescriptor', d[0]); - end - ffi.gc(weightDesc, destroyWDesc) - - -- create a convolution descriptor - local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') - errcheck('cudnnCreateConvolutionDescriptor', convDesc) - local pad = torch.IntTensor({padH, padW}) - local stride = torch.IntTensor({strideH, strideW}) - local upscale = torch.IntTensor({1,1}) - errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap(torch.type(gradWeight))); - local function destroyConvDesc(d) - errcheck('cudnnDestroyConvolutionDescriptor', d[0]); - end - ffi.gc(convDesc, destroyConvDesc) - - -- create input, output descriptor - local iDesc = cudnn.toDescriptor(input) - local oDesc = cudnn.toDescriptor(gradOutput) - - local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' - local algWorkspaceLimit = 0 - - errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', - cudnn.getHandle(), - iDesc[0], oDesc[0], - convDesc[0], weightDesc[0], - algSearchMode, algWorkspaceLimit, algType) - - - -- do convolution - errcheck('cudnnConvolutionBackwardFilter', handle, - scaleT:data(), - iDesc[0], input:data(), - oDesc[0], gradOutput:data(), - convDesc[0], - algType[0], - NULL, 0, - cudnn.scalar(input, 1), - weightDesc[0], gradWeight:data()); -end - - - --- Does a 2D Pooling (updateOutput) on input, weight --- output is assumed to be allocated and given. -cudnn.functional.Pooling_updateOutput = function(handle, mode, input, output, - kH, kW, dH, dW, padH, padW, ceil_mode) - input = input:dim() == 3 and Batch2D(input) or input - output = output:dim() == 3 and Batch2D(output) or output - - padH = padH or 0 - padW = padW or 0 - ceil_mode = ceil_mode or false - - local oW, oH - if ceil_mode then - oW = math.ceil((input:size(4)+padW*2 - kW)/dW + 1) - oH = math.ceil((input:size(3)+padH*2 - kH)/dH + 1) - else - oW = math.floor((input:size(4)+padW*2 - kW)/dW + 1) - oH = math.floor((input:size(3)+padH*2 - kH)/dH + 1) - end - assert(oH == output:size(3) and oW == output:size(4), - 'size mismatch: ' .. oH .. 'x' .. oW .. ' vs ' .. - output:size(3) .. 'x' .. output:size(4)) - - -- create pooling descriptor - local poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') - errcheck('cudnnCreatePoolingDescriptor', poolDesc) - local ker = torch.IntTensor({kH, kW}) - local str = torch.IntTensor({dH, dW}) - local pad = torch.IntTensor({padH, padW}) - errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 2, - ker:data(), pad:data(), str:data()); - local function destroyPoolDesc(d) - errcheck('cudnnDestroyPoolingDescriptor', d[0]); - end - ffi.gc(poolDesc, destroyPoolDesc) - - -- create input, output descriptor - local iDesc = cudnn.toDescriptor(input) - local oDesc = cudnn.toDescriptor(output) - - -- pool - errcheck('cudnnPoolingForward', handle, - poolDesc[0], - cudnn.scalar(input, 1), - iDesc[0], input:data(), - cudnn.scalar(input, 0), - oDesc[0], output:data()); -end - -cudnn.functional.MaxPooling2D_updateOutput = function(handle, input, output, - kH, kW, dH, dW, padH, padW, ceil_mode) - cudnn.functional.Pooling_updateOutput(handle, 'CUDNN_POOLING_MAX', input, output, - kH, kW, dH, dW, padH, padW, ceil_mode); -end - -cudnn.functional.AveragePooling2D_updateOutput = function(handle, input, output, - kH, kW, dH, dW, padH, padW, ceil_mode) - cudnn.functional.Pooling_updateOutput(handle, 'CUDNN_POOLING_AVERAGE', input, output, - kH, kW, dH, dW, padH, padW, ceil_mode); -end - --- Does a 2D Pooling (updateGradInput) on input, weight --- output is assumed to be allocated and given. -cudnn.functional.Pooling_updateGradInput = function(handle, mode, input, output, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, ceil_mode) - input = input:dim() == 3 and Batch2D(input) or input - output = output:dim() == 3 and Batch2D(output) or output - gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput - gradInput = gradInput:dim() == 3 and Batch2D(gradInput) or gradInput - - padH = padH or 0 - padW = padW or 0 - ceil_mode = ceil_mode or false - - local oW, oH - if ceil_mode then - oW = math.ceil((input:size(4)+padW*2 - kW)/dW + 1) - oH = math.ceil((input:size(3)+padH*2 - kH)/dH + 1) - else - oW = math.floor((input:size(4)+padW*2 - kW)/dW + 1) - oH = math.floor((input:size(3)+padH*2 - kH)/dH + 1) - end - assert(oH == output:size(3) and oW == output:size(4), - 'size mismatch: ' .. oH .. 'x' .. oW .. ' vs ' .. - output:size(3) .. 'x' .. output:size(4)) - - -- create pooling descriptor - local poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') - errcheck('cudnnCreatePoolingDescriptor', poolDesc) - local ker = torch.IntTensor({kH, kW}) - local str = torch.IntTensor({dH, dW}) - local pad = torch.IntTensor({padH, padW}) - errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 2, - ker:data(), pad:data(), str:data()); - local function destroyPoolDesc(d) - errcheck('cudnnDestroyPoolingDescriptor', d[0]); - end - ffi.gc(poolDesc, destroyPoolDesc) - - -- create input, output descriptor - local iDesc = cudnn.toDescriptor(input) - local oDesc = cudnn.toDescriptor(output) - - -- pool - errcheck('cudnnPoolingBackward', - handle, poolDesc[0], - cudnn.scalar(input, 1), - oDesc[0], output:data(), - oDesc[0], gradOutput:data(), - iDesc[0], input:data(), - cudnn.scalar(input, 0), - iDesc[0], gradInput:data()); -end - -cudnn.functional.MaxPooling2D_updateGradInput = function(handle, input, output, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, ceil_mode) - cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_MAX', input, output, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, ceil_mode); -end - -cudnn.functional.AveragePooling2D_updateGradInput = function(handle, input, output, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, ceil_mode) - cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_AVERAGE', input, output, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, ceil_mode); -end +error('cudnn.functional is obsolete, should not be used!') @@ -190,12 +190,6 @@ end local sharedBuffer = {} local nextBufferSize = {} -local function setNextSize(buf, size, ifGreater) - if size > buf.nextSize or not ifGreater then - buf.nextSize = size - end -end - -- may reassign currentSize local function allocateStorage(buf, ifGreater) @@ -210,9 +204,6 @@ local function allocateStorage(buf, ifGreater) if buf.storage then if (newelem == buf.storage:size()) or (ifGreater and newelem < buf.storage:size()) then else - if cudnn.verbose then - print( "allocateStorage: new WS size is ", buf.nextSize) - end -- resize to just to make sure we return memory buf.storage:resize(0) buf.storage:resize(newelem) @@ -228,24 +219,26 @@ local function allocateStorage(buf, ifGreater) buf.nextSize = -1 end -local function sharedBufForCurrentStream() - local device = cutorch.getDevice() - local stream = cutorch.getStream() -- starts from 0 - if not sharedBuffer[device] then sharedBuffer[device] = {} end - local buf = sharedBuffer[device][stream] - if not buf then - buf = { - currentSize = cudnn.initialWorkspaceBytes, - nextSize = -1 - } - allocateStorage(buf) - sharedBuffer[device][stream] = buf - end - return buf +local function sharedBufForStream(device, stream) + device = device or cutorch.getDevice() + stream = stream or cutorch.getStream() -- starts from 0 + if not sharedBuffer[device] then sharedBuffer[device] = {} end + local buf = sharedBuffer[device][stream] + if not buf then + buf = { + currentSize = cudnn.initialWorkspaceBytes, + nextSize = -1 + } + allocateStorage(buf) + sharedBuffer[device][stream] = buf + end + return buf end -function cudnn.getSharedWorkspace() - local buf = sharedBufForCurrentStream() +function cudnn.getSharedWorkspace(device, stream) + device = device or cutorch.getDevice() + stream = stream or cutorch.getStream() + local buf = sharedBufForStream(device, stream) return buf.data, buf.currentSize end @@ -257,21 +250,25 @@ function cudnn.externalizeString(luaStr) return cStr end -function cudnn.adjustSharedWorkspaceSize(bytesDelta) - local buf = sharedBufForCurrentStream() - setNextSize(buf, buf.currentSize + bytesDelta) +function cudnn.adjustSharedWorkspaceSize(bytesDelta, device, stream) + local buf = sharedBufForStream(device, stream) + buf.nextSize = buf.currentSize + bytesDelta allocateStorage(buf) end -function cudnn.setSharedWorkspaceSize(bytes, ifGreater) - local buf = sharedBufForCurrentStream() - ifGreater = ifGreater or false +function cudnn.setNextWorkspaceSize(bytes, device, stream) + local buf = sharedBufForStream(device, stream) + buf.nextSize = bytes + return buf +end + +function cudnn.setSharedWorkspaceSize(bytes, ifGreater, device, stream) bytes = bytes or cudnn.initialWorkspaceBytes - setNextSize(buf, bytes, ifGreater) + local buf = cudnn.setNextWorkspaceSize(bytes, device, stream) allocateStorage(buf, ifGreater) end -local find = require('cudnn.find') +cudnn.find = require('cudnn.find') require('cudnn.SpatialConvolution') require('cudnn.VolumetricConvolution') @@ -307,7 +304,6 @@ require('cudnn.BLSTM') require('cudnn.LSTM') require('cudnn.BGRU') require('cudnn.GRU') -require('cudnn.functional') require('cudnn.convert') function cudnn.reset() @@ -322,7 +318,7 @@ function cudnn.reset() end collectgarbage() -- this resets internal algorithm finder state machine and cache - find.reset() + cudnn.find.reset() end return cudnn diff --git a/test/test.lua b/test/test.lua index eed428f..7526831 100644 --- a/test/test.lua +++ b/test/test.lua @@ -126,20 +126,20 @@ function cudnntest.SpatialConvolution() local bs = math.random(1,32) local from = math.random(1,32) local to = math.random(1,64) - local ki = math.random(1,9) - local kj = math.random(1,9) + local ki = math.random(1,15) + local kj = math.random(1,15) local si = math.random(1,ki) local sj = math.random(1,kj) local outi = math.random(1,64) local outj = math.random(1,64) local ini = (outi-1)*si+ki local inj = (outj-1)*sj+kj + local scale = math.random() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,to,outj,outi):cuda() local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() - local gconv = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj)) - + local gconv = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj)):fastest() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) @@ -154,12 +154,11 @@ function cudnntest.SpatialConvolution() end function cudnntest.SpatialFullConvolution() - local bs = math.random(1,32) local from = math.random(1,32) local to = math.random(1,64) - local ki = math.random(1,9) - local kj = math.random(1,9) + local ki = math.random(1,15) + local kj = math.random(1,15) local si = math.random(1,ki) local sj = math.random(1,kj) local ini = math.random(1,64) @@ -186,12 +185,12 @@ function cudnntest.SpatialFullConvolution() end function cudnntest.TemporalConvolution() - local bs = math.random(2,32) + local bs = math.random(1,32) local inputFrameSize = math.random(1,64) local outputFrameSize = math.random(1,64) - local ki = math.random(2,6) - local si = math.random(2,ki) - local outi = math.random(2,9) + local ki = math.random(1,15) + local si = math.random(1,ki) + local outi = math.random(1,15) local ini = (outi - 1) * si + ki local scale = math.random() @@ -208,13 +207,13 @@ function cudnntest.TemporalConvolution() end function cudnntest.TemporalConvolution_padding_batch() - local bs = math.random(2,32) - local inputFrameSize = math.random(2,64) - local outputFrameSize = math.random(2,64) - local ki = math.random(2,9) + local bs = math.random(1,32) + local inputFrameSize = math.random(1,64) + local outputFrameSize = math.random(1,64) + local ki = math.random(2,15) local pad_h = math.floor(ki/2) - local si = math.random(1,ki,1) - local outi = math.random(2,9) + local si = math.random(1,ki) + local outi = math.random(2,15) local ini = (outi-1)*si+ki local scale = math.random() @@ -264,9 +263,9 @@ end function cudnntest.TemporalConvolution_reduceBatchSize() local inputFrameSize = math.random(1,64) local outputFrameSize = math.random(1,64) - local ki = math.random(1,9) + local ki = math.random(1,15) local si = math.random(1,ki) - local outi = math.random(2,9) + local outi = math.random(1,15) local ini = (outi-1)*si+ki local batchSize = 128 local smallerBatchSize = batchSize/2 @@ -288,9 +287,9 @@ function cudnntest.VolumetricConvolution() local bs = math.random(1,32) local from = math.random(1,16) local to = math.random(1,16) - local ki = math.random(3,5,3) - local kj = math.random(3,5,3) - local kk = math.random(3,5,3) + local ki = math.random(3,5) + local kj = math.random(3,5) + local kk = math.random(3,5) local si = math.random(1,ki-1) local sj = math.random(1,kj-1) local sk = math.random(1,kk-1) @@ -298,18 +297,10 @@ function cudnntest.VolumetricConvolution() local outj = math.random(1,17) local outk = math.random(1,5) - if testparams.test_type == 'torch.CudaHalfTensor' then - --- CUDNN causes some corruption here - si, sj, sk = 1,1,1 - ki, kj, kk = 3,3,3 - outi, outj, outk = 1,1,1 - --- was not able to restrict parameters so that CUDNN would behave ... - return - end + local ini = outi*si+ki-1 + local inj = outj*sj+kj-1 + local ink = outk*sk+kk-1 - local ini = (outi-1)*si+ki - local inj = (outj-1)*sj+kj - local ink = (outk-1)*sk+kk local scale = math.random() local input = torch.randn(bs,from,ink,inj,ini):cuda() @@ -472,7 +463,7 @@ function cudnntest.SpatialCrossMapLRN_batch() local inputSize = math.random(6,9) local size = math.random(1,3)*2+1 local nbfeatures = math.random(3,8) - local alpha = math.random(1,100)/100 + local alpha = math.random(0,100)/100 local beta = math.random(1,100)/100 local k = math.random(1,3) @@ -773,90 +764,6 @@ function cudnntest.VolumetricCrossEntropyCriterion() 'error in difference between central difference and :backward') end -function cudnntest.functional_bias2D() - local bs = math.random(1,32) - local from = math.random(1,32) - local to = math.random(1,64) - local ki = math.random(1,15) - local kj = math.random(1,15) - local si = math.random(1,ki) - local sj = math.random(1,kj) - local outi = math.random(1,64) - local outj = math.random(1,64) - local ini = (outi-1)*si+ki - local inj = (outj-1)*sj+kj - local scale = torch.uniform() - local input = torch.zeros(bs,from,inj,ini):cuda() - local mod = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() - mod.weight:zero() - local groundtruth = mod:forward(input) - local result = groundtruth:clone():zero() - cudnn.functional.bias2D_updateOutput(cudnn.getHandle(), mod.bias, result) - local error = result:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), - testparams.precision_forward, 'error on forward ') - - mod:zeroGradParameters() - local gradOutput = groundtruth:clone():normal() - mod:backward(input, gradOutput, scale) - local groundtruth = mod.gradBias - local result = groundtruth:clone():zero() - cudnn.functional.bias2D_accGradParameters(cudnn.getHandle(), gradOutput, result, scale) - error = result:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), - testparams.precision_backward, 'error on accGradParameters ') -end - -function cudnntest.functional_convolution2d() - local a=cudnn.SpatialConvolution(3,16,5,5):cuda() - a.bias:zero(); - local input = torch.randn(10,3,10,10):cuda() - a:zeroGradParameters() - a:forward(input); - local output = a.output:clone():normal() - local gradOutput = a.output:clone():normal() - local gradInput = a:backward(input, gradOutput):clone():normal() - local gradWeight = a.gradWeight:clone():zero() - cudnn.functional.Convolution2D_updateOutput(cudnn.getHandle(), input, - a.weight, output, a.dH, - a.dW, a.padH, a.padW) - mytester:assertlt((output - a.output):abs():max(), - testparams.precision_forward, 'error on forward ') - - cudnn.functional.Convolution2D_updateGradInput(cudnn.getHandle(), input, - a.weight, output, gradOutput, - gradInput, - a.dH, a.dW, a.padH, a.padW) - mytester:assertlt((gradInput - a.gradInput):abs():max(), - testparams.precision_forward, 'error on updateGradInput ') - - cudnn.functional.Convolution2D_accGradParameters(cudnn.getHandle(), input, - gradWeight, gradOutput, - a.dH, a.dW, a.padH, a.padW) - mytester:assertlt((gradWeight - a.gradWeight):abs():max(), - testparams.precision_forward, 'error on accGradParameters ') -end - -function cudnntest.functional_maxpooling2d() - local a=cudnn.SpatialMaxPooling(2,2,2,2):cuda() - local input = torch.randn(10,3,10,10):cuda() - a:forward(input); - local output = a.output:clone():normal() - local gradOutput = a.output:clone():normal() - local gradInput = a:backward(input, gradOutput):clone():normal() - cudnn.functional.MaxPooling2D_updateOutput(cudnn.getHandle(), input, - output, a.kH, a.kW, - a.dH, a.dW, a.padH, a.padW) - mytester:assertlt((output - a.output):abs():max(), - testparams.precision_forward, 'error on forward ') - - cudnn.functional.MaxPooling2D_updateGradInput(cudnn.getHandle(), input, - output, gradOutput, gradInput, - a.kH, a.kW, a.dH, a.dW, - a.padH, a.padW) - mytester:assertlt((gradInput - a.gradInput):abs():max(), - testparams.precision_forward, 'error on updateGradInput ') -end torch.setdefaulttensortype('torch.FloatTensor') math.randomseed(os.time()) @@ -864,15 +771,14 @@ mytester = torch.Tester() mytester:add(cudnntest) -- cudnn.verbose=true - --- Developers, do not commit uncommented regions until bindings fixed --- TODO: adapt tests for FindEx +-- cudnn.find.verbose=true -- cudnn.useFindEx=true for i = 1, cutorch.getDeviceCount() do for _, benchmark in ipairs({false, true}) do cudnn.benchmark = benchmark +-- cudnn.reset() local prop = cutorch.getDeviceProperties(i) print('Running test on device: #' .. i .. ' : ' .. prop.name @@ -891,6 +797,7 @@ for i = 1, cutorch.getDeviceCount() do print'Testing torch.CudaDoubleTensor' testparams = testparams_double mytester:run() + end end |