Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2017-03-12 07:26:15 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2017-03-12 07:26:15 +0300
commit412e492caaa33845975140630c9753f40130d9bd (patch)
tree6b64577a12f5cb815b84be322a9715da1c56e0a0
parent744f79b10b20a2c7561ae352f923f2b331583fcd (diff)
parentf2a1e328cf18290df9ca4ea9609d843443f5e571 (diff)
Merge branch '17.03-devel' into 17.04-devel
Conflicts: CMakeLists.txt RNN.lua
-rw-r--r--CMakeLists.txt12
-rw-r--r--README.md5
-rw-r--r--RNN.lua49
-rw-r--r--SpatialConvolution.lua2
-rw-r--r--SpatialDilatedConvolution.lua73
-rw-r--r--VolumetricDilatedConvolution.lua62
-rw-r--r--VolumetricFullConvolution.lua8
-rw-r--r--convert.lua2
-rw-r--r--ffi.lua575
-rw-r--r--init.lua6
-rw-r--r--test/test.lua86
-rw-r--r--test/test_rnn.lua55
12 files changed, 618 insertions, 317 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4cf64b8..9e4648b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -9,18 +9,18 @@ IF(LUAROCKS_PREFIX)
STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}")
MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}")
ENDIF()
+
FIND_PACKAGE(Torch REQUIRED)
-FIND_PACKAGE(CUDA 7.0 REQUIRED)
+FIND_PACKAGE(CUDA 8.0 REQUIRED)
-FIND_PACKAGE(CUDNN 5 EXACT QUIET)
+FIND_PACKAGE(CUDNN 6 EXACT QUIET)
IF(NOT CUDNN_FOUND)
- CUDNN_INSTALL(5.1 "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "${Torch_INSTALL_BIN}")
- FIND_PACKAGE(CUDNN 5 EXACT REQUIRED)
+ CUDNN_INSTALL(6.0-rc "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "")
+ FIND_PACKAGE(CUDNN 6 EXACT REQUIRED)
+
ENDIF()
FILE(GLOB luasrc *.lua)
SET(src "")
ADD_TORCH_PACKAGE(cudnn "${src}" "${luasrc}" "NVIDIA CuDNN Bindings")
-
-
diff --git a/README.md b/README.md
index ed3e610..c356988 100644
--- a/README.md
+++ b/README.md
@@ -122,6 +122,11 @@ nn.Sequential {
(2): nn.ReLU
}
```
+### New for cudnn V6
+Persistent mode for RNNs is enabled. RNNs can be run in persistent mode on Pascal family GPUs, and they are expected to be faster
+for small batch size. They can be enabled by calling :setPersist(true) on any RNN module (RNN, LSTM, or GRU).
+Dilated convolutions have been enabled.
+
### Older versions
For version CuDNN R1, checkout the branch **R1**
diff --git a/RNN.lua b/RNN.lua
index 8b1a965..6f2999c 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -37,12 +37,27 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem
self.cellOutput = torch.CudaTensor()
self.gradHiddenInput = torch.CudaTensor()
self.gradCellInput = torch.CudaTensor()
+ self.persistent = false -- set true to use persistent RNNs
self:training()
self:reset()
end
function RNN:setSync(sync)
- self.sync = sync
+ if sync==nil then
+ self.sync = true
+ else
+ self.sync = sync
+ end
+end
+
+function RNN:setPersist(persist)
+ --sets persistent mode for any argument except persist=false
+ if persist==nil then
+ self.persistent = true
+ else
+ self.persistent = persist
+ end
+ self:resetRNNDescriptor()
end
function RNN:reset(stdv)
@@ -122,7 +137,14 @@ function RNN:resetRNNDescriptor()
if not self.rnnDesc then
self.rnnDesc = self:createRNNDescriptors(1)
end
- errcheck('cudnnSetRNNDescriptor',
+ local algo
+ if self.persistent then
+ algo = 'CUDNN_RNN_ALGO_PERSIST_STATIC'
+ else
+ algo = 'CUDNN_RNN_ALGO_STANDARD'
+ end
+ local status = cudnn.call('cudnnSetRNNDescriptor_v6',
+ cudnn.getHandle(),
self.rnnDesc[0],
self.hiddenSize,
self.numLayers,
@@ -130,7 +152,30 @@ function RNN:resetRNNDescriptor()
self.inputMode,
self.bidirectional,
self.mode,
+ algo,
self.datatype)
+ if status ~= ffi.C.CUDNN_STATUS_SUCCESS then
+ if algo == 'CUDNN_RNN_ALGO_PERSIST_STATIC' then
+ --try using standard algo
+ print("Warning: persistent RNN is not supported for this configuration. Switching to standard")
+ algo = 'CUDNN_RNN_ALGO_STANDARD'
+ self.persistent = false
+ errcheck('cudnnSetRNNDescriptor_v6',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.hiddenSize,
+ self.numLayers,
+ self.dropoutDesc[0],
+ self.inputMode,
+ self.bidirectional,
+ self.mode,
+ algo,
+ self.datatype)
+ else
+ local str = ffi.string(C.cudnnGetErrorString(status))
+ error('Error in CuDNN: ' .. str .. ' (cudnnSetRNNDescriptor_v6)')
+ end
+ end
end
function RNN:resetWeightDescriptor()
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 830a7e6..7f9dd23 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -129,7 +129,7 @@ function SpatialConvolution:createIODescriptors(input)
self.convDescData = { padA = self.pad,
filterStrideA = self.stride,
- upscaleA = {1,1},
+ dilationA = {1,1},
dataType = cudnn.configmap(torch.type(self.weight))
}
diff --git a/SpatialDilatedConvolution.lua b/SpatialDilatedConvolution.lua
new file mode 100644
index 0000000..5e39c03
--- /dev/null
+++ b/SpatialDilatedConvolution.lua
@@ -0,0 +1,73 @@
+local SpatialDilatedConvolution, parent =
+ torch.class('cudnn.SpatialDilatedConvolution', 'cudnn.SpatialConvolution')
+local ffi = require 'ffi'
+local find = require 'cudnn.find'
+
+function SpatialDilatedConvolution:__init(nInputPlane, nOutputPlane,
+ kW, kH, dW, dH, padW, padH, dilationW, dilationH, groups)
+ local delayedReset = self.reset
+ self.reset = function() end
+ parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, groups)--, dilationW, dilationH)
+ self.dilationW = dilationW
+ self.dilationH = dilationH
+end
+
+function SpatialDilatedConvolution:createIODescriptors(input)
+ local batch = true
+ if input:dim() == 3 then
+ input = input:view(1, input:size(1), input:size(2), input:size(3))
+ batch = false
+ end
+ if parent.checkInputChanged(self, input) then
+ -- create input descriptor
+ local input_slice = input:narrow(2,1,self.nInputPlane/self.groups)
+ self.iDesc = cudnn.toDescriptor(input_slice)
+ -- create conv descriptor
+ self.padH, self.padW = self.padH or 0, self.padW or 0
+ -- those needed to calculate hash
+ self.pad = {self.padH, self.padW}
+ self.stride = {self.dH, self.dW}
+ local t_dataType = cudnn.configmap(torch.type(self.weight))
+ --fallback to fp32 math if half type, fp16 dilated convs not fully implmented in cuDNN 6.0.2
+ if( t_dataType == 'CUDNN_DATA_HALF') then t_dataType = 'CUDNN_DATA_FLOAT' end
+ self.convDescData = {
+ padA = self.pad,
+ filterStrideA = self.stride,
+ dilationA = {self.dilationH, self.dilationW},
+ dataType = t_dataType
+ }
+ self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData)
+
+ -- get output shape, resize output
+ local oSize = torch.IntTensor(4)
+ cudnn.errcheck('cudnnGetConvolutionNdForwardOutputDim',
+ self.convDesc[0], self.iDesc[0],
+ self.weightDesc[0], 4, oSize:data())
+ oSize[2] = oSize[2] * self.groups
+ self.output:resize(oSize:long():storage())
+ self.oSize = self.output:size()
+
+ local output_slice = self.output:narrow(2,1,self.nOutputPlane/self.groups)
+ -- create descriptor for output
+ self.oDesc = cudnn.toDescriptor(output_slice)
+ self.oDescForBias = cudnn.toDescriptor(self.output)
+
+ find:prepare(self, input_slice, output_slice)
+
+ -- create offsets for groups
+ local iH, iW = input:size(3), input:size(4)
+ local kH, kW = self.kH, self.kW
+ local oH, oW = oSize[3], oSize[4]
+ self.input_offset = self.nInputPlane / self.groups * iH * iW
+ self.output_offset = self.nOutputPlane / self.groups * oH * oW
+ self.weight_offset = self.nInputPlane / self.groups * self.nOutputPlane / self.groups * kH * kW
+
+ if not batch then
+ self.output = self.output:view(self.output:size(2),
+ self.output:size(3),
+ self.output:size(4))
+ end
+
+ end
+ return self
+end
diff --git a/VolumetricDilatedConvolution.lua b/VolumetricDilatedConvolution.lua
new file mode 100644
index 0000000..652b646
--- /dev/null
+++ b/VolumetricDilatedConvolution.lua
@@ -0,0 +1,62 @@
+local VolumetricDilatedConvolution, parent
+ = torch.class('cudnn.VolumetricDilatedConvolution', 'cudnn.VolumetricConvolution')
+local ffi = require 'ffi'
+local find = require 'cudnn.find'
+
+local Convolution = cudnn.SpatialConvolution
+
+function VolumetricDilatedConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, dilationT, dilationW, dilationH)
+ parent.__init(self, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH)
+
+ self.dilationT = dilationT or 1
+ self.dilationW = dilationW or 1
+ self.dilationH = dilationH or 1
+end
+
+
+function VolumetricDilatedConvolution:createIODescriptors(input)
+ if input:dim() == 4 then
+ input = input:view(1, input:size(1), input:size(2),
+ input:size(3), input:size(4))
+ batch = false
+ end
+ if Convolution.checkInputChanged(self, input) then
+ -- create input descriptor
+ self.iDesc = cudnn.toDescriptor(input)
+ -- create conv descriptor
+ self.pad = {self.padT, self.padH, self.padW}
+ self.stride = {self.dT, self.dH, self.dW}
+ self.dilation = {self.dilationT, self.dilationH, self.dilationW}
+
+ local mathtype=cudnn.configmap(torch.type(self.weight))
+ -- 3D convolutions do not work in 16 bits
+ if mathtype == 'CUDNN_DATA_HALF' then
+ mathtype = 'CUDNN_DATA_FLOAT'
+ end
+ self.convDescData = {
+ padA = self.pad,
+ filterStrideA = self.stride,
+ dilationA = self.dilation,
+ dataType = mathtype
+ }
+ self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData)
+
+ local oSize = torch.IntTensor(5)
+ cudnn.errcheck('cudnnGetConvolutionNdForwardOutputDim',
+ self.convDesc[0], self.iDesc[0],
+ self.weightDesc[0], 5, oSize:data())
+ self.output:resize(oSize:long():storage())
+ -- create descriptor for output
+ self.oDesc = cudnn.toDescriptor(self.output)
+ self.oDescForBias = cudnn.toDescriptor(
+ self.output:view(self.output:size(1),
+ self.output:size(2),
+ self.output:size(3)*self.output:size(4),
+ self.output:size(5)))
+ self.input_offset = 0
+ self.output_offset = 0
+ self.weight_offset = 0
+ find:prepare(self, input, self.output)
+
+ end
+end
diff --git a/VolumetricFullConvolution.lua b/VolumetricFullConvolution.lua
index a662429..a51f09b 100644
--- a/VolumetricFullConvolution.lua
+++ b/VolumetricFullConvolution.lua
@@ -43,10 +43,16 @@ function VolumetricFullConvolution:createIODescriptors(input)
local input_slice = input[{{},{1,self.nInputPlane},{},{}}]
self.iDesc = cudnn.toDescriptor(input_slice)
-- create conv descriptor
+ local mathtype = cudnn.configmap(torch.type(self.weight))
+ -- 3D convolutions do not work in 16 bits
+ if mathtype == 'CUDNN_DATA_HALF' then
+ mathtype = 'CUDNN_DATA_FLOAT'
+ end
+
self.pad = {self.padT, self.padH, self.padW}
self.stride = {self.dT, self.dH, self.dW}
self.convDescData = { padA = self.pad, filterStrideA = self.stride,
- dataType = cudnn.configmap(torch.type(self.weight))}
+ dataType = mathtype }
self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData)
-- get output shape, resize output
diff --git a/convert.lua b/convert.lua
index 4075122..7b4de1c 100644
--- a/convert.lua
+++ b/convert.lua
@@ -7,6 +7,7 @@ local layer_list = {
'SpatialFullConvolution',
'SpatialMaxPooling',
'SpatialAveragePooling',
+ 'SpatialDilatedConvolution',
'ReLU',
'Tanh',
'Sigmoid',
@@ -17,6 +18,7 @@ local layer_list = {
'VolumetricFullConvolution',
'VolumetricMaxPooling',
'VolumetricAveragePooling',
+ 'VolumetricDilatedConvolution',
}
-- goes over a given net and converts all layers to dst backend
diff --git a/ffi.lua b/ffi.lua
index 3fba048..f72baea 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -3,11 +3,17 @@ local ffi = require 'ffi'
ffi.cdef[[
+typedef enum
+{
+ MAJOR_VERSION,
+ MINOR_VERSION,
+ PATCH_LEVEL
+} libraryPropertyType;
typedef enum {
- CUDNN_MAJOR = 5,
+ CUDNN_MAJOR = 6,
CUDNN_MINOR = 0,
- CUDNN_PATCHLEVEL = 4,
+ CUDNN_PATCHLEVEL = 2,
CUDNN_VERSION = (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
} cudnnVerFakeEnum;
@@ -21,22 +27,25 @@ size_t cudnnGetVersion(void);
*/
typedef enum
{
- CUDNN_STATUS_SUCCESS = 0,
- CUDNN_STATUS_NOT_INITIALIZED = 1,
- CUDNN_STATUS_ALLOC_FAILED = 2,
- CUDNN_STATUS_BAD_PARAM = 3,
- CUDNN_STATUS_INTERNAL_ERROR = 4,
- CUDNN_STATUS_INVALID_VALUE = 5,
- CUDNN_STATUS_ARCH_MISMATCH = 6,
- CUDNN_STATUS_MAPPING_ERROR = 7,
- CUDNN_STATUS_EXECUTION_FAILED = 8,
- CUDNN_STATUS_NOT_SUPPORTED = 9,
- CUDNN_STATUS_LICENSE_ERROR = 10
+ CUDNN_STATUS_SUCCESS = 0,
+ CUDNN_STATUS_NOT_INITIALIZED = 1,
+ CUDNN_STATUS_ALLOC_FAILED = 2,
+ CUDNN_STATUS_BAD_PARAM = 3,
+ CUDNN_STATUS_INTERNAL_ERROR = 4,
+ CUDNN_STATUS_INVALID_VALUE = 5,
+ CUDNN_STATUS_ARCH_MISMATCH = 6,
+ CUDNN_STATUS_MAPPING_ERROR = 7,
+ CUDNN_STATUS_EXECUTION_FAILED = 8,
+ CUDNN_STATUS_NOT_SUPPORTED = 9,
+ CUDNN_STATUS_LICENSE_ERROR = 10,
+ CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING = 11
} cudnnStatus_t;
/* human-readable error messages*/
const char * cudnnGetErrorString(cudnnStatus_t status);
+cudnnStatus_t cudnnGetProperty(libraryPropertyType type, int *value);
+
cudnnStatus_t cudnnCreate (cudnnHandle_t *handle);
cudnnStatus_t cudnnDestroy (cudnnHandle_t handle);
cudnnStatus_t cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId);
@@ -52,6 +61,7 @@ typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t;
typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t;
typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t;
typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t;
+typedef struct cudnnReduceTensorStruct* cudnnReduceTensorDescriptor_t;
/*
* CUDNN data type
*/
@@ -60,16 +70,27 @@ typedef enum
CUDNN_DATA_FLOAT = 0,
CUDNN_DATA_DOUBLE = 1,
CUDNN_DATA_HALF = 2,
+ CUDNN_DATA_INT8 = 3,
+ CUDNN_DATA_INT32 = 4,
+ CUDNN_DATA_INT8x4 = 5
} cudnnDataType_t;
/*
* CUDNN propagate Nan
*/
-typedef enum{
+typedef enum {
CUDNN_NOT_PROPAGATE_NAN = 0,
CUDNN_PROPAGATE_NAN = 1,
} cudnnNanPropagation_t;
+/*
+ * CUDNN Determinism
+ */
+typedef enum {
+ CUDNN_NON_DETERMINISTIC = 0,
+ CUDNN_DETERMINISTIC = 1,
+} cudnnDeterminism_t;
+
/* Maximum supported number of tensor dimensions */
typedef enum { CUDNN_DIM_MAX = 8 } cudnnDimMaxFakeEnum;
@@ -79,8 +100,9 @@ cudnnStatus_t cudnnCreateTensorDescriptor(
typedef enum
{
- CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
- CUDNN_TENSOR_NHWC = 1 /* feature maps interleaved ( cStride = 1 )*/
+ CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
+ CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/
+ CUDNN_TENSOR_NCHW_VECT_C = 2 /* each image point is vector of element of C : the length of the vector is carried by the data type*/
} cudnnTensorFormat_t;
cudnnStatus_t cudnnSetTensor4dDescriptor(
@@ -124,7 +146,14 @@ cudnnStatus_t cudnnSetTensorNdDescriptor(
const int dimA[],
const int strideA[] );
-cudnnStatus_t cudnnGetTensorNdDescriptor(
+cudnnStatus_t cudnnSetTensorNdDescriptorEx(
+ cudnnTensorDescriptor_t tensorDesc,
+ cudnnTensorFormat_t format,
+ cudnnDataType_t dataType,
+ int nbDims,
+ const int dimA[] );
+
+cudnnStatus_t cudnnGetTensorNdDescriptor(
const cudnnTensorDescriptor_t tensorDesc,
int nbDimsRequested,
cudnnDataType_t *dataType,
@@ -132,6 +161,11 @@ cudnnStatus_t cudnnGetTensorNdDescriptor(
int dimA[],
int strideA[] );
+
+cudnnStatus_t cudnnGetTensorSizeInBytes(
+ const cudnnTensorDescriptor_t tensorDesc,
+ size_t *size);
+
/* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride
1)Example of all images in row major order one batch of features after the other (with an optional padding on row)
@@ -186,10 +220,11 @@ cudnnStatus_t cudnnAddTensor(
*/
typedef enum
{
- CUDNN_OP_TENSOR_ADD = 0,
- CUDNN_OP_TENSOR_MUL = 1,
- CUDNN_OP_TENSOR_MIN = 2,
- CUDNN_OP_TENSOR_MAX = 3,
+ CUDNN_OP_TENSOR_ADD = 0,
+ CUDNN_OP_TENSOR_MUL = 1,
+ CUDNN_OP_TENSOR_MIN = 2,
+ CUDNN_OP_TENSOR_MAX = 3,
+ CUDNN_OP_TENSOR_SQRT = 4,
} cudnnOpTensorOp_t;
cudnnStatus_t cudnnCreateOpTensorDescriptor(
@@ -224,6 +259,97 @@ cudnnStatus_t cudnnOpTensor(
const cudnnTensorDescriptor_t cDesc,
void *C );
+/*
+* CUDNN ReduceTensor op type
+*/
+typedef enum
+{
+ CUDNN_REDUCE_TENSOR_ADD = 0,
+ CUDNN_REDUCE_TENSOR_MUL = 1,
+ CUDNN_REDUCE_TENSOR_MIN = 2,
+ CUDNN_REDUCE_TENSOR_MAX = 3,
+ CUDNN_REDUCE_TENSOR_AMAX = 4,
+ CUDNN_REDUCE_TENSOR_AVG = 5,
+ CUDNN_REDUCE_TENSOR_NORM1 = 6,
+ CUDNN_REDUCE_TENSOR_NORM2 = 7,
+} cudnnReduceTensorOp_t;
+
+/*
+* CUDNN ReduceTensor indices type
+*/
+typedef enum
+{
+ CUDNN_REDUCE_TENSOR_NO_INDICES = 0,
+ CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1,
+} cudnnReduceTensorIndices_t;
+
+/*
+* CUDNN tensor indices type size (all unsigned)
+* Currently not supported, default is 32 bit unsigned.
+*/
+typedef enum
+{
+ CUDNN_32BIT_INDICES = 0,
+ CUDNN_64BIT_INDICES = 1,
+ CUDNN_16BIT_INDICES = 2,
+ CUDNN_8BIT_INDICES = 3,
+} cudnnIndicesType_t;
+
+cudnnStatus_t cudnnCreateReduceTensorDescriptor(
+ cudnnReduceTensorDescriptor_t *reduceTensorDesc );
+
+cudnnStatus_t cudnnSetReduceTensorDescriptor(
+ cudnnReduceTensorDescriptor_t reduceTensorDesc,
+ cudnnReduceTensorOp_t reduceTensorOp,
+ cudnnDataType_t reduceTensorCompType,
+ cudnnNanPropagation_t reduceTensorNanOpt,
+ cudnnReduceTensorIndices_t reduceTensorIndices,
+ cudnnIndicesType_t reduceTensorIndicesType );
+
+cudnnStatus_t cudnnGetReduceTensorDescriptor(
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+ cudnnReduceTensorOp_t *reduceTensorOp,
+ cudnnDataType_t *reduceTensorCompType,
+ cudnnNanPropagation_t *reduceTensorNanOpt,
+ cudnnReduceTensorIndices_t *reduceTensorIndices,
+ cudnnIndicesType_t *reduceTensorIndicesType );
+
+cudnnStatus_t cudnnDestroyReduceTensorDescriptor(
+ cudnnReduceTensorDescriptor_t reduceTensorDesc );
+
+ /* Helper function to return the minimum size of the index space to be passed to the reduction given the input and output tensors */
+cudnnStatus_t cudnnGetReductionIndicesSize(
+ cudnnHandle_t handle,
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+ const cudnnTensorDescriptor_t aDesc,
+ const cudnnTensorDescriptor_t cDesc,
+ size_t *sizeInBytes );
+
+ /* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output tensors */
+cudnnStatus_t cudnnGetReductionWorkspaceSize(
+ cudnnHandle_t handle,
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+ const cudnnTensorDescriptor_t aDesc,
+ const cudnnTensorDescriptor_t cDesc,
+ size_t *sizeInBytes );
+
+/* Tensor operation : C = reduce op( alpha * A ) + beta * C */
+/* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */
+/* The indices space is ignored for reduce ops other than min or max. */
+cudnnStatus_t cudnnReduceTensor(
+ cudnnHandle_t handle,
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
+ void *indices,
+ size_t indicesSizeInBytes,
+ void *workspace,
+ size_t workspaceSizeInBytes,
+ const void *alpha,
+ const cudnnTensorDescriptor_t aDesc,
+ const void *A,
+ const void *beta,
+ const cudnnTensorDescriptor_t cDesc,
+ void *C );
+
/* Set all values of a tensor to a given value : y[i] = value[0] */
cudnnStatus_t cudnnSetTensor(
cudnnHandle_t handle,
@@ -296,46 +422,26 @@ cudnnStatus_t cudnnDestroyFilterDescriptor(
cudnnStatus_t cudnnCreateConvolutionDescriptor(
cudnnConvolutionDescriptor_t *convDesc );
-cudnnStatus_t cudnnSetConvolution2dDescriptor(
- cudnnConvolutionDescriptor_t convDesc,
- int pad_h, /* zero-padding height*/
- int pad_w, /* zero-padding width*/
- int u, /* vertical filter stride*/
- int v, /* horizontal filter stride*/
- int upscalex, /* upscale the input in x-direction*/
- int upscaley, /* upscale the input in y-direction*/
- cudnnConvolutionMode_t mode );
-
-cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc,
- int pad_h, /* zero-padding height*/
- int pad_w, /* zero-padding width*/
- int u, /* vertical filter stride*/
- int v, /* horizontal filter stride*/
- int upscalex, /* upscale the input in x-direction*/
- int upscaley, /* upscale the input in y-direction*/
+cudnnStatus_t cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc,
+ int pad_h, // zero-padding height
+ int pad_w, // zero-padding width
+ int u, // vertical filter stride
+ int v, // horizontal filter stride
+ int dilation_h, // filter dilation in the vertical dimension
+ int dilation_w, // filter dilation in the horizontal dimension
cudnnConvolutionMode_t mode,
- cudnnDataType_t dataType
+ cudnnDataType_t computeType
);
-cudnnStatus_t cudnnGetConvolution2dDescriptor(
- const cudnnConvolutionDescriptor_t convDesc,
- int *pad_h, /* zero-padding height*/
- int *pad_w, /* zero-padding width*/
- int *u, /* vertical filter stride*/
- int *v, /* horizontal filter stride*/
- int *upscalex, /* upscale the input in x-direction*/
- int *upscaley, /* upscale the input in y-direction*/
- cudnnConvolutionMode_t *mode );
-
-cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc,
- int* pad_h, /* zero-padding height*/
- int* pad_w, /* zero-padding width*/
- int* u, /* vertical filter stride*/
- int* v, /* horizontal filter stride*/
- int* upscalex, /* upscale the input in x-direction*/
- int* upscaley, /* upscale the input in y-direction*/
+cudnnStatus_t cudnnGetConvolution2dDescriptor( const cudnnConvolutionDescriptor_t convDesc,
+ int* pad_h, // zero-padding height
+ int* pad_w, // zero-padding width
+ int* u, // vertical filter stride
+ int* v, // horizontal filter stride
+ int* dilation_h, // filter dilation in the vertical dimension
+ int* dilation_w, // filter dilation in the horizontal dimension
cudnnConvolutionMode_t* mode,
- cudnnDataType_t *dataType
+ cudnnDataType_t *computeType
);
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
@@ -354,19 +460,19 @@ cudnnStatus_t cudnnSetConvolutionNdDescriptor(
int arrayLength, /* nbDims-2 size */
const int padA[],
const int filterStrideA[],
- const int upscaleA[],
+ const int dilationA[],
cudnnConvolutionMode_t mode,
- cudnnDataType_t dataType ); /* convolution data type*/
+ cudnnDataType_t computeType ); // convolution data type
-cudnnStatus_t cudnnGetConvolutionNdDescriptor(
+cudnnStatus_t cudnnGetConvolutionNdDescriptor(
const cudnnConvolutionDescriptor_t convDesc,
int arrayLengthRequested,
int *arrayLength,
int padA[],
int strideA[],
- int upscaleA[],
+ int dilationA[],
cudnnConvolutionMode_t *mode,
- cudnnDataType_t *dataType ); /* convolution data type*/
+ cudnnDataType_t *computeType ); // convolution data type
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
@@ -400,7 +506,8 @@ typedef enum
CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6,
- CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7,
+ CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8,
} cudnnConvolutionFwdAlgo_t;
typedef struct {
@@ -408,6 +515,8 @@ typedef struct {
cudnnStatus_t status;
float time;
size_t memory;
+ cudnnDeterminism_t determinism;
+ int reserved[4];
} cudnnConvolutionFwdAlgoPerf_t;
cudnnStatus_t cudnnFindConvolutionForwardAlgorithm(
@@ -479,8 +588,29 @@ cudnnStatus_t cudnnConvolutionForward(
const cudnnTensorDescriptor_t yDesc,
void *y );
+/* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
+cudnnStatus_t cudnnConvolutionBiasActivationForward(
+ cudnnHandle_t handle,
+ const void *alpha1,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
+ const cudnnFilterDescriptor_t wDesc,
+ const void *w,
+ const cudnnConvolutionDescriptor_t convDesc,
+ cudnnConvolutionFwdAlgo_t algo,
+ void *workSpace,
+ size_t workSpaceSizeInBytes,
+ const void *alpha2,
+ const cudnnTensorDescriptor_t zDesc,
+ const void *z,
+ const cudnnTensorDescriptor_t biasDesc,
+ const void *bias,
+ const cudnnActivationDescriptor_t activationDesc,
+ const cudnnTensorDescriptor_t yDesc,
+ void *y );
+
/* Function to compute the bias gradient for batch convolution */
-cudnnStatus_t cudnnConvolutionBackwardBias(
+cudnnStatus_t cudnnConvolutionBackwardBias(
cudnnHandle_t handle,
const void *alpha,
const cudnnTensorDescriptor_t dyDesc,
@@ -504,16 +634,19 @@ typedef enum
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic, algo0 with workspace*/
- /* CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, not implemented */
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING= 6,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7,
} cudnnConvolutionBwdFilterAlgo_t;
-
typedef struct {
cudnnConvolutionBwdFilterAlgo_t algo;
cudnnStatus_t status;
float time;
size_t memory;
+ cudnnDeterminism_t determinism;
+ int reserved[4];
} cudnnConvolutionBwdFilterAlgoPerf_t;
cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm(
@@ -596,7 +729,8 @@ typedef enum
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6,
} cudnnConvolutionBwdDataAlgo_t;
typedef struct {
@@ -604,6 +738,8 @@ typedef struct {
cudnnStatus_t status;
float time;
size_t memory;
+ cudnnDeterminism_t determinism;
+ int reserved[4];
} cudnnConvolutionBwdDataAlgoPerf_t;
@@ -728,9 +864,10 @@ cudnnStatus_t cudnnSoftmaxBackward(
typedef enum
{
CUDNN_POOLING_MAX = 0,
- CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values*/
- CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values*/
- CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility
+ CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, // count for average includes padded values
+ CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, // count for average does not include padded values
+ CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,
+ CUDNN_POOLING_MAX_DETERMINISTIC = 3
} cudnnPoolingMode_t;
/* Create an instance of pooling descriptor */
@@ -833,7 +970,8 @@ typedef enum
CUDNN_ACTIVATION_SIGMOID = 0,
CUDNN_ACTIVATION_RELU = 1,
CUDNN_ACTIVATION_TANH = 2,
- CUDNN_ACTIVATION_CLIPPED_RELU = 3
+ CUDNN_ACTIVATION_CLIPPED_RELU = 3,
+ CUDNN_ACTIVATION_ELU = 4
} cudnnActivationMode_t;
/* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */
@@ -844,13 +982,13 @@ cudnnStatus_t cudnnSetActivationDescriptor(
cudnnActivationDescriptor_t activationDesc,
cudnnActivationMode_t mode,
cudnnNanPropagation_t reluNanOpt,
- double reluCeiling );
+ double coef ); /* ceiling for clipped RELU, alpha for ELU */
cudnnStatus_t cudnnGetActivationDescriptor(
const cudnnActivationDescriptor_t activationDesc,
cudnnActivationMode_t *mode,
cudnnNanPropagation_t *reluNanOpt,
- double* reluCeiling );
+ double* coef ); /* ceiling for clipped RELU, alpha for ELU */
cudnnStatus_t cudnnDestroyActivationDescriptor(
cudnnActivationDescriptor_t activationDesc);
@@ -1234,13 +1372,50 @@ typedef enum
} cudnnRNNInputMode_t;
+typedef enum
+ {
+ CUDNN_RNN_ALGO_STANDARD = 0,
+ CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
+ CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2
+ } cudnnRNNAlgo_t;
+
struct cudnnRNNStruct;
typedef struct cudnnRNNStruct* cudnnRNNDescriptor_t;
-cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc);
-cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
+cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc);
+cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
+
+struct cudnnPersistentRNNPlan;
+typedef struct cudnnPersistentRNNPlan *cudnnPersistentRNNPlan_t;
+
-cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc,
+// Expensive. Creates the plan for the specific settings.
+cudnnStatus_t cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
+ const int minibatch,
+ const cudnnDataType_t dataType,
+ cudnnPersistentRNNPlan_t * plan);
+
+// Attaches the plan to the descriptor.
+cudnnStatus_t cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
+ cudnnPersistentRNNPlan_t plan);
+
+cudnnStatus_t cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan);
+
+
+
+cudnnStatus_t cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
+ cudnnRNNDescriptor_t rnnDesc,
+ const int hiddenSize,
+ const int numLayers,
+ cudnnDropoutDescriptor_t dropoutDesc, // Between layers, not between recurrent steps.
+ cudnnRNNInputMode_t inputMode,
+ cudnnDirectionMode_t direction,
+ cudnnRNNMode_t mode,
+ cudnnRNNAlgo_t algo,
+ cudnnDataType_t dataType);
+
+
+cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc,
int hiddenSize,
int numLayers,
cudnnDropoutDescriptor_t dropoutDesc,
@@ -1250,6 +1425,7 @@ cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc,
cudnnDataType_t dataType);
+
// dataType in the RNN descriptor is used to determine math precision
// dataType in weight descriptors and input descriptors is used to describe storage
@@ -1388,202 +1564,52 @@ cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle,
size_t reserveSpaceSizeInBytes );
-
-
-
/* DEPRECATED routines to be removed next release :
- User should use the non-suffixed version (which has the API and functionality of _v4 version)
- Routines with _v3 suffix has the functionality of the non-suffixed routines in the CUDNN V4
+ User should use the non-suffixed version (which has the API and functionality of _v5 version)
+ Routines with _v4 suffix has the functionality of the non-suffixed routines in the CUDNN V5
*/
-cudnnStatus_t cudnnSetFilter4dDescriptor_v3(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type*/
- int k, /* number of output feature maps*/
- int c, /* number of input feature maps*/
- int h, /* height of each input filter*/
- int w ); /* width of each input filter*/
-
-cudnnStatus_t cudnnSetFilter4dDescriptor_v4(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type*/
- cudnnTensorFormat_t format,
- int k, /* number of output feature maps*/
- int c, /* number of input feature maps*/
- int h, /* height of each input filter*/
- int w ); /* width of each input filter*/
-
-cudnnStatus_t cudnnGetFilter4dDescriptor_v3(
- const cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t *dataType, /* image data type*/
- int *k, /* number of output feature maps*/
- int *c, /* number of input feature maps*/
- int *h, /* height of each input filter*/
- int *w ); /* width of each input filter*/
-
-cudnnStatus_t cudnnGetFilter4dDescriptor_v4(
- const cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t *dataType, /* image data type*/
- cudnnTensorFormat_t *format,
- int *k, /* number of output feature maps*/
- int *c, /* number of input feature maps*/
- int *h, /* height of each input filter*/
- int *w ); /* width of each input filter */
-
-cudnnStatus_t cudnnSetFilterNdDescriptor_v3(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type*/
- int nbDims,
- const int filterDimA[] );
-
-
-cudnnStatus_t cudnnSetFilterNdDescriptor_v4(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type*/
- cudnnTensorFormat_t format,
- int nbDims,
- const int filterDimA[] );
-
-cudnnStatus_t cudnnGetFilterNdDescriptor_v3(
- const cudnnFilterDescriptor_t filterDesc,
- int nbDimsRequested,
- cudnnDataType_t *dataType, /* image data type*/
- int *nbDims,
- int filterDimA[] );
-
-cudnnStatus_t cudnnGetFilterNdDescriptor_v4(
- const cudnnFilterDescriptor_t filterDesc,
- int nbDimsRequested,
- cudnnDataType_t *dataType, /* image data type*/
- cudnnTensorFormat_t *format,
- int *nbDims,
- int filterDimA[] );
-
-cudnnStatus_t cudnnSetPooling2dDescriptor_v3(
- cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t mode,
- int windowHeight,
- int windowWidth,
- int verticalPadding,
- int horizontalPadding,
- int verticalStride,
- int horizontalStride );
-
-cudnnStatus_t cudnnSetPooling2dDescriptor_v4(
- cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t mode,
- cudnnNanPropagation_t maxpoolingNanOpt,
- int windowHeight,
- int windowWidth,
- int verticalPadding,
- int horizontalPadding,
- int verticalStride,
- int horizontalStride );
-cudnnStatus_t cudnnGetPooling2dDescriptor_v3(
- const cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t *mode,
- int *windowHeight,
- int *windowWidth,
- int *verticalPadding,
- int *horizontalPadding,
- int *verticalStride,
- int *horizontalStride );
-
-cudnnStatus_t cudnnGetPooling2dDescriptor_v4(
- const cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t *mode,
- cudnnNanPropagation_t *maxpoolingNanOpt,
- int *windowHeight,
- int *windowWidth,
- int *verticalPadding,
- int *horizontalPadding,
- int *verticalStride,
- int *horizontalStride );
-
-cudnnStatus_t cudnnSetPoolingNdDescriptor_v3(
- cudnnPoolingDescriptor_t poolingDesc,
- const cudnnPoolingMode_t mode,
- int nbDims,
- const int windowDimA[],
- const int paddingA[],
- const int strideA[] );
-
-cudnnStatus_t cudnnSetPoolingNdDescriptor_v4(
- cudnnPoolingDescriptor_t poolingDesc,
- const cudnnPoolingMode_t mode,
- const cudnnNanPropagation_t maxpoolingNanOpt,
- int nbDims,
- const int windowDimA[],
- const int paddingA[],
- const int strideA[] );
-
-cudnnStatus_t cudnnGetPoolingNdDescriptor_v3(
- const cudnnPoolingDescriptor_t poolingDesc,
- const int nbDimsRequested,
- cudnnPoolingMode_t *mode,
- int *nbDims,
- int windowDimA[],
- int paddingA[],
- int strideA[] );
-
-cudnnStatus_t cudnnGetPoolingNdDescriptor_v4(
- const cudnnPoolingDescriptor_t poolingDesc,
- int nbDimsRequested,
- cudnnPoolingMode_t *mode,
- cudnnNanPropagation_t *maxpoolingNanOpt,
- int *nbDims,
- int windowDimA[],
- int paddingA[],
- int strideA[] );
-
-cudnnStatus_t cudnnActivationForward_v3(
- cudnnHandle_t handle,
- cudnnActivationMode_t mode,
- const void *alpha,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t yDesc,
- void *y );
-
-cudnnStatus_t cudnnActivationForward_v4(
- cudnnHandle_t handle,
- cudnnActivationDescriptor_t activationDesc,
- const void *alpha,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t yDesc,
- void *y );
+cudnnStatus_t cudnnSetConvolution2dDescriptor_v4(
+ cudnnConvolutionDescriptor_t convDesc,
+ int pad_h, // zero-padding height
+ int pad_w, // zero-padding width
+ int u, // vertical filter stride
+ int v, // horizontal filter stride
+ int dilation_h, // filter dilation in the vertical dimension
+ int dilation_w, // filter dilation in the horizontal dimension
+ cudnnConvolutionMode_t mode );
-cudnnStatus_t cudnnActivationBackward_v3(
- cudnnHandle_t handle,
- cudnnActivationMode_t mode,
- const void *alpha,
- const cudnnTensorDescriptor_t yDesc,
- const void *y,
- const cudnnTensorDescriptor_t dyDesc,
- const void *dy,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t dxDesc,
- void *dx );
+cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc,
+ int pad_h, // zero-padding height
+ int pad_w, // zero-padding width
+ int u, // vertical filter stride
+ int v, // horizontal filter stride
+ int dilation_h, // filter dilation in the vertical dimension
+ int dilation_w, // filter dilation in the horizontal dimension
+ cudnnConvolutionMode_t mode,
+ cudnnDataType_t computeType
+ );
-cudnnStatus_t cudnnActivationBackward_v4(
- cudnnHandle_t handle,
- cudnnActivationDescriptor_t activationDesc,
- const void *alpha,
- const cudnnTensorDescriptor_t yDesc,
- const void *y,
- const cudnnTensorDescriptor_t dyDesc,
- const void *dy,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t dxDesc,
- void *dx );
+cudnnStatus_t cudnnGetConvolution2dDescriptor_v4(
+ const cudnnConvolutionDescriptor_t convDesc,
+ int *pad_h, // zero-padding height
+ int *pad_w, // zero-padding width
+ int *u, // vertical filter stride
+ int *v, // horizontal filter stride
+ int *dilation_h, // filter dilation in the vertical dimension
+ int *dilation_w, // filter dilation in the horizontal dimension
+ cudnnConvolutionMode_t *mode );
+cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc,
+ int* pad_h, // zero-padding height
+ int* pad_w, // zero-padding width
+ int* u, // vertical filter stride
+ int* v, // horizontal filter stride
+ int* dilation_h, // filter dilation in the vertical dimension
+ int* dilation_w, // filter dilation in the horizontal dimension
+ cudnnConvolutionMode_t* mode,
+ cudnnDataType_t *computeType
+ );
]]
local CUDNN_PATH = os.getenv('CUDNN_PATH')
@@ -1591,8 +1617,7 @@ if CUDNN_PATH then
print('Found Environment variable CUDNN_PATH = ' .. CUDNN_PATH)
cudnn.C = ffi.load(CUDNN_PATH)
else
-
- local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib', 'cudnn64_5.dll'}
+ local libnames = {'libcudnn.so.6', 'libcudnn.6.dylib', 'cudnn64_6.dll'}
local ok = false
for i=1,#libnames do
ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end)
@@ -1600,22 +1625,22 @@ else
end
if not ok then
- error([['libcudnn (R5) not found in library path.
+ error([['libcudnn (R6\) not found in library path.
Please install CuDNN from https://developer.nvidia.com/cuDNN
-Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in
+Then make sure files named as libcudnn.so.6 or libcudnn.6.dylib are placed in
your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
-Alternatively, set the path to libcudnn.so.5 or libcudnn.5.dylib
+Alternatively, set the path to libcudnn.so.6 or libcudnn.6.dylib
to the environment variable CUDNN_PATH and rerun torch.
-For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.5"
+For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.6"
]])
end
end
-- check cuDNN version
cudnn.version = tonumber(cudnn.C.cudnnGetVersion())
-if cudnn.version < 5005 or cudnn.version >= 6000 then
- error('These bindings are for CUDNN 5.x (5005 <= cudnn.version > 6000) , '
+if cudnn.version < 6002 then
+ error('These bindings are for version 6002 or above, '
.. 'while the loaded CuDNN is version: ' .. cudnn.version
.. ' \nAre you using an older or newer version of CuDNN?')
end
diff --git a/init.lua b/init.lua
index 246583b..13a77fd 100644
--- a/init.lua
+++ b/init.lua
@@ -210,7 +210,7 @@ end
function cudnn.setConvolutionDescriptor(data, desc)
if not data.arrayLength then data.arrayLength = #data.padA end
- if not data.upscaleA then data.upscaleA = torch.IntStorage(data.arrayLength):fill(1) end
+ if not data.dilationA then data.dilationA = {1,1,1 } end -- assume maximum length==3
if not data.mode then data.mode = 'CUDNN_CROSS_CORRELATION' end
local myDesc = desc or cudnn.createDescriptors(
@@ -219,7 +219,7 @@ function cudnn.setConvolutionDescriptor(data, desc)
-- make sure we have references to these tensors so gc doesn't clean them up
local padATensor = torch.IntTensor(data.padA)
local filterStrideATensor = torch.IntTensor(data.filterStrideA)
- local upscaleATensor = torch.IntTensor(data.upscaleA)
+ local upscaleATensor = torch.IntTensor(data.dilationA)
errcheck('cudnnSetConvolutionNdDescriptor', myDesc[0],
data.arrayLength,
padATensor:data(),
@@ -328,6 +328,8 @@ cudnn.find = require('cudnn.find')
require('cudnn.SpatialConvolution')
require('cudnn.VolumetricConvolution')
+require('cudnn.SpatialDilatedConvolution')
+require('cudnn.VolumetricDilatedConvolution')
require('cudnn.SpatialFullConvolution')
require('cudnn.VolumetricFullConvolution')
require('cudnn.Pooling')
diff --git a/test/test.lua b/test/test.lua
index d958041..ffe6cc3 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -13,7 +13,7 @@ local testparams_half = {
precision_forward = 2e-1,
precision_backward = 10,
precision_jac = 1e-3,
- precision_io = 1e-1,
+ precision_io = 3e-1,
}
local testparams_float = {
@@ -146,6 +146,84 @@ function cudnntest.SpatialConvolution()
testLayer(sconv, gconv, input, gradOutput, scale, true, false)
end
+function cudnntest.SpatialDilatedConvolution()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local to = math.random(1,64)
+ local ki = math.random(1,7)
+ local kj = math.random(1,7)
+ local di = math.random(1,7)
+ local dj = math.random(1,7)
+ local wi = (ki-1)*di+1
+ local wj = (kj-1)*dj+1
+ local si = math.random(1, wi )
+ local sj = math.random(1, wj)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+wi
+ local inj = (outj-1)*sj+wj
+
+ local scale = math.random()
+
+ local input = torch.randn(bs,from,inj,ini)
+ local gradOutput = torch.randn(bs,to,outj,outi)
+ local sconv = nn.SpatialDilatedConvolution(from,to,ki,kj,si,sj,0,0,di,dj)
+ local gconv = cast(cudnn.SpatialDilatedConvolution(from,to,ki,kj,si,sj,0,0,di,dj)):fastest()
+ gconv.weight:copy(sconv.weight)
+ gconv.bias:copy(sconv.bias)
+
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true) -- batch
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false) -- non-batch
+ local originalTypename = torch.typename(gconv)
+ local gconv = cast(cudnn.convert(sconv, cudnn))
+ mytester:asserteq(torch.typename(gconv),
+ originalTypename, 'conversion type check')
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true)
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false)
+end
+
+function cudnntest.VolumetricDilatedConvolution()
+ local bs = math.random(1,32)
+ local from = math.random(1,16)
+ local to = math.random(1,16)
+ local ki = math.random(3,5)
+ local kj = math.random(3,5)
+ local kk = math.random(3,5)
+ local di = math.random(1,5)
+ local dj = math.random(1,5)
+ local dk = math.random(1,5)
+ local wi = (ki-1)*di+1
+ local wj = (kj-1)*dj+1
+ local wk = (kk-1)*dk+1
+ local si = math.random(1, wi )
+ local sj = math.random(1, wj)
+ local sk = math.random(1, wk)
+ local outi = math.random(1,17)
+ local outj = math.random(1,17)
+ local outk = math.random(1,5)
+ local ini = (outi-1)*si+wi
+ local inj = (outj-1)*sj+wj
+ local ink = (outk-1)*sk+wk
+
+ local scale = math.random()
+
+ local input = torch.randn(bs,from,ini,ink,inj)
+ local gradOutput = torch.randn(bs,to,outi,outk,outj)
+ local sconv = nn.VolumetricDilatedConvolution(from,to,ki,kj,kk,si,sj,sk,0,0,0,di,dj,dk)
+ local gconv = cast(cudnn.VolumetricDilatedConvolution(from,to,ki,kj,kk,si,sj,sk,0,0,0,di,dj,dk))
+ gconv.weight:copy(sconv.weight)
+ gconv.bias:copy(sconv.bias)
+
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true) -- batch
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false) -- non-batch
+ local originalTypename = torch.typename(gconv)
+ local gconv = cast(cudnn.convert(sconv, cudnn))
+ mytester:asserteq(torch.typename(gconv),
+ originalTypename, 'conversion type check')
+ testLayer(sconv, gconv, input, gradOutput, scale, true, true)
+ testLayer(sconv, gconv, input, gradOutput, scale, true, false)
+end
+
function cudnntest.SpatialFullConvolution()
local bs = math.random(1,32)
local from = math.random(1,32)
@@ -813,7 +891,7 @@ function cudnntest.functional_bias2D()
local inj = (outj-1)*sj+kj
local scale = torch.uniform()
local input = cast(torch.zeros(bs,from,inj,ini))
- local mod = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj))
+ local mod = cast(cudnn.SpatialConvolution(from,to,ki,kj,si,sj):fastest())
mod.weight:zero()
local groundtruth = mod:forward(input)
local result = groundtruth:clone():zero()
@@ -943,8 +1021,8 @@ mytester:add(cudnntest)
cudnn.verbose=false
cudnn.find.verbose=false
--- this is the default, keep it for demo of 16->32 bit float fallback
-cudnn.find.verboseFallback=true
+-- todo: put it back for release to demo 16->32 bit float fallback
+cudnn.find.verboseFallback=false
cudnn.useFindEx=false
cudnn.configureMath({ ['torch.CudaHalfTensor'] = 'CUDNN_DATA_FLOAT'} )
for i = 1, 1 do -- cutorch.getDeviceCount() do
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index 0d0b37b..63520b6 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -12,6 +12,9 @@ local cudnntest = torch.TestSuite()
local mytester
local tolerance = 1000
+local btol = tolerance * 20
+local stol = tolerance/100
+local tinytol = .1
function cudnntest.testRNNRELU()
local miniBatch = 64
@@ -26,9 +29,9 @@ function cudnntest.testRNNRELU()
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 1.315793E+06, tolerance, 'checkSum with reference for localsumi failed')
mytester:assertalmosteq(checkSums.localSumh, 1.315212E+05, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, tolerance, 'checkSum with reference for localSumdh failed')
- mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, tolerance, 'checkSum with reference for localSumdw failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, stol, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, stol, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, btol, 'checkSum with reference for localSumdw failed')
end
function cudnntest.testRNNBatchFirst()
@@ -45,9 +48,9 @@ function cudnntest.testRNNBatchFirst()
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 1.315793E+06, tolerance, 'checkSum with reference for localsumi failed')
mytester:assertalmosteq(checkSums.localSumh, 1.315212E+05, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, tolerance, 'checkSum with reference for localSumdh failed')
- mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, tolerance, 'checkSum with reference for localSumdw failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 6.676003E+01, stol, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 6.425067E+01, stol, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumdw, 1.453750E+09, btol, 'checkSum with reference for localSumdw failed')
end
function cudnntest.testRNNTANH()
@@ -63,8 +66,8 @@ function cudnntest.testRNNTANH()
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 6.319591E+05, tolerance, 'checkSum with reference for localsumi failed')
mytester:assertalmosteq(checkSums.localSumh, 6.319605E+04, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 4.501830E+00, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdh, 4.489546E+00, tolerance, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 4.501830E+00, 1, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 4.489546E+00, 1, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 5.012598E+07, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -81,9 +84,9 @@ function cudnntest.testRNNLSTM()
mytester:assertalmosteq(checkSums.localSumi, 5.749536E+05, tolerance, 'checkSum with reference for localsumi failed')
mytester:assertalmosteq(checkSums.localSumc, 4.365091E+05, tolerance, 'checkSum with reference for localSumc failed')
mytester:assertalmosteq(checkSums.localSumh, 5.774818E+04, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 3.842206E+02, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdc, 9.323785E+03, tolerance, 'checkSum with reference for localSumdc failed')
- mytester:assertalmosteq(checkSums.localSumdh, 1.182566E+01, tolerance, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 3.842206E+02, stol, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdc, 9.323785E+03, stol, 'checkSum with reference for localSumdc failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 1.182566E+01, tinytol, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 4.313461E+08, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -98,7 +101,7 @@ function cudnntest.testRNNGRU()
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 6.358978E+05, tolerance, 'checkSum with reference for localsumi failed')
mytester:assertalmosteq(checkSums.localSumh, 6.281680E+04, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 6.296622E+00, tolerance, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 6.296622E+00, tinytol, 'checkSum with reference for localSumdi failed')
mytester:assertalmosteq(checkSums.localSumdh, 2.289960E+05, tolerance, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 5.397419E+07, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -118,10 +121,10 @@ function cudnntest.testBiDirectionalRELURNN()
local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections)
-- Checksums to check against are retrieved from cudnn RNN sample.
- mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tolerance, 'checkSum with reference for localsumi failed')
- mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tolerance, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tinytol, 'checkSum with reference for localsumi failed')
+ mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tinytol, 'checkSum with reference for localSumh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tinytol, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tinytol, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 7.061081E+07, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -140,10 +143,10 @@ function cudnntest.testBiDirectionalTANHRNN()
local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections)
-- Checksums to check against are retrieved from cudnn RNN sample.
- mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tolerance, 'checkSum with reference for localsumi failed')
- mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tolerance, 'checkSum with reference for localSumdi failed')
- mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tolerance, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumi, 1.388634E+01, tinytol, 'checkSum with reference for localsumi failed')
+ mytester:assertalmosteq(checkSums.localSumh, 1.288997E+01, tinytol, 'checkSum with reference for localSumh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 1.288729E+01, tinytol, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 1.279004E+01, tinytol, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 7.061081E+07, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -160,11 +163,11 @@ function cudnntest.testBiDirectionalLSTMRNN()
local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections)
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 3.134097E+04, tolerance, 'checkSum with reference for localsumi failed')
- mytester:assertalmosteq(checkSums.localSumc, 3.845626E+00, tolerance, 'checkSum with reference for localSumc failed')
- mytester:assertalmosteq(checkSums.localSumh, 1.922855E+00, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 4.794993E+00, tolerance, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumc, 3.845626E+00, tinytol, 'checkSum with reference for localSumc failed')
+ mytester:assertalmosteq(checkSums.localSumh, 1.922855E+00, tinytol, 'checkSum with reference for localSumh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 4.794993E+00, tinytol, 'checkSum with reference for localSumdi failed')
mytester:assertalmosteq(checkSums.localSumdc, 2.870925E+04, tolerance, 'checkSum with reference for localSumdc failed')
- mytester:assertalmosteq(checkSums.localSumdh, 2.468645E+00, tolerance, 'checkSum with reference for localSumdh failed')
+ mytester:assertalmosteq(checkSums.localSumdh, 2.468645E+00, tinytol, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 1.121568E+08, tolerance, 'checkSum with reference for localSumdw failed')
end
@@ -184,8 +187,8 @@ function cudnntest.testBiDirectionalGRURNN()
local checkSums = getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numberOfLinearLayers, rnn, batchFirst, nbDirections)
-- Checksums to check against are retrieved from cudnn RNN sample.
mytester:assertalmosteq(checkSums.localSumi, 6.555183E+04, tolerance, 'checkSum with reference for localsumi failed')
- mytester:assertalmosteq(checkSums.localSumh, 5.830924E+00, tolerance, 'checkSum with reference for localSumh failed')
- mytester:assertalmosteq(checkSums.localSumdi, 4.271801E+00, tolerance, 'checkSum with reference for localSumdi failed')
+ mytester:assertalmosteq(checkSums.localSumh, 5.830924E+00, tinytol, 'checkSum with reference for localSumh failed')
+ mytester:assertalmosteq(checkSums.localSumdi, 4.271801E+00, tinytol, 'checkSum with reference for localSumdi failed')
mytester:assertalmosteq(checkSums.localSumdh, 6.555744E+04, tolerance, 'checkSum with reference for localSumdh failed')
mytester:assertalmosteq(checkSums.localSumdw, 1.701796E+08, tolerance, 'checkSum with reference for localSumdw failed')
end