diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-24 17:34:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-24 17:34:35 +0300 |
commit | 097a4bee165b12ee9e5060549a361a043f4db541 (patch) | |
tree | 8bbeb32011156b90d289f6839c4defee446c3b1f | |
parent | a5fb4c0b22a5440c836994d2d6ba2905d30655e5 (diff) | |
parent | 4e95c7d8e41ddc26eb3b56c76dec4e4afbcab051 (diff) |
Merge pull request #280 from soumith/functional-findex
Use FindEx in functional
-rw-r--r-- | functional.lua | 156 |
1 files changed, 55 insertions, 101 deletions
diff --git a/functional.lua b/functional.lua index deaf839..e0ca9cd 100644 --- a/functional.lua +++ b/functional.lua @@ -5,6 +5,7 @@ local cudnn = require 'cudnn.env' local ffi = require 'ffi' local errcheck = cudnn.errcheck +local find = require 'cudnn.find' cudnn.functional = {} local function getMathType(weight) @@ -50,20 +51,6 @@ cudnn.functional.bias2D_accGradParameters = function(handle, gradOutput, gradBia end -local function getWsPtrAndSize(workspace, maxBufSize) - local wsPtr, extraBufferSizeInBytes - if workspace then - if maxBufSize > workspace:nElement()*workspace:elementSize() then - local nElems = math.ceil(maxBufSize/workspace:elementSize()) - workspace:resize(nElems) - end - else - cudnn.setSharedWorkspaceSize(maxBufSize,true) - wsPtr, extraBufferSizeInBytes = cudnn.getSharedWorkspace() - end - return wsPtr, extraBufferSizeInBytes -end - -- Does a 2D Convolution (updateOutput) on input, weight -- output is assumed to be allocated and given. cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, output, @@ -102,41 +89,30 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou -- create descriptor for output local oDesc = cudnn.toDescriptor(output) - -- create forwardAlgorithm descriptors for - local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' - if cudnn.fastest then - algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' - end - local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight) - - errcheck('cudnnGetConvolutionForwardAlgorithm', - handle, - iDesc[0], weightDesc[0], - convDesc[0], oDesc[0], - algSearchMode, algWorkspaceLimit, algType) - - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionForwardWorkspaceSize', - handle, - iDesc[0], weightDesc[0], - convDesc[0], oDesc[0], - algType[0], bufSize:data()) - local maxBufSize = bufSize[1] - local wsPtr, extraBufferSizeInBytes = getWsPtrAndSize(workspace, maxBufSize) - - if maxBufSize > extraBufferSizeInBytes then - extraBuffer:resize(math.ceil(maxBufSize / extraBuffer:elementSize())) - extraBufferSizeInBytes = maxBufSize - end + local layer = { + convDesc = convDesc, + weight = weight, + nInputPlane = nInputPlane, + nOutputPlane = nOutputPlane, + kW = kW, + kH = kH, + pad = {padH, padW}, + stride = {strideH, strideW}, + } + + local finder = find.get() + find:prepare(layer, input, output) + local fwdAlgo = finder:forwardAlgorithm(layer, {iDesc[0], input, weightDesc[0], + weight, convDesc[0], oDesc[0], output}) + local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() -- do convolution errcheck('cudnnConvolutionForward', handle, cudnn.scalar(input, 1), iDesc[0], input:data(), weightDesc[0], weight:data(), - convDesc[0], algType[0], - wsPtr, extraBufferSizeInBytes, + convDesc[0], fwdAlgo, + extraBuffer, extraBufferSize, cudnn.scalar(input, 0), oDesc[0], output:data()); end @@ -168,41 +144,30 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, local iDesc = cudnn.toDescriptor(input) local oDesc = cudnn.toDescriptor(output) - local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' - if cudnn.fastest then - algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' - end - local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight) - - errcheck('cudnnGetConvolutionBackwardDataAlgorithm', - handle, - weightDesc[0], oDesc[0], - convDesc[0], iDesc[0], - algSearchMode, algWorkspaceLimit, algType) - - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize', - handle, - weightDesc[0], oDesc[0], - convDesc[0], iDesc[0], - algType[0], bufSize:data()) - local maxBufSize = bufSize[1] - local wsPtr, extraBufferSizeInBytes = getWsPtrAndSize(workspace, maxBufSize) - - if maxBufSize > extraBufferSizeInBytes then - extraBuffer:resize(math.ceil(maxBufSize / extraBuffer:elementSize())) - extraBufferSizeInBytes = maxBufSize - end + local layer = { + convDesc = convDesc, + weight = weight, + nInputPlane = nInputPlane, + nOutputPlane = nOutputPlane, + kW = kW, + kH = kH, + pad = {padH, padW}, + stride = {strideH, strideW}, + } + + local finder = find.get() + find:prepare(layer, input, output) + local bwdDataAlgo = finder:backwardDataAlgorithm(layer, {weightDesc[0], weight, oDesc[0], + output, convDesc[0], iDesc[0], input}) + local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() -- do convolution errcheck('cudnnConvolutionBackwardData', handle, cudnn.scalar(input, 1), weightDesc[0], weight:data(), oDesc[0], gradOutput:data(), - convDesc[0], - algType[0], - wsPtr, extraBufferSizeInBytes, + convDesc[0], bwdDataAlgo, + extraBuffer, extraBufferSize, cudnn.scalar(input, 0), iDesc[0], gradInput:data()); end @@ -234,41 +199,30 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW local iDesc = cudnn.toDescriptor(input) local oDesc = cudnn.toDescriptor(gradOutput) - local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) - local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' - if cudnn.fastest then - algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' - end - local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(gradWeight) - - errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', - handle, - iDesc[0], oDesc[0], - convDesc[0], weightDesc[0], - algSearchMode, algWorkspaceLimit, algType) - - local bufSize = torch.LongTensor(1) - errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize', - handle, - iDesc[0], oDesc[0], - convDesc[0], weightDesc[0], - algType[0], bufSize:data()) - local maxBufSize = bufSize[1] - local wsPtr, extraBufferSizeInBytes = getWsPtrAndSize(workspace, maxBufSize) - - if maxBufSize > extraBufferSizeInBytes then - extraBuffer:resize(math.ceil(maxBufSize / extraBuffer:elementSize())) - extraBufferSizeInBytes = maxBufSize - end + local layer = { + convDesc = convDesc, + weight = gradWeight, + nInputPlane = nInputPlane, + nOutputPlane = nOutputPlane, + kW = kW, + kH = kH, + pad = {padH, padW}, + stride = {strideH, strideW}, + } + + local finder = find.get() + find:prepare(layer, input, gradOutput) + local bwdFilterAlgo = finder:backwardFilterAlgorithm(layer, {iDesc[0], input, oDesc[0], + gradOutput, convDesc[0], weightDesc[0], gradWeight}) + local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() -- do convolution errcheck('cudnnConvolutionBackwardFilter', handle, scaleT:data(), iDesc[0], input:data(), oDesc[0], gradOutput:data(), - convDesc[0], - algType[0], - wsPtr, extraBufferSizeInBytes, + convDesc[0], bwdFilterAlgo, + extraBuffer, extraBufferSize, cudnn.scalar(input, 1), weightDesc[0], gradWeight:data()); end |