diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-09-29 21:33:35 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-29 21:33:35 +0300 |
commit | 01e765d8ef8cadb079fb5063918b1524061b3241 (patch) | |
tree | 5171cacdc7eccd8572d2826d7d6a30e0ae64d6d8 | |
parent | 6b454379d7b784c56b1d4710bcac3c0f5dbc471c (diff) | |
parent | 40a9ab188affe55e2faa2a0cf4bf226199f3c454 (diff) |
Merge pull request #265 from szagoruyko/fastest-functional
cudnn.fastest in functional
-rw-r--r-- | functional.lua | 74 |
1 files changed, 62 insertions, 12 deletions
diff --git a/functional.lua b/functional.lua index d721b0d..84bd7a5 100644 --- a/functional.lua +++ b/functional.lua @@ -103,23 +103,39 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou -- create forwardAlgorithm descriptors for local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' - local algWorkspaceLimit = 0 - if workspace then - algWorkspaceLimit = workspace:nElement() * 4 -- 4 = sizeof float + if cudnn.fastest then + algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end + local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight) + errcheck('cudnnGetConvolutionForwardAlgorithm', handle, iDesc[0], weightDesc[0], convDesc[0], oDesc[0], algSearchMode, algWorkspaceLimit, algType) + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionForwardWorkspaceSize', + handle, + iDesc[0], weightDesc[0], + convDesc[0], oDesc[0], + algType[0], bufSize:data()) + local maxBufSize = bufSize[1] + + local extraBuffer = workspace or cudnn.getSharedWorkspace() + local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float + if maxBufSize > extraBufferSizeInBytes then + extraBuffer:resize(math.ceil(maxBufSize / 4)) + extraBufferSizeInBytes = maxBufSize + end + -- do convolution errcheck('cudnnConvolutionForward', handle, cudnn.scalar(input, 1), iDesc[0], input:data(), weightDesc[0], weight:data(), convDesc[0], algType[0], - workspace and workspace:data() or nil, algWorkspaceLimit, + extraBuffer:data(), extraBufferSizeInBytes, cudnn.scalar(input, 0), oDesc[0], output:data()); end @@ -167,12 +183,31 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' + if cudnn.fastest then + algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' + end + local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight) errcheck('cudnnGetConvolutionBackwardDataAlgorithm', - cudnn.getHandle(), + handle, weightDesc[0], oDesc[0], convDesc[0], iDesc[0], - algSearchMode, 0, algType) + algSearchMode, algWorkspaceLimit, algType) + + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize', + handle, + weightDesc[0], oDesc[0], + convDesc[0], iDesc[0], + algType[0], bufSize:data()) + local maxBufSize = bufSize[1] + + local extraBuffer = cudnn.getSharedWorkspace() + local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float + if maxBufSize > extraBufferSizeInBytes then + extraBuffer:resize(math.ceil(maxBufSize / 4)) + extraBufferSizeInBytes = maxBufSize + end -- do convolution errcheck('cudnnConvolutionBackwardData', handle, @@ -181,11 +216,9 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, oDesc[0], gradOutput:data(), convDesc[0], algType[0], - NULL, 0, + extraBuffer:data(), extraBufferSizeInBytes, cudnn.scalar(input, 0), iDesc[0], gradInput:data()); - - end -- accumulates the gradients into gradWeight. @@ -232,14 +265,31 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' - local algWorkspaceLimit = 0 + if cudnn.fastest then + algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' + end + local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(gradWeight) errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', - cudnn.getHandle(), + handle, iDesc[0], oDesc[0], convDesc[0], weightDesc[0], algSearchMode, algWorkspaceLimit, algType) + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize', + handle, + iDesc[0], oDesc[0], + convDesc[0], weightDesc[0], + algType[0], bufSize:data()) + local maxBufSize = bufSize[1] + + local extraBuffer = cudnn.getSharedWorkspace() + local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float + if maxBufSize > extraBufferSizeInBytes then + extraBuffer:resize(math.ceil(maxBufSize / 4)) + extraBufferSizeInBytes = maxBufSize + end -- do convolution errcheck('cudnnConvolutionBackwardFilter', handle, @@ -248,7 +298,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW oDesc[0], gradOutput:data(), convDesc[0], algType[0], - NULL, 0, + extraBuffer:data(), extraBufferSizeInBytes, cudnn.scalar(input, 1), weightDesc[0], gradWeight:data()); end |