diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2017-03-12 07:26:15 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2017-03-12 07:26:15 +0300 |
commit | 412e492caaa33845975140630c9753f40130d9bd (patch) | |
tree | 6b64577a12f5cb815b84be322a9715da1c56e0a0 | |
parent | 744f79b10b20a2c7561ae352f923f2b331583fcd (diff) | |
parent | f2a1e328cf18290df9ca4ea9609d843443f5e571 (diff) |
Merge branch '17.03-devel' into 17.04-devel
Conflicts:
CMakeLists.txt
RNN.lua
-rw-r--r-- | CMakeLists.txt | 12 | ||||
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | RNN.lua | 49 | ||||
-rw-r--r-- | SpatialConvolution.lua | 2 | ||||
-rw-r--r-- | SpatialDilatedConvolution.lua | 73 | ||||
-rw-r--r-- | VolumetricDilatedConvolution.lua | 62 | ||||
-rw-r--r-- | VolumetricFullConvolution.lua | 8 | ||||
-rw-r--r-- | convert.lua | 2 | ||||
-rw-r--r-- | ffi.lua | 575 | ||||
-rw-r--r-- | init.lua | 6 | ||||
-rw-r--r-- | test/test.lua | 86 | ||||
-rw-r--r-- | test/test_rnn.lua | 55 |
12 files changed, 618 insertions, 317 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cf64b8..9e4648b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,18 +9,18 @@ IF(LUAROCKS_PREFIX) STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}") MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}") ENDIF() + FIND_PACKAGE(Torch REQUIRED) -FIND_PACKAGE(CUDA 7.0 REQUIRED) +FIND_PACKAGE(CUDA 8.0 REQUIRED) -FIND_PACKAGE(CUDNN 5 EXACT QUIET) +FIND_PACKAGE(CUDNN 6 EXACT QUIET) IF(NOT CUDNN_FOUND) - CUDNN_INSTALL(5.1 "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "${Torch_INSTALL_BIN}") - FIND_PACKAGE(CUDNN 5 EXACT REQUIRED) + CUDNN_INSTALL(6.0-rc "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "") + FIND_PACKAGE(CUDNN 6 EXACT REQUIRED) + ENDIF() FILE(GLOB luasrc *.lua) SET(src "") ADD_TORCH_PACKAGE(cudnn "${src}" "${luasrc}" "NVIDIA CuDNN Bindings") - - @@ -122,6 +122,11 @@ nn.Sequential { (2): nn.ReLU } ``` +### New for cudnn V6 +Persistent mode for RNNs is enabled. RNNs can be run in persistent mode on Pascal family GPUs, and they are expected to be faster +for small batch size. They can be enabled by calling :setPersist(true) on any RNN module (RNN, LSTM, or GRU). +Dilated convolutions have been enabled. + ### Older versions For version CuDNN R1, checkout the branch **R1** @@ -37,12 +37,27 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem self.cellOutput = torch.CudaTensor() self.gradHiddenInput = torch.CudaTensor() self.gradCellInput = torch.CudaTensor() + self.persistent = false -- set true to use persistent RNNs self:training() self:reset() end function RNN:setSync(sync) - self.sync = sync + if sync==nil then + self.sync = true + else + self.sync = sync + end +end + +function RNN:setPersist(persist) + --sets persistent mode for any argument except persist=false + if persist==nil then + self.persistent = true + else + self.persistent = persist + end + self:resetRNNDescriptor() end function RNN:reset(stdv) @@ -122,7 +137,14 @@ function RNN:resetRNNDescriptor() if not self.rnnDesc then self.rnnDesc = self:createRNNDescriptors(1) end - errcheck('cudnnSetRNNDescriptor', + local algo + if self.persistent then + algo = 'CUDNN_RNN_ALGO_PERSIST_STATIC' + else + algo = 'CUDNN_RNN_ALGO_STANDARD' + end + local status = cudnn.call('cudnnSetRNNDescriptor_v6', + cudnn.getHandle(), self.rnnDesc[0], self.hiddenSize, self.numLayers, @@ -130,7 +152,30 @@ function RNN:resetRNNDescriptor() self.inputMode, self.bidirectional, self.mode, + algo, self.datatype) + if status ~= ffi.C.CUDNN_STATUS_SUCCESS then + if algo == 'CUDNN_RNN_ALGO_PERSIST_STATIC' then + --try using standard algo + print("Warning: persistent RNN is not supported for this configuration. Switching to standard") + algo = 'CUDNN_RNN_ALGO_STANDARD' + self.persistent = false + errcheck('cudnnSetRNNDescriptor_v6', + cudnn.getHandle(), + self.rnnDesc[0], + self.hiddenSize, + self.numLayers, + self.dropoutDesc[0], + self.inputMode, + self.bidirectional, + self.mode, + algo, + self.datatype) + else + local str = ffi.string(C.cudnnGetErrorString(status)) + error('Error in CuDNN: ' .. str .. ' (cudnnSetRNNDescriptor_v6)') + end + end end function RNN:resetWeightDescriptor() diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 830a7e6..7f9dd23 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -129,7 +129,7 @@ function SpatialConvolution:createIODescriptors(input) self.convDescData = { padA = self.pad, filterStrideA = self.stride, - upscaleA = {1,1}, + dilationA = {1,1}, dataType = cudnn.configmap(torch.type(self.weight)) } diff --git a/SpatialDilatedConvolution.lua b/SpatialDilatedConvolution.lua new file mode 100644 index 0000000..5e39c03 --- /dev/null +++ b/SpatialDilatedConvolution.lua @@ -0,0 +1,73 @@ +local SpatialDilatedConvolution, parent = + torch.class('cudnn.SpatialDilatedConvolution', 'cudnn.SpatialConvolution') +local ffi = require 'ffi' +local find = require 'cudnn.find' + +function SpatialDilatedConvolution:__init(nInputPlane, nOutputPlane, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, groups) + local delayedReset = self.reset + self.reset = function() end + parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, groups)--, dilationW, dilationH) + self.dilationW = dilationW + self.dilationH = dilationH +end + +function SpatialDilatedConvolution:createIODescriptors(input) + local batch = true + if input:dim() == 3 then + input = input:view(1, input:size(1), input:size(2), input:size(3)) + batch = false + end + if parent.checkInputChanged(self, input) then + -- create input descriptor + local input_slice = input:narrow(2,1,self.nInputPlane/self.groups) + self.iDesc = cudnn.toDescriptor(input_slice) + -- create conv descriptor + self.padH, self.padW = self.padH or 0, self.padW or 0 + -- those needed to calculate hash + self.pad = {self.padH, self.padW} + self.stride = {self.dH, self.dW} + local t_dataType = cudnn.configmap(torch.type(self.weight)) + --fallback to fp32 math if half type, fp16 dilated convs not fully implmented in cuDNN 6.0.2 + if( t_dataType == 'CUDNN_DATA_HALF') then t_dataType = 'CUDNN_DATA_FLOAT' end + self.convDescData = { + padA = self.pad, + filterStrideA = self.stride, + dilationA = {self.dilationH, self.dilationW}, + dataType = t_dataType + } + self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData) + + -- get output shape, resize output + local oSize = torch.IntTensor(4) + cudnn.errcheck('cudnnGetConvolutionNdForwardOutputDim', + self.convDesc[0], self.iDesc[0], + self.weightDesc[0], 4, oSize:data()) + oSize[2] = oSize[2] * self.groups + self.output:resize(oSize:long():storage()) + self.oSize = self.output:size() + + local output_slice = self.output:narrow(2,1,self.nOutputPlane/self.groups) + -- create descriptor for output + self.oDesc = cudnn.toDescriptor(output_slice) + self.oDescForBias = cudnn.toDescriptor(self.output) + + find:prepare(self, input_slice, output_slice) + + -- create offsets for groups + local iH, iW = input:size(3), input:size(4) + local kH, kW = self.kH, self.kW + local oH, oW = oSize[3], oSize[4] + self.input_offset = self.nInputPlane / self.groups * iH * iW + self.output_offset = self.nOutputPlane / self.groups * oH * oW + self.weight_offset = self.nInputPlane / self.groups * self.nOutputPlane / self.groups * kH * kW + + if not batch then + self.output = self.output:view(self.output:size(2), + self.output:size(3), + self.output:size(4)) + end + + end + return self +end diff --git a/VolumetricDilatedConvolution.lua b/VolumetricDilatedConvolution.lua new file mode 100644 index 0000000..652b646 --- /dev/null +++ b/VolumetricDilatedConvolution.lua @@ -0,0 +1,62 @@ +local VolumetricDilatedConvolution, parent + = torch.class('cudnn.VolumetricDilatedConvolution', 'cudnn.VolumetricConvolution') +local ffi = require 'ffi' +local find = require 'cudnn.find' + +local Convolution = cudnn.SpatialConvolution + +function VolumetricDilatedConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH) + parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH) + + self.dilationT = dilationT or 1 + self.dilationW = dilationW or 1 + self.dilationH = dilationH or 1 +end + + +function VolumetricDilatedConvolution:createIODescriptors(input) + if input:dim() == 4 then + input = input:view(1, input:size(1), input:size(2), + input:size(3), input:size(4)) + batch = false + end + if Convolution.checkInputChanged(self, input) then + -- create input descriptor + self.iDesc = cudnn.toDescriptor(input) + -- create conv descriptor + self.pad = {self.padT, self.padH, self.padW} + self.stride = {self.dT, self.dH, self.dW} + self.dilation = {self.dilationT, self.dilationH, self.dilationW} + + local mathtype=cudnn.configmap(torch.type(self.weight)) + -- 3D convolutions do not work in 16 bits + if mathtype == 'CUDNN_DATA_HALF' then + mathtype = 'CUDNN_DATA_FLOAT' + end + self.convDescData = { + padA = self.pad, + filterStrideA = self.stride, + dilationA = self.dilation, + dataType = mathtype + } + self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData) + + local oSize = torch.IntTensor(5) + cudnn.errcheck('cudnnGetConvolutionNdForwardOutputDim', + self.convDesc[0], self.iDesc[0], + self.weightDesc[0], 5, oSize:data()) + self.output:resize(oSize:long():storage()) + -- create descriptor for output + self.oDesc = cudnn.toDescriptor(self.output) + self.oDescForBias = cudnn.toDescriptor( + self.output:view(self.output:size(1), + self.output:size(2), + self.output:size(3)*self.output:size(4), + self.output:size(5))) + self.input_offset = 0 + self.output_offset = 0 + self.weight_offset = 0 + find:prepare(self, input, self.output) + + end +end diff --git a/VolumetricFullConvolution.lua b/VolumetricFullConvolution.lua index a662429..a51f09b 100644 --- a/VolumetricFullConvolution.lua +++ b/VolumetricFullConvolution.lua @@ -43,10 +43,16 @@ function VolumetricFullConvolution:createIODescriptors(input) local input_slice = input[{{},{1,self.nInputPlane},{},{}}] self.iDesc = cudnn.toDescriptor(input_slice) -- create conv descriptor + local mathtype = cudnn.configmap(torch.type(self.weight)) + -- 3D convolutions do not work in 16 bits + if mathtype == 'CUDNN_DATA_HALF' then + mathtype = 'CUDNN_DATA_FLOAT' + end + self.pad = {self.padT, self.padH, self.padW} self.stride = {self.dT, self.dH, self.dW} self.convDescData = { padA = self.pad, filterStrideA = self.stride, - dataType = cudnn.configmap(torch.type(self.weight))} + dataType = mathtype } self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData) -- get output shape, resize output diff --git a/convert.lua b/convert.lua index 4075122..7b4de1c 100644 --- a/convert.lua +++ b/convert.lua @@ -7,6 +7,7 @@ local layer_list = { 'SpatialFullConvolution', 'SpatialMaxPooling', 'SpatialAveragePooling', + 'SpatialDilatedConvolution', 'ReLU', 'Tanh', 'Sigmoid', @@ -17,6 +18,7 @@ local layer_list = { 'VolumetricFullConvolution', 'VolumetricMaxPooling', 'VolumetricAveragePooling', + 'VolumetricDilatedConvolution', } -- goes over a given net and converts all layers to dst backend @@ -3,11 +3,17 @@ local ffi = require 'ffi' ffi.cdef[[ +typedef enum +{ + MAJOR_VERSION, + MINOR_VERSION, + PATCH_LEVEL +} libraryPropertyType; typedef enum { - CUDNN_MAJOR = 5, + CUDNN_MAJOR = 6, CUDNN_MINOR = 0, - CUDNN_PATCHLEVEL = 4, + CUDNN_PATCHLEVEL = 2, CUDNN_VERSION = (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL) } cudnnVerFakeEnum; @@ -21,22 +27,25 @@ size_t cudnnGetVersion(void); */ typedef enum { - CUDNN_STATUS_SUCCESS = 0, - CUDNN_STATUS_NOT_INITIALIZED = 1, - CUDNN_STATUS_ALLOC_FAILED = 2, - CUDNN_STATUS_BAD_PARAM = 3, - CUDNN_STATUS_INTERNAL_ERROR = 4, - CUDNN_STATUS_INVALID_VALUE = 5, - CUDNN_STATUS_ARCH_MISMATCH = 6, - CUDNN_STATUS_MAPPING_ERROR = 7, - CUDNN_STATUS_EXECUTION_FAILED = 8, - CUDNN_STATUS_NOT_SUPPORTED = 9, - CUDNN_STATUS_LICENSE_ERROR = 10 + CUDNN_STATUS_SUCCESS = 0, + CUDNN_STATUS_NOT_INITIALIZED = 1, + CUDNN_STATUS_ALLOC_FAILED = 2, + CUDNN_STATUS_BAD_PARAM = 3, + CUDNN_STATUS_INTERNAL_ERROR = 4, + CUDNN_STATUS_INVALID_VALUE = 5, + CUDNN_STATUS_ARCH_MISMATCH = 6, + CUDNN_STATUS_MAPPING_ERROR = 7, + CUDNN_STATUS_EXECUTION_FAILED = 8, + CUDNN_STATUS_NOT_SUPPORTED = 9, + CUDNN_STATUS_LICENSE_ERROR = 10, + CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING = 11 } cudnnStatus_t; /* human-readable error messages*/ const char * cudnnGetErrorString(cudnnStatus_t status); +cudnnStatus_t cudnnGetProperty(libraryPropertyType type, int *value); + cudnnStatus_t cudnnCreate (cudnnHandle_t *handle); cudnnStatus_t cudnnDestroy (cudnnHandle_t handle); cudnnStatus_t cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId); @@ -52,6 +61,7 @@ typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t; typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t; typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t; typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t; +typedef struct cudnnReduceTensorStruct* cudnnReduceTensorDescriptor_t; /* * CUDNN data type */ @@ -60,16 +70,27 @@ typedef enum CUDNN_DATA_FLOAT = 0, CUDNN_DATA_DOUBLE = 1, CUDNN_DATA_HALF = 2, + CUDNN_DATA_INT8 = 3, + CUDNN_DATA_INT32 = 4, + CUDNN_DATA_INT8x4 = 5 } cudnnDataType_t; /* * CUDNN propagate Nan */ -typedef enum{ +typedef enum { CUDNN_NOT_PROPAGATE_NAN = 0, CUDNN_PROPAGATE_NAN = 1, } cudnnNanPropagation_t; +/* + * CUDNN Determinism + */ +typedef enum { + CUDNN_NON_DETERMINISTIC = 0, + CUDNN_DETERMINISTIC = 1, +} cudnnDeterminism_t; + /* Maximum supported number of tensor dimensions */ typedef enum { CUDNN_DIM_MAX = 8 } cudnnDimMaxFakeEnum; @@ -79,8 +100,9 @@ cudnnStatus_t cudnnCreateTensorDescriptor( typedef enum { - CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */ - CUDNN_TENSOR_NHWC = 1 /* feature maps interleaved ( cStride = 1 )*/ + CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */ + CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/ + CUDNN_TENSOR_NCHW_VECT_C = 2 /* each image point is vector of element of C : the length of the vector is carried by the data type*/ } cudnnTensorFormat_t; cudnnStatus_t cudnnSetTensor4dDescriptor( @@ -124,7 +146,14 @@ cudnnStatus_t cudnnSetTensorNdDescriptor( const int dimA[], const int strideA[] ); -cudnnStatus_t cudnnGetTensorNdDescriptor( +cudnnStatus_t cudnnSetTensorNdDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, + cudnnTensorFormat_t format, + cudnnDataType_t dataType, + int nbDims, + const int dimA[] ); + +cudnnStatus_t cudnnGetTensorNdDescriptor( const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, cudnnDataType_t *dataType, @@ -132,6 +161,11 @@ cudnnStatus_t cudnnGetTensorNdDescriptor( int dimA[], int strideA[] ); + +cudnnStatus_t cudnnGetTensorSizeInBytes( + const cudnnTensorDescriptor_t tensorDesc, + size_t *size); + /* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride 1)Example of all images in row major order one batch of features after the other (with an optional padding on row) @@ -186,10 +220,11 @@ cudnnStatus_t cudnnAddTensor( */ typedef enum { - CUDNN_OP_TENSOR_ADD = 0, - CUDNN_OP_TENSOR_MUL = 1, - CUDNN_OP_TENSOR_MIN = 2, - CUDNN_OP_TENSOR_MAX = 3, + CUDNN_OP_TENSOR_ADD = 0, + CUDNN_OP_TENSOR_MUL = 1, + CUDNN_OP_TENSOR_MIN = 2, + CUDNN_OP_TENSOR_MAX = 3, + CUDNN_OP_TENSOR_SQRT = 4, } cudnnOpTensorOp_t; cudnnStatus_t cudnnCreateOpTensorDescriptor( @@ -224,6 +259,97 @@ cudnnStatus_t cudnnOpTensor( const cudnnTensorDescriptor_t cDesc, void *C ); +/* +* CUDNN ReduceTensor op type +*/ +typedef enum +{ + CUDNN_REDUCE_TENSOR_ADD = 0, + CUDNN_REDUCE_TENSOR_MUL = 1, + CUDNN_REDUCE_TENSOR_MIN = 2, + CUDNN_REDUCE_TENSOR_MAX = 3, + CUDNN_REDUCE_TENSOR_AMAX = 4, + CUDNN_REDUCE_TENSOR_AVG = 5, + CUDNN_REDUCE_TENSOR_NORM1 = 6, + CUDNN_REDUCE_TENSOR_NORM2 = 7, +} cudnnReduceTensorOp_t; + +/* +* CUDNN ReduceTensor indices type +*/ +typedef enum +{ + CUDNN_REDUCE_TENSOR_NO_INDICES = 0, + CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1, +} cudnnReduceTensorIndices_t; + +/* +* CUDNN tensor indices type size (all unsigned) +* Currently not supported, default is 32 bit unsigned. +*/ +typedef enum +{ + CUDNN_32BIT_INDICES = 0, + CUDNN_64BIT_INDICES = 1, + CUDNN_16BIT_INDICES = 2, + CUDNN_8BIT_INDICES = 3, +} cudnnIndicesType_t; + +cudnnStatus_t cudnnCreateReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t *reduceTensorDesc ); + +cudnnStatus_t cudnnSetReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, + cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType ); + +cudnnStatus_t cudnnGetReduceTensorDescriptor( + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType ); + +cudnnStatus_t cudnnDestroyReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc ); + + /* Helper function to return the minimum size of the index space to be passed to the reduction given the input and output tensors */ +cudnnStatus_t cudnnGetReductionIndicesSize( + cudnnHandle_t handle, + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, + const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes ); + + /* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output tensors */ +cudnnStatus_t cudnnGetReductionWorkspaceSize( + cudnnHandle_t handle, + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, + const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes ); + +/* Tensor operation : C = reduce op( alpha * A ) + beta * C */ +/* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */ +/* The indices space is ignored for reduce ops other than min or max. */ +cudnnStatus_t cudnnReduceTensor( + cudnnHandle_t handle, + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, + size_t indicesSizeInBytes, + void *workspace, + size_t workspaceSizeInBytes, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, + const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C ); + /* Set all values of a tensor to a given value : y[i] = value[0] */ cudnnStatus_t cudnnSetTensor( cudnnHandle_t handle, @@ -296,46 +422,26 @@ cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnStatus_t cudnnCreateConvolutionDescriptor( cudnnConvolutionDescriptor_t *convDesc ); -cudnnStatus_t cudnnSetConvolution2dDescriptor( - cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height*/ - int pad_w, /* zero-padding width*/ - int u, /* vertical filter stride*/ - int v, /* horizontal filter stride*/ - int upscalex, /* upscale the input in x-direction*/ - int upscaley, /* upscale the input in y-direction*/ - cudnnConvolutionMode_t mode ); - -cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height*/ - int pad_w, /* zero-padding width*/ - int u, /* vertical filter stride*/ - int v, /* horizontal filter stride*/ - int upscalex, /* upscale the input in x-direction*/ - int upscaley, /* upscale the input in y-direction*/ +cudnnStatus_t cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension cudnnConvolutionMode_t mode, - cudnnDataType_t dataType + cudnnDataType_t computeType ); -cudnnStatus_t cudnnGetConvolution2dDescriptor( - const cudnnConvolutionDescriptor_t convDesc, - int *pad_h, /* zero-padding height*/ - int *pad_w, /* zero-padding width*/ - int *u, /* vertical filter stride*/ - int *v, /* horizontal filter stride*/ - int *upscalex, /* upscale the input in x-direction*/ - int *upscaley, /* upscale the input in y-direction*/ - cudnnConvolutionMode_t *mode ); - -cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc, - int* pad_h, /* zero-padding height*/ - int* pad_w, /* zero-padding width*/ - int* u, /* vertical filter stride*/ - int* v, /* horizontal filter stride*/ - int* upscalex, /* upscale the input in x-direction*/ - int* upscaley, /* upscale the input in y-direction*/ +cudnnStatus_t cudnnGetConvolution2dDescriptor( const cudnnConvolutionDescriptor_t convDesc, + int* pad_h, // zero-padding height + int* pad_w, // zero-padding width + int* u, // vertical filter stride + int* v, // horizontal filter stride + int* dilation_h, // filter dilation in the vertical dimension + int* dilation_w, // filter dilation in the horizontal dimension cudnnConvolutionMode_t* mode, - cudnnDataType_t *dataType + cudnnDataType_t *computeType ); /* Helper function to return the dimensions of the output tensor given a convolution descriptor */ @@ -354,19 +460,19 @@ cudnnStatus_t cudnnSetConvolutionNdDescriptor( int arrayLength, /* nbDims-2 size */ const int padA[], const int filterStrideA[], - const int upscaleA[], + const int dilationA[], cudnnConvolutionMode_t mode, - cudnnDataType_t dataType ); /* convolution data type*/ + cudnnDataType_t computeType ); // convolution data type -cudnnStatus_t cudnnGetConvolutionNdDescriptor( +cudnnStatus_t cudnnGetConvolutionNdDescriptor( const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, int *arrayLength, int padA[], int strideA[], - int upscaleA[], + int dilationA[], cudnnConvolutionMode_t *mode, - cudnnDataType_t *dataType ); /* convolution data type*/ + cudnnDataType_t *computeType ); // convolution data type /* Helper function to return the dimensions of the output tensor given a convolution descriptor */ @@ -400,7 +506,8 @@ typedef enum CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4, CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5, CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7 + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8, } cudnnConvolutionFwdAlgo_t; typedef struct { @@ -408,6 +515,8 @@ typedef struct { cudnnStatus_t status; float time; size_t memory; + cudnnDeterminism_t determinism; + int reserved[4]; } cudnnConvolutionFwdAlgoPerf_t; cudnnStatus_t cudnnFindConvolutionForwardAlgorithm( @@ -479,8 +588,29 @@ cudnnStatus_t cudnnConvolutionForward( const cudnnTensorDescriptor_t yDesc, void *y ); +/* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */ +cudnnStatus_t cudnnConvolutionBiasActivationForward( + cudnnHandle_t handle, + const void *alpha1, + const cudnnTensorDescriptor_t xDesc, + const void *x, + const cudnnFilterDescriptor_t wDesc, + const void *w, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionFwdAlgo_t algo, + void *workSpace, + size_t workSpaceSizeInBytes, + const void *alpha2, + const cudnnTensorDescriptor_t zDesc, + const void *z, + const cudnnTensorDescriptor_t biasDesc, + const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, + void *y ); + /* Function to compute the bias gradient for batch convolution */ -cudnnStatus_t cudnnConvolutionBackwardBias( +cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t dyDesc, @@ -504,16 +634,19 @@ typedef enum CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic, algo0 with workspace*/ - /* CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, not implemented */ - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5 + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING= 6, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7, } cudnnConvolutionBwdFilterAlgo_t; - typedef struct { cudnnConvolutionBwdFilterAlgo_t algo; cudnnStatus_t status; float time; size_t memory; + cudnnDeterminism_t determinism; + int reserved[4]; } cudnnConvolutionBwdFilterAlgoPerf_t; cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm( @@ -596,7 +729,8 @@ typedef enum CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5 + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6, } cudnnConvolutionBwdDataAlgo_t; typedef struct { @@ -604,6 +738,8 @@ typedef struct { cudnnStatus_t status; float time; size_t memory; + cudnnDeterminism_t determinism; + int reserved[4]; } cudnnConvolutionBwdDataAlgoPerf_t; @@ -728,9 +864,10 @@ cudnnStatus_t cudnnSoftmaxBackward( typedef enum { CUDNN_POOLING_MAX = 0, - CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values*/ - CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values*/ - CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility + CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, // count for average includes padded values + CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, // count for average does not include padded values + CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, + CUDNN_POOLING_MAX_DETERMINISTIC = 3 } cudnnPoolingMode_t; /* Create an instance of pooling descriptor */ @@ -833,7 +970,8 @@ typedef enum CUDNN_ACTIVATION_SIGMOID = 0, CUDNN_ACTIVATION_RELU = 1, CUDNN_ACTIVATION_TANH = 2, - CUDNN_ACTIVATION_CLIPPED_RELU = 3 + CUDNN_ACTIVATION_CLIPPED_RELU = 3, + CUDNN_ACTIVATION_ELU = 4 } cudnnActivationMode_t; /* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */ @@ -844,13 +982,13 @@ cudnnStatus_t cudnnSetActivationDescriptor( cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, cudnnNanPropagation_t reluNanOpt, - double reluCeiling ); + double coef ); /* ceiling for clipped RELU, alpha for ELU */ cudnnStatus_t cudnnGetActivationDescriptor( const cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t *mode, cudnnNanPropagation_t *reluNanOpt, - double* reluCeiling ); + double* coef ); /* ceiling for clipped RELU, alpha for ELU */ cudnnStatus_t cudnnDestroyActivationDescriptor( cudnnActivationDescriptor_t activationDesc); @@ -1234,13 +1372,50 @@ typedef enum } cudnnRNNInputMode_t; +typedef enum + { + CUDNN_RNN_ALGO_STANDARD = 0, + CUDNN_RNN_ALGO_PERSIST_STATIC = 1, + CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2 + } cudnnRNNAlgo_t; + struct cudnnRNNStruct; typedef struct cudnnRNNStruct* cudnnRNNDescriptor_t; -cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc); -cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc); +cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc); +cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc); + +struct cudnnPersistentRNNPlan; +typedef struct cudnnPersistentRNNPlan *cudnnPersistentRNNPlan_t; + -cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc, +// Expensive. Creates the plan for the specific settings. +cudnnStatus_t cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, + const int minibatch, + const cudnnDataType_t dataType, + cudnnPersistentRNNPlan_t * plan); + +// Attaches the plan to the descriptor. +cudnnStatus_t cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, + cudnnPersistentRNNPlan_t plan); + +cudnnStatus_t cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan); + + + +cudnnStatus_t cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, // Between layers, not between recurrent steps. + cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, + cudnnRNNAlgo_t algo, + cudnnDataType_t dataType); + + +cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, cudnnDropoutDescriptor_t dropoutDesc, @@ -1250,6 +1425,7 @@ cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc, cudnnDataType_t dataType); + // dataType in the RNN descriptor is used to determine math precision // dataType in weight descriptors and input descriptors is used to describe storage @@ -1388,202 +1564,52 @@ cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle, size_t reserveSpaceSizeInBytes ); - - - /* DEPRECATED routines to be removed next release : - User should use the non-suffixed version (which has the API and functionality of _v4 version) - Routines with _v3 suffix has the functionality of the non-suffixed routines in the CUDNN V4 + User should use the non-suffixed version (which has the API and functionality of _v5 version) + Routines with _v4 suffix has the functionality of the non-suffixed routines in the CUDNN V5 */ -cudnnStatus_t cudnnSetFilter4dDescriptor_v3( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type*/ - int k, /* number of output feature maps*/ - int c, /* number of input feature maps*/ - int h, /* height of each input filter*/ - int w ); /* width of each input filter*/ - -cudnnStatus_t cudnnSetFilter4dDescriptor_v4( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type*/ - cudnnTensorFormat_t format, - int k, /* number of output feature maps*/ - int c, /* number of input feature maps*/ - int h, /* height of each input filter*/ - int w ); /* width of each input filter*/ - -cudnnStatus_t cudnnGetFilter4dDescriptor_v3( - const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type*/ - int *k, /* number of output feature maps*/ - int *c, /* number of input feature maps*/ - int *h, /* height of each input filter*/ - int *w ); /* width of each input filter*/ - -cudnnStatus_t cudnnGetFilter4dDescriptor_v4( - const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type*/ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps*/ - int *c, /* number of input feature maps*/ - int *h, /* height of each input filter*/ - int *w ); /* width of each input filter */ - -cudnnStatus_t cudnnSetFilterNdDescriptor_v3( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type*/ - int nbDims, - const int filterDimA[] ); - - -cudnnStatus_t cudnnSetFilterNdDescriptor_v4( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type*/ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[] ); - -cudnnStatus_t cudnnGetFilterNdDescriptor_v3( - const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type*/ - int *nbDims, - int filterDimA[] ); - -cudnnStatus_t cudnnGetFilterNdDescriptor_v4( - const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type*/ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[] ); - -cudnnStatus_t cudnnSetPooling2dDescriptor_v3( - cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride ); - -cudnnStatus_t cudnnSetPooling2dDescriptor_v4( - cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride ); -cudnnStatus_t cudnnGetPooling2dDescriptor_v3( - const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride ); - -cudnnStatus_t cudnnGetPooling2dDescriptor_v4( - const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride ); - -cudnnStatus_t cudnnSetPoolingNdDescriptor_v3( - cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[] ); - -cudnnStatus_t cudnnSetPoolingNdDescriptor_v4( - cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[] ); - -cudnnStatus_t cudnnGetPoolingNdDescriptor_v3( - const cudnnPoolingDescriptor_t poolingDesc, - const int nbDimsRequested, - cudnnPoolingMode_t *mode, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[] ); - -cudnnStatus_t cudnnGetPoolingNdDescriptor_v4( - const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[] ); - -cudnnStatus_t cudnnActivationForward_v3( - cudnnHandle_t handle, - cudnnActivationMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ); - -cudnnStatus_t cudnnActivationForward_v4( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ); +cudnnStatus_t cudnnSetConvolution2dDescriptor_v4( + cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t mode ); -cudnnStatus_t cudnnActivationBackward_v3( - cudnnHandle_t handle, - cudnnActivationMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ); +cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t mode, + cudnnDataType_t computeType + ); -cudnnStatus_t cudnnActivationBackward_v4( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ); +cudnnStatus_t cudnnGetConvolution2dDescriptor_v4( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, // zero-padding height + int *pad_w, // zero-padding width + int *u, // vertical filter stride + int *v, // horizontal filter stride + int *dilation_h, // filter dilation in the vertical dimension + int *dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t *mode ); +cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc, + int* pad_h, // zero-padding height + int* pad_w, // zero-padding width + int* u, // vertical filter stride + int* v, // horizontal filter stride + int* dilation_h, // filter dilation in the vertical dimension + int* dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t* mode, + cudnnDataType_t *computeType + ); ]] local CUDNN_PATH = os.getenv('CUDNN_PATH') @@ -1591,8 +1617,7 @@ if CUDNN_PATH then print('Found Environment variable CUDNN_PATH = ' .. CUDNN_PATH) cudnn.C = ffi.load(CUDNN_PATH) else - - local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib', 'cudnn64_5.dll'} + local libnames = {'libcudnn.so.6', 'libcudnn.6.dylib', 'cudnn64_6.dll'} local ok = false for i=1,#libnames do ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end) @@ -1600,22 +1625,22 @@ else end if not ok then - error([['libcudnn (R5) not found in library path. + error([['libcudnn (R6\) not found in library path. Please install CuDNN from https://developer.nvidia.com/cuDNN -Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in +Then make sure files named as libcudnn.so.6 or libcudnn.6.dylib are placed in your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH) -Alternatively, set the path to libcudnn.so.5 or libcudnn.5.dylib +Alternatively, set the path to libcudnn.so.6 or libcudnn.6.dylib to the environment variable CUDNN_PATH and rerun torch. -For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.5" +For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.6" ]]) end end -- check cuDNN version cudnn.version = tonumber(cudnn.C.cudnnGetVersion()) -if cudnn.version < 5005 or cudnn.version >= 6000 then - error('These bindings are for CUDNN 5.x (5005 <= cudnn.version > 6000) , ' +if cudnn.version < 6002 then + error('These bindings are for version 6002 or above, ' .. 'while the loaded CuDNN is version: ' .. cudnn.version .. ' \nAre you using an older or newer version of CuDNN?') end @@ -210,7 +210,7 @@ end function cudnn.setConvolutionDescriptor(data, desc) if not data.arrayLength then data.arrayLength = #data.padA end - if not data.upscaleA then data.upscaleA = torch.IntStorage(data.arrayLength):fill(1) end + if not data.dilationA then data.dilationA = {1,1,1 } end -- assume maximum length==3 if not data.mode then data.mode = 'CUDNN_CROSS_CORRELATION' end local myDesc = desc or cudnn.createDescriptors( @@ -219,7 +219,7 @@ function cudnn.setConvolutionDescriptor(data, desc) -- make sure we have references to these tensors so gc doesn't clean them up local padATensor = torch.IntTensor(data.padA) local filterStrideATensor = torch.IntTensor(data.filterStrideA) - local upscaleATensor = torch.IntTensor(data.upscaleA) + local upscaleATensor = torch.IntTensor(data.dilationA) errcheck('cudnnSetConvolutionNdDescriptor', myDesc[0], data.arrayLength, padATensor:data(), @@ -328,6 +328,8 @@ cudnn.find = require('cudnn.find') require('cudnn.SpatialConvolution') require('cudnn.VolumetricConvolution') +require('cudnn.SpatialDilatedConvolution') +require('cudnn.VolumetricDilatedConvolution') require('cudnn.SpatialFullConvolution') require('cudnn.VolumetricFullConvolution') require('cudnn.Pooling') diff --git a/test/test.lua b/test/test.lua index d958041..ffe6cc3 100644 --- a/test/test.lua +++ b/test/test.lua @@ -13,7 +13,7 @@ local testparams_half = { precision_forward = 2e-1, precision_backward = 10, precision_jac = 1e-3, - precision_io = 1e-1, + precision_io = 3e-1, } local testparams_float = { @@ -146,6 +146,84 @@ function cudnntest.SpatialConvolution() testLayer(sconv, gconv, input, gradOutput, scale, true, false) end +function cudnntest.SpatialDilatedConvolution() + local bs = math.random(1,32) + local from = math.random(1,32) + local to = math.random(1,64) + local ki = math.random(1,7) + local kj = math.random(1,7) + local di = math.random(1,7) + local dj = math.random(1,7) + local wi = (ki-1)*di+1 + local wj = (kj-1)*dj+1 + local si = math.random(1, wi ) + local sj = math.random(1, wj) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = (outi-1)*si+wi + local inj = (outj-1)*sj+wj + + local scale = math.random() + + local input = torch.randn(bs,from,inj,ini) + local gradOutput = torch.randn(bs,to,outj,outi) + local sconv = nn.SpatialDilatedConvolution(from,to,ki,kj,si,sj,0,0,di,dj) + local gconv = cast(cudnn.SpatialDilatedConvolution(from,to,ki,kj,si,sj,0,0,di,dj)):fastest() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + + testLayer(sconv, gconv, input, gradOutput, scale, true, true) -- batch + testLayer(sconv, gconv, input, gradOutput, scale, true, false) -- non-batch + local originalTypename = torch.typename(gconv) + local gconv = cast(cudnn.convert(sconv, cudnn)) + mytester:asserteq(torch.typename(gconv), + originalTypename, 'conversion type check') + testLayer(sconv, gconv, input, gradOutput, scale, true, true) + testLayer(sconv, gconv, input, gradOutput, scale, true, false) +end + +function cudnntest.VolumetricDilatedConvolution() + local bs = math.random(1,32) + local from = math.random(1,16) + local to = math.random(1,16) + local ki = math.random(3,5) + local kj = math.random(3,5) + local kk = math.random(3,5) + local di = math.random(1,5) + local dj = math.random(1,5) + local dk = math.random(1,5) + local wi = (ki-1)*di+1 + local wj = (kj-1)*dj+1 + local wk = (kk-1)*dk+1 + local si = math.random(1, wi ) + local sj = math.random(1, wj) + local sk = math.random(1, wk) + local outi = math.random(1,17) + local outj = math.random(1,17) + local outk = math.random(1,5) + local ini = (outi-1)*si+wi + local inj = (outj-1)*sj+wj + local ink = (outk-1)*sk+wk + + local scale = math.random() + + local input = torch.randn(bs,from,ini,ink,inj) + local gradOutput = torch.randn(bs,to,outi,outk,outj) + local sconv = nn.VolumetricDilatedConvolution(from,to,ki,kj,kk,si,sj,sk,0,0,0,di,dj,dk) + local gconv = cast(cudnn.VolumetricDilatedConvolution(from,to,ki,kj,kk,si,sj,sk,0,0,0,di,dj,dk)) + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + + testLayer(sconv, gconv, input, gradOutput, scale, true, true) -- batch + testLayer(sconv, gconv, input, gradOutput, scale, true, false) -- non-batch + local originalTypename = torch.typename(gconv) + local gconv = cast(cudnn.convert(sconv, cudnn)) + mytester:asserteq(torch.typename(gconv), + originalTypename, 'conversion type check') + testLayer(sconv, gconv, input, gradOutput, scale, true, true) + testLayer(sconv, gconv, input, gradOutput, scale, true, false) +end + function cudnntest.SpatialFullConvolution() local bs = math.random(1,32) local from = math.random(1,32) @@ -813,7 +891,7 @@ function cudnntest.functional_bias2D() local inj = (outj-1)*sj+kj local scale = torch.uniform() local input = cast(torch.zeros(bs,from,inj,ini)) - local mod = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj)) + local mod = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj):fastest()) mod.weight:zero() local groundtruth = mod:forward(input) local result = groundtruth:clone():zero() @@ -943,8 +1021,8 @@ mytester:add(cudnntest) cudnn.verbose=false cudnn.find.verbose=false --- this is the default, keep it for demo of 16->32 bit float fallback -cudnn.find.verboseFallback=true +-- todo: put it back for release to demo 16->32 bit float fallback +cudnn.find.verboseFallback=false cudnn.useFindEx=false cudnn.configureMath({ ['torch.CudaHalfTensor'] = 'CUDNN_DATA_FLOAT'} ) for i = 1, 1 do -- cutorch.getDeviceCount() do diff --git a/test/test_rnn.lua b/test/test_rnn.lua index 0d0b37b..63520b6 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -12,6 +12,9 @@ local cudnntest = torch.TestSuite() local mytester local tolerance = 1000 +local btol = tolerance * 20 +local stol = tolerance/100 +local tinytol = .1 function cudnntest.testRNNRELU() local miniBatch = 64 @@ -26,9 +29,9 @@ function cudnntest.testRNNRELU() -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 1.315793E+06, tolerance, 'checkSum with reference for localsumi failed') mytester:assertalmosteq(checkSums.localSumh, 1.315212E+05, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, tolerance, 'checkSum with reference for localSumdh failed') - mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, tolerance, 'checkSum with reference for localSumdw failed') + mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, stol, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, stol, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, btol, 'checkSum with reference for localSumdw failed') end function cudnntest.testRNNBatchFirst() @@ -45,9 +48,9 @@ function cudnntest.testRNNBatchFirst() -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 1.315793E+06, tolerance, 'checkSum with reference for localsumi failed') mytester:assertalmosteq(checkSums.localSumh, 1.315212E+05, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, tolerance, 'checkSum with reference for localSumdh failed') - mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, tolerance, 'checkSum with reference for localSumdw failed') + mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, stol, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, stol, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, btol, 'checkSum with reference for localSumdw failed') end function cudnntest.testRNNTANH() @@ -63,8 +66,8 @@ function cudnntest.testRNNTANH() -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 6.319591E+05, tolerance, 'checkSum with reference for localsumi failed') mytester:assertalmosteq(checkSums.localSumh, 6.319605E+04, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 4.501830E+00, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdh, 4.489546E+00, tolerance, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumdi, 4.501830E+00, 1, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdh, 4.489546E+00, 1, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 5.012598E+07, tolerance, 'checkSum with reference for localSumdw failed') end @@ -81,9 +84,9 @@ function cudnntest.testRNNLSTM() mytester:assertalmosteq(checkSums.localSumi, 5.749536E+05, tolerance, 'checkSum with reference for localsumi failed') mytester:assertalmosteq(checkSums.localSumc, 4.365091E+05, tolerance, 'checkSum with reference for localSumc failed') mytester:assertalmosteq(checkSums.localSumh, 5.774818E+04, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 3.842206E+02, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdc, 9.323785E+03, tolerance, 'checkSum with reference for localSumdc failed') - mytester:assertalmosteq(checkSums.localSumdh, 1.182566E+01, tolerance, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumdi, 3.842206E+02, stol, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdc, 9.323785E+03, stol, 'checkSum with reference for localSumdc failed') + mytester:assertalmosteq(checkSums.localSumdh, 1.182566E+01, tinytol, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 4.313461E+08, tolerance, 'checkSum with reference for localSumdw failed') end @@ -98,7 +101,7 @@ function cudnntest.testRNNGRU() -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 6.358978E+05, tolerance, 'checkSum with reference for localsumi failed') mytester:assertalmosteq(checkSums.localSumh, 6.281680E+04, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 6.296622E+00, tolerance, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdi, 6.296622E+00, tinytol, 'checkSum with reference for localSumdi failed') mytester:assertalmosteq(checkSums.localSumdh, 2.289960E+05, tolerance, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 5.397419E+07, tolerance, 'checkSum with reference for localSumdw failed') end @@ -118,10 +121,10 @@ function cudnntest.testBiDirectionalRELURNN() local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections) -- Checksums to check against are retrieved from cudnn RNN sample. - mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tolerance, 'checkSum with reference for localsumi failed') - mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tolerance, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tinytol, 'checkSum with reference for localsumi failed') + mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tinytol, 'checkSum with reference for localSumh failed') + mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tinytol, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tinytol, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 7.061081E+07, tolerance, 'checkSum with reference for localSumdw failed') end @@ -140,10 +143,10 @@ function cudnntest.testBiDirectionalTANHRNN() local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections) -- Checksums to check against are retrieved from cudnn RNN sample. - mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tolerance, 'checkSum with reference for localsumi failed') - mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tolerance, 'checkSum with reference for localSumdi failed') - mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tolerance, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tinytol, 'checkSum with reference for localsumi failed') + mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tinytol, 'checkSum with reference for localSumh failed') + mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tinytol, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tinytol, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 7.061081E+07, tolerance, 'checkSum with reference for localSumdw failed') end @@ -160,11 +163,11 @@ function cudnntest.testBiDirectionalLSTMRNN() local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections) -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 3.134097E+04, tolerance, 'checkSum with reference for localsumi failed') - mytester:assertalmosteq(checkSums.localSumc, 3.845626E+00, tolerance, 'checkSum with reference for localSumc failed') - mytester:assertalmosteq(checkSums.localSumh, 1.922855E+00, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 4.794993E+00, tolerance, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumc, 3.845626E+00, tinytol, 'checkSum with reference for localSumc failed') + mytester:assertalmosteq(checkSums.localSumh, 1.922855E+00, tinytol, 'checkSum with reference for localSumh failed') + mytester:assertalmosteq(checkSums.localSumdi, 4.794993E+00, tinytol, 'checkSum with reference for localSumdi failed') mytester:assertalmosteq(checkSums.localSumdc, 2.870925E+04, tolerance, 'checkSum with reference for localSumdc failed') - mytester:assertalmosteq(checkSums.localSumdh, 2.468645E+00, tolerance, 'checkSum with reference for localSumdh failed') + mytester:assertalmosteq(checkSums.localSumdh, 2.468645E+00, tinytol, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 1.121568E+08, tolerance, 'checkSum with reference for localSumdw failed') end @@ -184,8 +187,8 @@ function cudnntest.testBiDirectionalGRURNN() local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections) -- Checksums to check against are retrieved from cudnn RNN sample. mytester:assertalmosteq(checkSums.localSumi, 6.555183E+04, tolerance, 'checkSum with reference for localsumi failed') - mytester:assertalmosteq(checkSums.localSumh, 5.830924E+00, tolerance, 'checkSum with reference for localSumh failed') - mytester:assertalmosteq(checkSums.localSumdi, 4.271801E+00, tolerance, 'checkSum with reference for localSumdi failed') + mytester:assertalmosteq(checkSums.localSumh, 5.830924E+00, tinytol, 'checkSum with reference for localSumh failed') + mytester:assertalmosteq(checkSums.localSumdi, 4.271801E+00, tinytol, 'checkSum with reference for localSumdi failed') mytester:assertalmosteq(checkSums.localSumdh, 6.555744E+04, tolerance, 'checkSum with reference for localSumdh failed') mytester:assertalmosteq(checkSums.localSumdw, 1.701796E+08, tolerance, 'checkSum with reference for localSumdw failed') end |