diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-06-26 22:11:24 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-08-02 20:38:44 +0300 |
commit | f6f22a3bf2ee4b920b7a38a61d0be911377f0d47 (patch) | |
tree | 1f301fa2023b9e9a2bfbae90c93b0a89dc9e0906 | |
parent | 3e6e918dac9e94d2f104da6e36f749312e5c3951 (diff) |
working R3 bindings for non-new modules
-rw-r--r-- | SpatialAveragePooling.lua | 2 | ||||
-rw-r--r-- | SpatialConvolution.lua | 102 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 144 | ||||
-rw-r--r-- | ffi.lua | 340 | ||||
-rw-r--r-- | test/benchmark.lua | 72 | ||||
-rw-r--r-- | test/test.lua | 4 |
6 files changed, 533 insertions, 131 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index 5da3db2..51a4119 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -3,5 +3,5 @@ local SpatialAveragePooling, parent function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH) parent.__init(self, kW, kH, dW, dH, padW, padH) - self.mode = 'CUDNN_POOLING_AVERAGE' + self.mode = 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 98091cf..5734c47 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -51,6 +51,26 @@ function SpatialConvolution:fastest(mode) return self end +function SpatialConvolution:setMode(fmode, bdmode, bwmode) + if fmode ~= nil then + self.fmode = fmode + end + if bdmode ~= nil then + self.bdmode = bdmode + end + if bwmode ~= nil then + self.bwmode = bwmode + end + return self +end + +function SpatialConvolution:resetMode() + self.fmode = nil + self.bdmode = nil + self.bwmode = nil + return self +end + function SpatialConvolution:createIODescriptors(input) local batch = true if input:dim() == 3 then @@ -78,9 +98,10 @@ function SpatialConvolution:createIODescriptors(input) local pad = torch.IntTensor({self.padH, self.padW}) local stride = torch.IntTensor({self.dH, self.dW}) local upscale = torch.IntTensor({1,1}) - errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], + errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0], 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION'); + stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + 'CUDNN_DATA_FLOAT'); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end @@ -99,29 +120,79 @@ function SpatialConvolution:createIODescriptors(input) self.oDesc = cudnn.toDescriptor(self.output[output_slice]) self.oDescForBias = cudnn.toDescriptor(self.output) + ----------------------------------------------------------------------- + local maxBufSize = 0 -- create forwardAlgorithm descriptors for local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' - local algWorkspaceLimit = self.nInputPlane * self.kH * self.kW * 4 -- 4 = sizeof int. + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end errcheck('cudnnGetConvolutionForwardAlgorithm', cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algSearchMode, algWorkspaceLimit, algType) - self.algType = algType + algType[0] = self.fmode or algType[0] + self.fwdAlgType = algType local bufSize = torch.LongTensor(1) errcheck('cudnnGetConvolutionForwardWorkspaceSize', cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algType[0], bufSize:data()) - self.extraBuffer = self.extraBuffer or input.new(1) - if bufSize[1] ~= 0 or bufSize[1] ~= self.extraBufferSizeInBytes then - self.extraBuffer:resize(math.ceil(bufSize[1]/4)) - self.extraBufferSizeInBytes = bufSize[1] + maxBufSize = math.max(maxBufSize, bufSize[1]) + + -- create backwardFilterAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' end + errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', + cudnn.getHandle(), + self.iDesc[0], self.oDesc[0], + self.convDesc[0], self.weightDesc[0], + algSearchMode, algWorkspaceLimit, algType) + algType[0] = self.bwmode or algType[0] + self.bwdFilterAlgType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize', + cudnn.getHandle(), + self.iDesc[0], self.oDesc[0], + self.convDesc[0], self.weightDesc[0], + algType[0], bufSize:data()) + maxBufSize = math.max(maxBufSize, bufSize[1]) + + -- create backwardDataAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' end + errcheck('cudnnGetConvolutionBackwardDataAlgorithm', + cudnn.getHandle(), + self.weightDesc[0], self.oDesc[0], + self.convDesc[0], self.iDesc[0], + algSearchMode, algWorkspaceLimit, algType) + algType[0] = self.bdmode or algType[0] + self.bwdDataAlgType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize', + cudnn.getHandle(), + self.weightDesc[0], self.oDesc[0], + self.convDesc[0], self.iDesc[0], + algType[0], bufSize:data()) + maxBufSize = math.max(maxBufSize, bufSize[1]) + + self.extraBuffer = self.extraBuffer or input.new(1) + self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float + if maxBufSize > self.extraBufferSizeInBytes then + self.extraBuffer:resize(math.ceil(maxBufSize/4)) + self.extraBufferSizeInBytes = maxBufSize end + ----------------------------------------------------------------------- -- create offsets for groups self.input_offset = self.nInputPlane/self.groups*input:size(3)*input:size(4) self.output_offset = self.nOutputPlane/self.groups*oSize[3]*oSize[4] @@ -172,7 +243,7 @@ function SpatialConvolution:updateOutput(input) one:data(), self.iDesc[0], input:data() + g*self.input_offset, self.weightDesc[0], self.weight:data() + g*self.weight_offset, - self.convDesc[0], self.algType[0], + self.convDesc[0], self.fwdAlgType[0], self.extraBuffer:data(), self.extraBufferSizeInBytes, zero:data(), self.oDesc[0], self.output:data() + g*self.output_offset); @@ -195,11 +266,13 @@ function SpatialConvolution:updateGradInput(input, gradOutput) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) for g=0,self.groups-1 do - errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(), + errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(), one:data(), self.weightDesc[0], self.weight:data() + g*self.weight_offset, self.oDesc[0], gradOutput:data() + g*self.output_offset, self.convDesc[0], + self.bwdDataAlgType[0], + self.extraBuffer:data(), self.extraBufferSizeInBytes, zero:data(), self.iDesc[0], self.gradInput:data() + g*self.input_offset); end @@ -224,11 +297,13 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) self.biasDesc[0], self.gradBias:data()) for g=0,self.groups-1 do -- gradWeight - errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(), + errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(), self.scaleT:data(), self.iDesc[0], input:data() + g*self.input_offset, self.oDesc[0], gradOutput:data() + g*self.output_offset, self.convDesc[0], + self.bwdFilterAlgType[0], + self.extraBuffer:data(), self.extraBufferSizeInBytes, one:data(), self.weightDesc[0], self.gradWeight:data() + g*self.weight_offset); end @@ -242,6 +317,11 @@ function SpatialConvolution:write(f) self.oDesc = nil self.oDescForBias = nil self.algType = nil + self.fwdAlgType = nil + self.bwdDataAlgType = nil + self.bwdFilterAlgType = nil + self.extraBuffer = nil + self.extraBufferSizeInBytes = nil local var = {} for k,v in pairs(self) do var[k] = v diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 58e0a34..3ab7e5b 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -39,6 +39,32 @@ function VolumetricConvolution:resetWeightDescriptors() 1, 1)) end +function VolumetricConvolution:fastest(mode) + if mode == nil then mode = true end + self.fastest_mode = mode + return self +end + +function VolumetricConvolution:setMode(fmode, bdmode, bwmode) + if fmode ~= nil then + self.fmode = fmode + end + if bdmode ~= nil then + self.bdmode = bdmode + end + if bwmode ~= nil then + self.bwmode = bwmode + end + return self +end + +function VolumetricConvolution:resetMode() + self.fmode = nil + self.bdmode = nil + self.bwmode = nil + return self +end + function VolumetricConvolution:createIODescriptors(input) local batch = true if input:dim() == 4 then @@ -62,9 +88,10 @@ function VolumetricConvolution:createIODescriptors(input) local pad = torch.IntTensor({self.padT, self.padH, self.padW}) local stride = torch.IntTensor({self.dT, self.dH, self.dW}) local upscale = torch.IntTensor({1,1,1}) - errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], + errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0], 3, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION'); + stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + 'CUDNN_DATA_FLOAT'); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end @@ -84,24 +111,79 @@ function VolumetricConvolution:createIODescriptors(input) self.output:size(2), self.output:size(3)*self.output:size(4), self.output:size(5))) - + ----------------------------------------------------------------- + local maxBufSize = 0 -- create forwardAlgorithm descriptors for local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end errcheck('cudnnGetConvolutionForwardAlgorithm', cudnn.getHandle(), - self.iDesc[0], self.weightDesc[0], self.convDesc[0], - self.oDesc[0], 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', - -1, algType) - self.algType = algType + self.iDesc[0], self.weightDesc[0], + self.convDesc[0], self.oDesc[0], + algSearchMode, algWorkspaceLimit, algType) + algType[0] = self.fmode or algType[0] + self.fwdAlgType = algType local bufSize = torch.LongTensor(1) errcheck('cudnnGetConvolutionForwardWorkspaceSize', cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algType[0], bufSize:data()) - self.extraBuffer = self.extraBuffer or input.new(1) - if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end + maxBufSize = math.max(maxBufSize, bufSize[1]) + + -- create backwardFilterAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE' + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' end + errcheck('cudnnGetConvolutionBackwardFilterAlgorithm', + cudnn.getHandle(), + self.iDesc[0], self.oDesc[0], + self.convDesc[0], self.weightDesc[0], + algSearchMode, algWorkspaceLimit, algType) + algType[0] = self.bwmode or algType[0] + self.bwdFilterAlgType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize', + cudnn.getHandle(), + self.iDesc[0], self.oDesc[0], + self.convDesc[0], self.weightDesc[0], + algType[0], bufSize:data()) + maxBufSize = math.max(maxBufSize, bufSize[1]) + -- create backwardDataAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE' + local algWorkspaceLimit = self.workspace_limit + or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' end + errcheck('cudnnGetConvolutionBackwardDataAlgorithm', + cudnn.getHandle(), + self.weightDesc[0], self.oDesc[0], + self.convDesc[0], self.iDesc[0], + algSearchMode, algWorkspaceLimit, algType) + algType[0] = self.bdmode or algType[0] + self.bwdDataAlgType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize', + cudnn.getHandle(), + self.weightDesc[0], self.oDesc[0], + self.convDesc[0], self.iDesc[0], + algType[0], bufSize:data()) + maxBufSize = math.max(maxBufSize, bufSize[1]) + + self.extraBuffer = self.extraBuffer or input.new(1) + self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float + if maxBufSize > self.extraBufferSizeInBytes then + self.extraBuffer:resize(math.ceil(maxBufSize/4)) + self.extraBufferSizeInBytes = maxBufSize + end + + ----------------------------------------------------------------- if not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), @@ -125,8 +207,8 @@ function VolumetricConvolution:updateOutput(input) one:data(), self.iDesc[0], input:data(), self.weightDesc[0], self.weight:data(), - self.convDesc[0], self.algType[0], - self.extraBuffer:data(), self.extraBuffer:nElement(), + self.convDesc[0], self.fwdAlgType[0], + self.extraBuffer:data(), self.extraBufferSizeInBytes, zero:data(), self.oDesc[0], self.output:data()); errcheck('cudnnAddTensor', cudnn.getHandle(), @@ -142,13 +224,15 @@ function VolumetricConvolution:updateGradInput(input, gradOutput) and gradOutput:isContiguous()); if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) - errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(), - one:data(), - self.weightDesc[0], self.weight:data(), - self.oDesc[0], gradOutput:data(), - self.convDesc[0], - zero:data(), - self.iDesc[0], self.gradInput:data()); + errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(), + one:data(), + self.weightDesc[0], self.weight:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + self.bwdDataAlgType[0], + self.extraBuffer:data(), self.extraBufferSizeInBytes, + zero:data(), + self.iDesc[0], self.gradInput:data()); return self.gradInput end @@ -170,15 +254,15 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale) one:data(), self.biasDesc[0], self.gradBias:data()); -- gradWeight - errcheck('cudnnConvolutionBackwardFilter', - cudnn.getHandle(), - self.scaleT:data(), - self.iDesc[0], input:data(), - self.oDesc[0], gradOutput:data(), - self.convDesc[0], - one:data(), - self.weightDesc[0], self.gradWeight:data()); - + errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(), + self.scaleT:data(), + self.iDesc[0], input:data(), + self.oDesc[0], gradOutput:data(), + self.convDesc[0], + self.bwdFilterAlgType[0], + self.extraBuffer:data(), self.extraBufferSizeInBytes, + one:data(), + self.weightDesc[0], self.gradWeight:data()); end function VolumetricConvolution:write(f) @@ -188,7 +272,11 @@ function VolumetricConvolution:write(f) self.iDesc = nil self.oDesc = nil self.oDescBias = nil - self.algType = nil + self.fwdAlgType = nil + self.bwdDataAlgType = nil + self.bwdFilterAlgType = nil + self.extraBuffer = nil + self.extraBufferInBytes = nil local var = {} for k,v in pairs(self) do var[k] = v @@ -31,11 +31,13 @@ typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t; typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t; typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t; typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t; +typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t; typedef enum { CUDNN_DATA_FLOAT = 0, - CUDNN_DATA_DOUBLE = 1 + CUDNN_DATA_DOUBLE = 1, + CUDNN_DATA_HALF = 2, } cudnnDataType_t; typedef enum @@ -108,24 +110,14 @@ cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnFilterDescriptor_t filterDesc); cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc ); cudnnStatus_t - cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int upscaleA[], - cudnnConvolutionMode_t mode - ); - -cudnnStatus_t - cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int upscaleA[], - cudnnConvolutionMode_t *mode - ); - +cudnnSetConvolutionNdDescriptor_v3( cudnnConvolutionDescriptor_t convDesc, + int arrayLength, + const int padA[], + const int filterStrideA[], + const int upscaleA[], + cudnnConvolutionMode_t mode, + cudnnDataType_t dataType + ); cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim( @@ -136,7 +128,6 @@ cudnnStatus_t int tensorOuputDimA[] ); -/* Destroy an instance of convolution descriptor */ cudnnStatus_t cudnnDestroyConvolutionDescriptor( cudnnConvolutionDescriptor_t convDesc ); @@ -152,9 +143,29 @@ typedef enum CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1, CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2, - CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3 + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3, + CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4 } cudnnConvolutionFwdAlgo_t; +typedef struct { + cudnnConvolutionFwdAlgo_t algo; + cudnnStatus_t status; + float time; + size_t memory; +} cudnnConvolutionFwdAlgoPerf_t; + +cudnnStatus_t +cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, + const int requestedCount, + int *returnedCount, + cudnnConvolutionFwdAlgoPerf_t *perfResults + ); + + cudnnStatus_t cudnnGetConvolutionForwardAlgorithm( cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, const cudnnFilterDescriptor_t filterDesc, @@ -165,10 +176,6 @@ cudnnStatus_t cudnnGetConvolutionForwardAlgorithm( cudnnHandle_t handle, cudnnConvolutionFwdAlgo_t *algo ); -/* - * convolution algorithm (which requires potentially some workspace) - */ - cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize( cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, const cudnnFilterDescriptor_t filterDesc, @@ -179,7 +186,6 @@ cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize( cudnnHandle_t handle, ); -/* Function to perform the forward multiconvolution */ cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -195,7 +201,6 @@ cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, void *destData ); -/* Functions to perform the backward multiconvolution */ cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -205,41 +210,116 @@ cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, void *destData ); +typedef enum +{ + CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE = 0, + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST = 1 +} cudnnConvolutionBwdFilterPreference_t; + +typedef enum +{ + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, // non-deterministic + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2 +} cudnnConvolutionBwdFilterAlgo_t; + +cudnnStatus_t +cudnnGetConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterPreference_t preference, + size_t memoryLimitInbytes, + cudnnConvolutionBwdFilterAlgo_t *algo + ); + +cudnnStatus_t +cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, + size_t *sizeInBytes + ); + +cudnnStatus_t cudnnConvolutionBackwardFilter_v3( + cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, + void *workSpace, + size_t workSpaceSizeInBytes, + const void *beta, + const cudnnFilterDescriptor_t gradDesc, + void *gradData + ); + +typedef enum +{ + CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE = 0, + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST = 1 +} cudnnConvolutionBwdDataPreference_t; + +typedef enum +{ + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2, +} cudnnConvolutionBwdDataAlgo_t; + + +cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, + const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, + cudnnConvolutionBwdDataPreference_t preference, + size_t memoryLimitInbytes, + cudnnConvolutionBwdDataAlgo_t *algo + ); + +cudnnStatus_t cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnnHandle_t handle, + const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, + cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes + ); + + +cudnnStatus_t cudnnConvolutionBackwardData_v3( + cudnnHandle_t handle, + const void *alpha, + const cudnnFilterDescriptor_t filterDesc, + const void *filterData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, + void *workSpace, + size_t workSpaceSizeInBytes, + const void *beta, + const cudnnTensorDescriptor_t gradDesc, + void *gradData + ); -cudnnStatus_t cudnnConvolutionBackwardFilter( cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t srcDesc, - const void *srcData, - const cudnnTensorDescriptor_t diffDesc, - const void *diffData, - const cudnnConvolutionDescriptor_t convDesc, - const void *beta, - const cudnnFilterDescriptor_t gradDesc, - void *gradData - ); - - -cudnnStatus_t cudnnConvolutionBackwardData( cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t filterDesc, - const void *filterData, - const cudnnTensorDescriptor_t diffDesc, - const void *diffData, - const cudnnConvolutionDescriptor_t convDesc, - const void *beta, - const cudnnTensorDescriptor_t gradDesc, - void *gradData - ); - - -/* - * softmax algorithm - */ typedef enum { CUDNN_SOFTMAX_FAST = 0, - CUDNN_SOFTMAX_ACCURATE = 1 + CUDNN_SOFTMAX_ACCURATE = 1, + CUDNN_SOFTMAX_LOG = 2 } cudnnSoftmaxAlgorithm_t; typedef enum @@ -250,18 +330,19 @@ typedef enum /* Function to perform forward softmax */ cudnnStatus_t cudnnSoftmaxForward( cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algorithm, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t srcDesc, - const void *srcData, - const void *beta, - const cudnnTensorDescriptor_t destDesc, - void *destData - ); + cudnnSoftmaxAlgorithm_t algorithm, + cudnnSoftmaxMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData + ); /* Function to perform backward softmax */ -cudnnStatus_t cudnnSoftmaxBackward( cudnnHandle_t handle, +cudnnStatus_t cudnnSoftmaxBackward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algorithm, cudnnSoftmaxMode_t mode, const void *alpha, @@ -274,25 +355,23 @@ cudnnStatus_t cudnnSoftmaxBackward( cudnnHandle_t handle, void *destDiffData ); - - typedef enum { CUDNN_POOLING_MAX = 0, - CUDNN_POOLING_AVERAGE = 1 + CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, + CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2 } cudnnPoolingMode_t; -/* Create an instance of pooling descriptor */ cudnnStatus_t cudnnCreatePoolingDescriptor( - cudnnPoolingDescriptor_t *poolingDesc); + cudnnPoolingDescriptor_t *poolingDesc); cudnnStatus_t cudnnSetPoolingNdDescriptor( - cudnnPoolingDescriptor_t poolingDesc, + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, int nbDims, const int windowDimA[], const int paddingA[], const int strideA[] - ); + ); cudnnStatus_t cudnnGetPoolingNdDescriptor( const cudnnPoolingDescriptor_t poolingDesc, @@ -305,15 +384,14 @@ cudnnStatus_t cudnnGetPoolingNdDescriptor( ); cudnnStatus_t cudnnGetPoolingNdForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[]); -/* Destroy an instance of pooling descriptor */ + const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, + int outputTensorDimA[]); + cudnnStatus_t cudnnDestroyPoolingDescriptor( - cudnnPoolingDescriptor_t poolingDesc ); + cudnnPoolingDescriptor_t poolingDesc ); -/* Function to perform forward pooling */ cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, @@ -324,8 +402,7 @@ cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, void *destData ); -/* Function to perform backward pooling */ -cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, +cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -346,8 +423,7 @@ typedef enum CUDNN_ACTIVATION_TANH = 2 } cudnnActivationMode_t; -/* Function to perform forward activation */ -cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, +cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, cudnnActivationMode_t mode, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -357,8 +433,7 @@ cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, void *destData ); -/* Function to perform backward activation */ -cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, +cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, cudnnActivationMode_t mode, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -371,6 +446,93 @@ cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, const cudnnTensorDescriptor_t destDiffDesc, void *destDiffData ); + +cudnnStatus_t cudnnCreateLRNDescriptor( cudnnLRNDescriptor_t* normDesc ); + +typedef enum + { + CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, + } cudnnLRNMode_t; + +cudnnStatus_t cudnnSetLRNDescriptor( + cudnnLRNDescriptor_t normDesc, + unsigned lrnN, + double lrnAlpha, + double lrnBeta, + double lrnK); + +cudnnStatus_t cudnnGetLRNDescriptor( + cudnnLRNDescriptor_t normDesc, + unsigned* lrnN, + double* lrnAlpha, + double* lrnBeta, + double* lrnK); + +cudnnStatus_t cudnnDestroyLRNDescriptor( cudnnLRNDescriptor_t lrnDesc ); + +cudnnStatus_t cudnnLRNCrossChannelForward( + cudnnHandle_t handle, + cudnnLRNDescriptor_t normDesc, + cudnnLRNMode_t lrnMode, + const void* alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData); + +cudnnStatus_t cudnnLRNCrossChannelBackward( + cudnnHandle_t handle, + cudnnLRNDescriptor_t normDesc, + cudnnLRNMode_t lrnMode, + const void* alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t srcDiffDesc, + const void *srcDiffData, + const cudnnTensorDescriptor_t destDesc, + const void *destData, + const void *beta, + const cudnnTensorDescriptor_t destDiffDesc, + void *destDiffData); + +typedef enum + { + CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0, + } cudnnDivNormMode_t; + +cudnnStatus_t cudnnDivisiveNormalizationForward( + cudnnHandle_t handle, + cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *srcMeansData, + void *tempData, + void *tempData2, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData + ); + +cudnnStatus_t cudnnDivisiveNormalizationBackward( + cudnnHandle_t handle, + cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *srcMeansData, + const void *srcDiffData, + void *tempData, + void *tempData2, + const void *betaData, + const cudnnTensorDescriptor_t destDataDesc, + void *destDataDiff, + void *destMeansDiff + ); + ]] local ok,err = pcall(function() cudnn.C = ffi.load('libcudnn') end) @@ -383,8 +545,8 @@ Then make sure all the files named as libcudnn.so* are placed in your library lo end cudnn.version = tonumber(cudnn.C.cudnnGetVersion()) -if cudnn.version < 20 then - error('These bindings are for version 20 or above, ' +if cudnn.version < 3000 then + error('These bindings are for version 3000 or above, ' .. 'while the loaded CuDNN is version: ' .. cudnn.version .. ' \nAre you using an older version of CuDNN?') end diff --git a/test/benchmark.lua b/test/benchmark.lua new file mode 100644 index 0000000..08218b9 --- /dev/null +++ b/test/benchmark.lua @@ -0,0 +1,72 @@ +require 'cudnn' +require 'torch' + +function bench(title, nInputC, nOutputC, kH, kW, sH, sW, iH, iW, nBatch, ...) + local m1 = cudnn.SpatialConvolution(nInputC,nOutputC,kW,kH, sW, sH):setMode(...):fastest():cuda() + local i1 = torch.zeros(nBatch, nInputC, iH, iW):cuda() + local o1 = m1:forward(i1) + + local t1 = torch.Timer() + local o1 = m1:forward(i1) + cutorch.synchronize() + print(title .. ': ', nInputC, nOutputC, kH, kW, iH, iW, nBatch, t1:time().real) +end + + +batchSize = 29 +from = 14 +to = 13 +kW = 9 +kH = 15 +sW = 1 +sH = 1 +outW = 10 +outH = 34 +iW = (outW-1)*sW+kW +iH = (outH-1)*sH+kH + + +print('CUDNN Version: ', tonumber(cudnn.C.cudnnGetVersion())) + +bench('Forward implicit gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize, + 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM', + 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', + 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') + +bench('Forward implicit precomp gemm', from, to, kH, kW, sH, sW, iH, iW, batchSize, + 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM', + 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', + 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') + +bench('Forward gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize, + 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM', + 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', + 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') + +-- just auto-tuned by cudnn with CUDNN_CONVOLUTION_FWD_PREFER_FASTEST mode +bench('Forward AutoTuned ', from, to, kH, kW, sH, sW, iH, iW, batchSize) + +bench('Forward FFT ', from, to, kH, kW, sH, sW, iH, iW, batchSize, + 'CUDNN_CONVOLUTION_FWD_ALGO_FFT', + 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', + 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') + + + +-- For reference, CuDNN Convolution modes +--[[ + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1, + CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3, // Placeholder + CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4 + + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, // non-deterministic + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2 + + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2, + + ]]-- diff --git a/test/test.lua b/test/test.lua index dc50d94..ac8a573 100644 --- a/test/test.lua +++ b/test/test.lua @@ -25,7 +25,7 @@ function cudnntest.SpatialConvolution_forward_batch() local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() local groundtruth = sconv:forward(input) cutorch.synchronize() - local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() + local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) local rescuda = gconv:forward(input) @@ -59,7 +59,7 @@ function cudnntest.SpatialConvolution_backward_batch() local groundweight = sconv.gradWeight local groundbias = sconv.gradBias - local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() + local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) gconv:forward(input) |