diff options
-rw-r--r-- | Pointwise.lua | 7 | ||||
-rw-r--r-- | Pooling.lua | 21 | ||||
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | SpatialAveragePooling.lua | 4 | ||||
-rw-r--r-- | SpatialConvolution.lua | 78 | ||||
-rw-r--r-- | SpatialMaxPooling.lua | 4 | ||||
-rw-r--r-- | SpatialSoftMax.lua | 7 | ||||
-rw-r--r-- | ffi.lua | 461 | ||||
-rw-r--r-- | init.lua | 18 | ||||
-rw-r--r-- | test/test.lua | 13 |
10 files changed, 403 insertions, 214 deletions
diff --git a/Pointwise.lua b/Pointwise.lua index aa4cb34..aecfe08 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -33,11 +33,16 @@ function Pointwise:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function Pointwise:updateOutput(input) self:createIODescriptors(input) errcheck('cudnnActivationForward', cudnn.handle[cutorch.getDevice()-1], self.mode, + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -52,9 +57,11 @@ function Pointwise:updateGradInput(input, gradOutput) self:createIODescriptors(input) errcheck('cudnnActivationBackward', cudnn.handle[cutorch.getDevice()-1], self.mode, + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), self.iDesc[0], input:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end diff --git a/Pooling.lua b/Pooling.lua index 0ab7cb4..87d56bf 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -2,12 +2,14 @@ local Pooling, parent = torch.class('cudnn._Pooling', 'nn.Module') local ffi = require 'ffi' local errcheck = cudnn.errcheck -function Pooling:__init(kW, kH, dW, dH) +function Pooling:__init(kW, kH, dW, dH, padW, padH) parent.__init(self) self.kW = kW self.kH = kH self.dW = dW or kW self.dH = dH or kW + self.padW = padW or 0 + self.padH = padH or 0 self.iSize = torch.LongStorage(4):fill(0) self.ceil_mode = false end @@ -26,8 +28,11 @@ function Pooling:resetPoolDescriptors() -- create pooling descriptor self.poolDesc = ffi.new('struct cudnnPoolingStruct*[1]') errcheck('cudnnCreatePoolingDescriptor', self.poolDesc) - errcheck('cudnnSetPoolingDescriptor', self.poolDesc[0], self.mode, - self.kH, self.kW, self.dH, self.dW); + local ker = torch.IntTensor({self.kH, self.kW}) + local str = torch.IntTensor({self.dH, self.dW}) + local pad = torch.IntTensor({self.padH, self.padW}) + errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 2, + ker:data(), pad:data(), str:data()); local function destroyPoolDesc(d) errcheck('cudnnDestroyPoolingDescriptor', d[0]); end @@ -73,11 +78,17 @@ function Pooling:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function Pooling:updateOutput(input) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) - errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], + errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], + self.poolDesc[0], + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -92,9 +103,11 @@ function Pooling:updateGradInput(input, gradOutput) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) errcheck('cudnnPoolingBackward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), self.iDesc[0], input:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end @@ -16,8 +16,8 @@ Modules are API compatible their [`nn`](https://github.com/torch/nn) equivalents ```lua -- All inputs have to be 3D or 4D(batch-mode), even for ReLU, SoftMax etc. cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) -cudnn.SpatialMaxPooling(kW, kH, dW, dH) -cudnn.SpatialAveragePooling(kW, kH, dW, dH) +cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) +cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH) cudnn.ReLU() cudnn.Tanh() cudnn.Sigmoid() diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index ec5bcca..5da3db2 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -1,7 +1,7 @@ local SpatialAveragePooling, parent = torch.class('cudnn.SpatialAveragePooling', 'cudnn._Pooling') -function SpatialAveragePooling:__init(kW, kH, dW, dH) - parent.__init(self, kW, kH, dW, dH) +function SpatialAveragePooling:__init(kW, kH, dW, dH, padW, padH) + parent.__init(self, kW, kH, dW, dH, padW, padH) self.mode = 'CUDNN_POOLING_AVERAGE' end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index e939592..794653c 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -17,8 +17,10 @@ function SpatialConvolution:resetWeightDescriptors() -- create filterDescriptor for weight self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]') errcheck('cudnnCreateFilterDescriptor', self.weightDesc) - errcheck('cudnnSetFilterDescriptor', self.weightDesc[0], 'CUDNN_DATA_FLOAT', - self.nOutputPlane, self.nInputPlane, self.kH, self.kW); + local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane, self.kH, self.kW}) + errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], + 'CUDNN_DATA_FLOAT', 4, + desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); end @@ -46,22 +48,40 @@ function SpatialConvolution:createIODescriptors(input) -- create conv descriptor self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]') errcheck('cudnnCreateConvolutionDescriptor', self.convDesc) - errcheck('cudnnSetConvolutionDescriptor', self.convDesc[0], self.iDesc[0], - self.weightDesc[0], self.padH, self.padW, - self.dH, self.dW, 1, 1, 'CUDNN_CROSS_CORRELATION'); + local pad = torch.IntTensor({self.padH, self.padW}) + local stride = torch.IntTensor({self.dH, self.dW}) + local upscale = torch.IntTensor({1,1}) + errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 2, pad:data(), + stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION'); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end ffi.gc(self.convDesc, destroyConvDesc) -- create output descriptor and resize output - local oSize = torch.IntTensor(4):fill(0) + local oSize = torch.IntTensor(4) local oSizeD = oSize:data() - errcheck('cudnnGetOutputTensor4dDim', self.convDesc[0], 'CUDNN_CONVOLUTION_FWD', - oSizeD, oSizeD+1, oSizeD+2, oSizeD+3) + errcheck('cudnnGetConvolutionNdForwardOutputDim', self.convDesc[0], self.iDesc[0], + self.weightDesc[0], 4, oSizeD) self.output:resize(oSize:long():storage()) -- create descriptor for output self.oDesc = cudnn.toDescriptor(self.output) + + -- create forwardAlgorithm descriptors for + local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) + errcheck('cudnnGetConvolutionForwardAlgorithm', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType) + self.algType = algType + local bufSize = torch.LongTensor(1) + errcheck('cudnnGetConvolutionForwardWorkspaceSize', + cudnn.handle[cutorch.getDevice()-1], + self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], + algType[0], bufSize:data()) + self.extraBuffer = self.extraBuffer or input.new(1) + if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end + if not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), @@ -73,17 +93,22 @@ function SpatialConvolution:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function SpatialConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], + one:data(), self.iDesc[0], input:data(), self.weightDesc[0], self.weight:data(), - self.convDesc[0], self.oDesc[0], self.output:data(), - 'CUDNN_RESULT_NO_ACCUMULATE'); - local alpha = torch.FloatTensor({1}); - errcheck('cudnnAddTensor4d', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', - alpha:data(), self.biasDesc[0], self.bias:data(), + self.convDesc[0], self.algType[0], + self.extraBuffer:data(), self.extraBuffer:nElement(), + zero:data(), + self.oDesc[0], self.output:data()); + errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', + one:data(), self.biasDesc[0], self.bias:data(), one:data(), self.oDesc[0], self.output:data()); return self.output end @@ -95,39 +120,46 @@ function SpatialConvolution:updateGradInput(input, gradOutput) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], + one:data(), self.weightDesc[0], self.weight:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], - self.iDesc[0], self.gradInput:data(), - 'CUDNN_RESULT_NO_ACCUMULATE'); + zero:data(), + self.iDesc[0], self.gradInput:data()); return self.gradInput end +local scaleT = torch.FloatTensor(1):fill(1.0) function SpatialConvolution:accGradParameters(input, gradOutput, scale) - assert(scale == nil or scale == 1) + scale = scale or 1.0 + scaleT[1] = scale assert((gradOutput:dim() == 3 or gradOutput:dim() == 4) and gradOutput:isContiguous()); self:createIODescriptors(input) if not self.weightDesc then self:resetWeightDescriptors() end -- gradBias errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], + scaleT:data(), self.oDesc[0], gradOutput:data(), - self.biasDesc[0], self.gradBias:data(), - 'CUDNN_RESULT_ACCUMULATE'); + one:data(), + self.biasDesc[0], self.gradBias:data()); -- gradWeight errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1], + scaleT:data(), self.iDesc[0], input:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], - self.weightDesc[0], self.gradWeight:data(), - 'CUDNN_RESULT_ACCUMULATE'); + one:data(), + self.weightDesc[0], self.gradWeight:data()); end + --[[ function SpatialConvolution:zeroGradParameters() -- gradWeight, gradBias to zero - local alpha = torch.FloatTensor({0}); - errcheck('cudnnSetTensor4d', self.weightDesc, self.gradWeight:data(), alpha:data()); - errcheck('cudnnSetTensor4d', self.biasDesc, self.gradBias:data(), alpha:data()); + errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + self.weightDesc, self.gradWeight:data(), zero:data()); + errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + self.biasDesc, self.gradBias:data(), zero:data()); end ]]-- diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index 108a055..e9a7b89 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -1,6 +1,6 @@ local SpatialMaxPooling, parent = torch.class('cudnn.SpatialMaxPooling', 'cudnn._Pooling') -function SpatialMaxPooling:__init(kW, kH, dW, dH) - parent.__init(self, kW, kH, dW, dH) +function SpatialMaxPooling:__init(kW, kH, dW, dH, padW, padH) + parent.__init(self, kW, kH, dW, dH, padW, padH) self.mode = 'CUDNN_POOLING_MAX' end diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 3a4106d..87af4d5 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -38,12 +38,17 @@ function SpatialSoftMax:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function SpatialSoftMax:updateOutput(input) self:createIODescriptors(input) errcheck('cudnnSoftmaxForward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -55,8 +60,10 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end @@ -18,31 +18,40 @@ typedef enum CUDNN_STATUS_LICENSE_ERROR = 10 } cudnnStatus_t; +const char * cudnnGetErrorString(cudnnStatus_t status); + typedef struct CUstream_st *cudaStream_t; cudnnStatus_t cudnnCreate(cudnnHandle_t *handle); cudnnStatus_t cudnnDestroy(cudnnHandle_t handle); -typedef struct cudnnTensor4dStruct* cudnnTensor4dDescriptor_t; +cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId); +cudnnStatus_t cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId); + +typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t; typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t; typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t; typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t; + typedef enum { CUDNN_DATA_FLOAT = 0, CUDNN_DATA_DOUBLE = 1 } cudnnDataType_t; -cudnnStatus_t cudnnCreateTensor4dDescriptor( cudnnTensor4dDescriptor_t *tensorDesc ); -cudnnStatus_t cudnnSetTensor4dDescriptorEx( cudnnTensor4dDescriptor_t tensorDesc, - cudnnDataType_t dataType, // image data type - int n, // number of inputs (batch size) - int c, // number of input feature maps - int h, // height of input section - int w, // width of input section - int nStride, - int cStride, - int hStride, - int wStride - ); -cudnnStatus_t cudnnDestroyTensor4dDescriptor( cudnnTensor4dDescriptor_t tensorDesc ); + +typedef enum +{ + CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */ + CUDNN_TENSOR_NHWC = 1 /* feature maps interleaved ( cStride = 1 )*/ +} cudnnTensorFormat_t; + +cudnnStatus_t cudnnCreateTensorDescriptor( cudnnTensorDescriptor_t *tensorDesc ); +cudnnStatus_t cudnnSetTensorNdDescriptor( cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, + int nbDims, + const int dimA[], + const int strideA[] + ); +cudnnStatus_t cudnnDestroyTensorDescriptor( cudnnTensorDescriptor_t tensorDesc ); + typedef enum { CUDNN_ADD_IMAGE = 0, @@ -52,19 +61,37 @@ typedef enum CUDNN_ADD_SAME_C = 2, CUDNN_ADD_FULL_TENSOR = 3 } cudnnAddMode_t; -cudnnStatus_t cudnnAddTensor4d( cudnnHandle_t handle, - cudnnAddMode_t mode, - const void *alpha, - cudnnTensor4dDescriptor_t biasDesc, - const void *biasData, - cudnnTensor4dDescriptor_t srcDestDesc, - void *srcDestData +/* Tensor Bias addition : srcDest = alpha * bias + beta * srcDestDesc */ +cudnnStatus_t cudnnAddTensor( cudnnHandle_t handle, + cudnnAddMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t biasDesc, + const void *biasData, + const void *beta, + cudnnTensorDescriptor_t srcDestDesc, + void *srcDestData ); + +/* Set all data points of a tensor to a given value : srcDest = value */ +cudnnStatus_t cudnnSetTensor( cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDestDesc, + void *srcDestData, + const void *value + ); + +/* Set all data points of a tensor to a given value : srcDest = alpha * srcDest */ +cudnnStatus_t cudnnScaleTensor( cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDestDesc, + void *srcDestData, + const void *alpha + ); + typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t; + typedef enum { CUDNN_CONVOLUTION_FWD = 0, /* Tensor Convolution function */ @@ -72,178 +99,282 @@ typedef enum CUDNN_CONVOLUTION_DATA_GRAD = 2 /* Data Gradient update function */ } cudnnConvolutionPath_t; cudnnStatus_t cudnnCreateFilterDescriptor( cudnnFilterDescriptor_t *filterDesc ); -cudnnStatus_t cudnnSetFilterDescriptor( cudnnFilterDescriptor_t filterDesc, +cudnnStatus_t cudnnSetFilterNdDescriptor( cudnnFilterDescriptor_t filterDesc, cudnnDataType_t dataType, // image data type - int k, // number of output feature maps - int c, // number of input feature maps - int h, // height of each input filter - int w // width of each input fitler - ); -cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnFilterDescriptor_t filterDesc ); + int nbDims, + const int filterDimA[] + ); + +cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnFilterDescriptor_t filterDesc ); + cudnnStatus_t cudnnCreateConvolutionDescriptor( cudnnConvolutionDescriptor_t *convDesc ); -cudnnStatus_t cudnnSetConvolutionDescriptor( cudnnConvolutionDescriptor_t convDesc, - cudnnTensor4dDescriptor_t inputTensorDesc, - cudnnFilterDescriptor_t filterDesc, - int pad_h, // zero-padding height - int pad_w, // zero-padding width - int u, // vertical filter stride - int v, // horizontal filter stride - int upscalex, // upscale the input in x-direction - int upscaley, // upscale the input in y-direction - cudnnConvolutionMode_t mode - ); -cudnnStatus_t cudnnGetOutputTensor4dDim( const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionPath_t path, - int *n, - int *c, - int *h, - int *w - ); -cudnnStatus_t cudnnDestroyConvolutionDescriptor( cudnnConvolutionDescriptor_t convDesc ); +cudnnStatus_t cudnnSetConvolutionNdDescriptor( cudnnConvolutionDescriptor_t convDesc, + int arrayLength, /* nbDims-2 size */ + const int padA[], + const int filterStrideA[], + const int upscaleA[], + cudnnConvolutionMode_t mode + ); + +cudnnStatus_t cudnnGetConvolutionNdDescriptor( const cudnnConvolutionDescriptor_t convDesc, + int arrayLengthRequested, + int *arrayLength, + int padA[], + int strideA[], + int upscaleA[], + cudnnConvolutionMode_t *mode + ); + + +/* Helper function to return the dimensions of the output tensor given a convolution descriptor */ +cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim( const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, + int nbDims, + int tensorOuputDimA[] + ); + +/* Destroy an instance of convolution descriptor */ +cudnnStatus_t cudnnDestroyConvolutionDescriptor( cudnnConvolutionDescriptor_t convDesc ); + +typedef enum +{ + CUDNN_CONVOLUTION_FWD_NO_WORKSPACE = 0, + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST = 1, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT = 2 +} cudnnConvolutionFwdPreference_t; + typedef enum { - CUDNN_RESULT_ACCUMULATE = 0, /* Evaluate O += I * F */ - CUDNN_RESULT_NO_ACCUMULATE = 1 /* Evaluate O = I * F */ -} cudnnAccumulateResult_t; -cudnnStatus_t cudnnConvolutionForward( cudnnHandle_t handle, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnFilterDescriptor_t filterDesc, - const void *filterData, - cudnnConvolutionDescriptor_t convDesc, - cudnnTensor4dDescriptor_t destDesc, - void *destData, - cudnnAccumulateResult_t accumulate + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1, + CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3 +} cudnnConvolutionFwdAlgo_t; + +cudnnStatus_t cudnnGetConvolutionForwardAlgorithm( cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, + cudnnConvolutionFwdPreference_t preference, + size_t memoryLimitInbytes, + cudnnConvolutionFwdAlgo_t *algo + ); + +/* + * convolution algorithm (which requires potentially some workspace) + */ + + /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/ +cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize( cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, + cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes + ); + + +/* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */ + +/* Function to perform the forward multiconvolution */ +cudnnStatus_t cudnnConvolutionForward( cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnFilterDescriptor_t filterDesc, + const void *filterData, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionFwdAlgo_t algo, + void *workSpace, + size_t workSpaceSizeInBytes, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData ); -cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t destDesc, - void *destData, - cudnnAccumulateResult_t accumulate + +/* Functions to perform the backward multiconvolution */ +cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData ); -cudnnStatus_t cudnnConvolutionBackwardFilter( cudnnHandle_t handle, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t diffDesc, - const void *diffData, - cudnnConvolutionDescriptor_t convDesc, - cudnnFilterDescriptor_t gradDesc, - void *gradData, - cudnnAccumulateResult_t accumulate + + + +cudnnStatus_t cudnnConvolutionBackwardFilter( cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + const void *beta, + const cudnnFilterDescriptor_t gradDesc, + void *gradData ); -cudnnStatus_t cudnnConvolutionBackwardData( cudnnHandle_t handle, - cudnnFilterDescriptor_t filterDesc, - const void *filterData, - cudnnTensor4dDescriptor_t diffDesc, - const void *diffData, - cudnnConvolutionDescriptor_t convDesc, - cudnnTensor4dDescriptor_t gradDesc, - void *gradData, - cudnnAccumulateResult_t accumulate - ); + + +cudnnStatus_t cudnnConvolutionBackwardData( cudnnHandle_t handle, + const void *alpha, + const cudnnFilterDescriptor_t filterDesc, + const void *filterData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + const void *beta, + const cudnnTensorDescriptor_t gradDesc, + void *gradData + ); + + +/* + * softmax algorithm + */ +typedef enum +{ + CUDNN_SOFTMAX_FAST = 0, /* straightforward implementation */ + CUDNN_SOFTMAX_ACCURATE = 1 /* subtract max from every point to avoid overflow */ +} cudnnSoftmaxAlgorithm_t; + +typedef enum +{ + CUDNN_SOFTMAX_MODE_INSTANCE = 0, /* compute the softmax over all C, H, W for each N */ + CUDNN_SOFTMAX_MODE_CHANNEL = 1 /* compute the softmax over all C for each H, W, N */ +} cudnnSoftmaxMode_t; + +/* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */ + +/* Function to perform forward softmax */ +cudnnStatus_t cudnnSoftmaxForward( cudnnHandle_t handle, + cudnnSoftmaxAlgorithm_t algorithm, + cudnnSoftmaxMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData + ); + +/* Function to perform backward softmax */ +cudnnStatus_t cudnnSoftmaxBackward( cudnnHandle_t handle, + cudnnSoftmaxAlgorithm_t algorithm, + cudnnSoftmaxMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t srcDiffDesc, + const void *srcDiffData, + const void *beta, + const cudnnTensorDescriptor_t destDiffDesc, + void *destDiffData + ); + + + typedef enum { CUDNN_POOLING_MAX = 0, CUDNN_POOLING_AVERAGE = 1 } cudnnPoolingMode_t; -cudnnStatus_t cudnnCreatePoolingDescriptor( cudnnPoolingDescriptor_t *poolingDesc); -cudnnStatus_t cudnnSetPoolingDescriptor( cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - int windowHeight, - int windowWidth, - int verticalStride, - int horizontalStride + +/* Create an instance of pooling descriptor */ +cudnnStatus_t cudnnCreatePoolingDescriptor( cudnnPoolingDescriptor_t *poolingDesc); +cudnnStatus_t cudnnSetPoolingNdDescriptor( cudnnPoolingDescriptor_t poolingDesc, + const cudnnPoolingMode_t mode, + int nbDims, + const int windowDimA[], + const int paddingA[], + const int strideA[] ); -cudnnStatus_t cudnnGetPoolingDescriptor( const cudnnPoolingDescriptor_t poolingDesc, + +cudnnStatus_t cudnnGetPoolingNdDescriptor( const cudnnPoolingDescriptor_t poolingDesc, + const int nbDimsRequested, cudnnPoolingMode_t *mode, - int *windowHeight, - int *windowWidth, - int *verticalStride, - int *horizontalStride - ); -cudnnStatus_t cudnnDestroyPoolingDescriptor( cudnnPoolingDescriptor_t poolingDesc ); -cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, - cudnnPoolingDescriptor_t poolingDesc, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t destDesc, - void *destData + int *nbDims, + int windowDimA[], + int paddingA[], + int strideA[] + ); + +cudnnStatus_t cudnnGetPoolingNdForwardOutputDim( const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, + int outputTensorDimA[]); +/* Destroy an instance of pooling descriptor */ +cudnnStatus_t cudnnDestroyPoolingDescriptor( cudnnPoolingDescriptor_t poolingDesc ); +/* Pooling functions: All of the form "output = alpha * Op(inputs) + beta * output" */ + +/* Function to perform forward pooling */ +cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, + const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData ); -cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, - cudnnPoolingDescriptor_t poolingDesc, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t srcDiffDesc, - const void *srcDiffData, - cudnnTensor4dDescriptor_t destDesc, - const void *destData, - cudnnTensor4dDescriptor_t destDiffDesc, - void *destDiffData + +/* Function to perform backward pooling */ +cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, + const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t srcDiffDesc, + const void *srcDiffData, + const cudnnTensorDescriptor_t destDesc, + const void *destData, + const void *beta, + const cudnnTensorDescriptor_t destDiffDesc, + void *destDiffData ); + typedef enum { CUDNN_ACTIVATION_SIGMOID = 0, CUDNN_ACTIVATION_RELU = 1, CUDNN_ACTIVATION_TANH = 2 } cudnnActivationMode_t; -cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, - cudnnActivationMode_t mode, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t destDesc, - void *destData - ); -cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, - cudnnActivationMode_t mode, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t srcDiffDesc, - const void *srcDiffData, - cudnnTensor4dDescriptor_t destDesc, - const void *destData, - cudnnTensor4dDescriptor_t destDiffDesc, - void *destDiffData - ); - - -typedef enum -{ - CUDNN_SOFTMAX_FAST = 0, CUDNN_SOFTMAX_ACCURATE = 1 -} cudnnSoftmaxAlgorithm_t; - -typedef enum -{ - CUDNN_SOFTMAX_MODE_INSTANCE = 0, CUDNN_SOFTMAX_MODE_CHANNEL = 1 -} cudnnSoftmaxMode_t; - - -cudnnStatus_t cudnnSoftmaxForward( cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algorithm, - cudnnSoftmaxMode_t mode, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t destDesc, - void *destData - ); - -cudnnStatus_t cudnnSoftmaxBackward( cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algorithm, - cudnnSoftmaxMode_t mode, - cudnnTensor4dDescriptor_t srcDesc, - const void *srcData, - cudnnTensor4dDescriptor_t srcDiffDesc, - const void *srcDiffData, - cudnnTensor4dDescriptor_t destDiffDesc, - void *destDiffData - ); +/* Function to perform forward activation */ +cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, + cudnnActivationMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const void *beta, + const cudnnTensorDescriptor_t destDesc, + void *destData + ); +/* Function to perform backward activation */ +cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, + cudnnActivationMode_t mode, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t srcDiffDesc, + const void *srcDiffData, + const cudnnTensorDescriptor_t destDesc, + const void *destData, + const void *beta, + const cudnnTensorDescriptor_t destDiffDesc, + void *destDiffData + ); ]] local ok -ok = pcall(function() cudnn.C = ffi.load('libcudnn') end) +ok,err = pcall(function() cudnn.C = ffi.load('libcudnn') end) if not ok then + print(err) error([['libcudnn.so not found in library path. Please install CuDNN from https://developer.nvidia.com/cuDNN Then make sure all the files named as libcudnn.so* are placed in your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH) @@ -8,7 +8,8 @@ local ffi = require 'ffi' local errcheck = function(f, ...) local status = C[f](...) if status ~= 'CUDNN_STATUS_SUCCESS' then - error("Error in CuDNN. Status Code: " .. tonumber(status)) + local str = ffi.string(C.cudnnGetErrorString(status)) + error('Error in CuDNN: ' .. str) end end cudnn.errcheck = errcheck @@ -34,21 +35,20 @@ end ffi.gc(cudnn.handle, destroy) function cudnn.toDescriptor(t) - if t:dim() == 3 then t = t:view(1, t:size(1), t:size(2), t:size(3)) end - assert(t:dim() == 4, 'Expecting 4D input, but got: ' .. t:dim()); assert(torch.typename(t) == 'torch.CudaTensor') - local descriptor = ffi.new('struct cudnnTensor4dStruct*[1]') + local descriptor = ffi.new('struct cudnnTensorStruct*[1]') -- create descriptor - errcheck('cudnnCreateTensor4dDescriptor', descriptor) + errcheck('cudnnCreateTensorDescriptor', descriptor) -- set gc hook local function destroy(d) - errcheck('cudnnDestroyTensor4dDescriptor', d[0]); + errcheck('cudnnDestroyTensorDescriptor', d[0]); end ffi.gc(descriptor, destroy) -- set descriptor - errcheck('cudnnSetTensor4dDescriptorEx', descriptor[0], 'CUDNN_DATA_FLOAT', - t:size(1), t:size(2), t:size(3), t:size(4), - t:stride(1), t:stride(2), t:stride(3), t:stride(4)) + local size = torch.LongTensor(t:size()):int() + local stride = torch.LongTensor(t:stride()):int() + errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT', + t:dim(), size:data(), stride:data()) return descriptor end diff --git a/test/test.lua b/test/test.lua index 855ea71..435ba5f 100644 --- a/test/test.lua +++ b/test/test.lua @@ -9,7 +9,6 @@ local nloop = 1 local times = {} local mytester - function cudnntest.SpatialConvolution_forward_batch() local bs = math.random(1,32) local from = math.random(1,32) @@ -22,7 +21,6 @@ function cudnntest.SpatialConvolution_forward_batch() local outj = math.random(1,64) local ini = (outi-1)*si+ki local inj = (outj-1)*sj+kj - local input = torch.randn(bs,from,inj,ini):cuda() local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() local groundtruth = sconv:forward(input) @@ -43,19 +41,20 @@ function cudnntest.SpatialConvolution_backward_batch() local to = math.random(1,64) local ki = math.random(3,15) local kj = math.random(3,15) - local si = 1 -- not supported by CPU version yet - local sj = si + local si = math.random(1,ki-1) + local sj = math.random(1,kj-1) local outi = math.random(1,64) local outj = math.random(1,64) local ini = (outi-1)*si+ki local inj = (outj-1)*sj+kj + local scale = math.random() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,to,outj,outi):cuda() local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() sconv:forward(input) sconv:zeroGradParameters() - local groundgrad = sconv:backward(input, gradOutput) + local groundgrad = sconv:backward(input, gradOutput, scale) cutorch.synchronize() local groundweight = sconv.gradWeight local groundbias = sconv.gradBias @@ -71,7 +70,7 @@ function cudnntest.SpatialConvolution_backward_batch() gconv:forward(input) gconv:zeroGradParameters() - local rescuda = gconv:backward(input, gradOutput) + local rescuda = gconv:backward(input, gradOutput, scale) cutorch.synchronize() local weightcuda = gconv.gradWeight local biascuda = gconv.gradBias @@ -514,7 +513,7 @@ mytester:add(cudnntest) for i=1,cutorch.getDeviceCount() do print('Running test on device: ' .. i) cutorch.setDevice(i) - mytester:run(tests) + mytester:run() end os.execute('rm -f modelTemp.t7') |