Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-08-10 08:36:10 +0300
committerGitHub <noreply@github.com>2017-08-10 08:36:10 +0300
commitf53fc2b365421155c102066db60674fb9b802ec2 (patch)
tree4599ce598e2f5585a7a97e3824a4960810e94d32
parentb76500ec0291172bb95a08a03f4bdc835711ea6c (diff)
parent84811a509c2bef04909d1b36cfe1375aaef3181f (diff)
Merge pull request #380 from ngimel/R7R7
R7
-rw-r--r--CMakeLists.txt13
-rw-r--r--SpatialConvolution.lua65
-rw-r--r--SpatialFullConvolution.lua30
-rw-r--r--convert.lua3
-rw-r--r--ffi.lua495
-rw-r--r--init.lua12
-rw-r--r--test/bench_groups.lua2
-rw-r--r--test/test.lua32
-rw-r--r--test/test_rnn.lua6
9 files changed, 413 insertions, 245 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1a2e4e3..231fd48 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,15 +13,12 @@ ENDIF()
FIND_PACKAGE(Torch REQUIRED)
FIND_PACKAGE(CUDA 8.0 REQUIRED)
-# Set to TRUE if you want automatic install of CUDNN
-SET(CUDNN_AUTO_INSTALL )
+#FIND_PACKAGE(CUDNN 6 EXACT QUIET)
+#IF(NOT CUDNN_FOUND)
+# CUDNN_INSTALL(6.0-rc "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "")
+# FIND_PACKAGE(CUDNN 6 EXACT REQUIRED)
+#ENDIF()
-FIND_PACKAGE(CUDNN 6.0 EXACT)
-
-IF(CUDNN_AUTO_INSTALL AND NOT CUDNN_FOUND)
- CUDNN_INSTALL(6.0 "${Torch_INSTALL_LIB}" "${Torch_INSTALL_INCLUDE}" "")
- FIND_PACKAGE(CUDNN 6.0 EXACT REQUIRED)
-ENDIF()
FILE(GLOB luasrc *.lua)
SET(src "")
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 7f9dd23..23e92b9 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -27,8 +27,6 @@ end
-- if you change the configuration of the module manually, call this
function SpatialConvolution:resetWeightDescriptors(desc)
- -- for compatibility
- self.groups = self.groups or 1
assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!')
assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!')
@@ -40,7 +38,7 @@ function SpatialConvolution:resetWeightDescriptors(desc)
self.weightDesc = cudnn.setFilterDescriptor(
{ dataType = cudnn.typemap[torch.typename(self.weight)],
filterDimA = desc or
- {self.nOutputPlane/self.groups,
+ {self.nOutputPlane,
self.nInputPlane/self.groups,
self.kH, self.kW}
}
@@ -92,7 +90,6 @@ function SpatialConvolution:checkInputChanged(input)
if not self.iSize or self.iSize:size() ~= input:dim() then
self.iSize = torch.LongStorage(input:dim()):fill(0)
end
- self.groups = self.groups or 1
if not self.weightDesc then self:resetWeightDescriptors() end
if not self.weightDesc then error "Weights not assigned!" end
@@ -119,18 +116,20 @@ function SpatialConvolution:createIODescriptors(input)
end
if SpatialConvolution.checkInputChanged(self, input) then
-- create input descriptor
- local input_slice = input:narrow(2,1,self.nInputPlane/self.groups)
+ local input_slice = input:narrow(2,1,self.nInputPlane)
self.iDesc = cudnn.toDescriptor(input_slice)
-- create conv descriptor
self.padH, self.padW = self.padH or 0, self.padW or 0
-- those needed to calculate hash
self.pad = {self.padH, self.padW}
self.stride = {self.dH, self.dW}
-
+
self.convDescData = { padA = self.pad,
filterStrideA = self.stride,
dilationA = {1,1},
- dataType = cudnn.configmap(torch.type(self.weight))
+ dataType = cudnn.configmap(torch.type(self.weight)),
+ mathType = 'CUDNN_DEFAULT_MATH',
+ groupCount = self.groups
}
self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData)
@@ -140,24 +139,14 @@ function SpatialConvolution:createIODescriptors(input)
errcheck('cudnnGetConvolutionNdForwardOutputDim',
self.convDesc[0], self.iDesc[0],
self.weightDesc[0], 4, oSize:data())
- oSize[2] = oSize[2] * self.groups
+
self.output:resize(oSize:long():storage())
self.oSize = self.output:size()
-
- local output_slice = self.output:narrow(2,1,self.nOutputPlane/self.groups)
+ local output_slice = self.output:narrow(2,1,self.nOutputPlane)
-- create descriptor for output
self.oDesc = cudnn.toDescriptor(output_slice)
self.oDescForBias = cudnn.toDescriptor(self.output)
-
- find:prepare(self, input_slice, output_slice)
-
- -- create offsets for groups
- local iH, iW = input:size(3), input:size(4)
- local kH, kW = self.kH, self.kW
- local oH, oW = oSize[3], oSize[4]
- self.input_offset = self.nInputPlane / self.groups * iH * iW
- self.output_offset = self.nOutputPlane / self.groups * oH * oW
- self.weight_offset = self.nInputPlane / self.groups * self.nOutputPlane / self.groups * kH * kW
+ find:prepare(self, input_slice, output_slice)
if not batch then
self.output = self.output:view(self.output:size(2),
@@ -189,19 +178,17 @@ function SpatialConvolution:updateOutput(input)
local finder = find.get()
local fwdAlgo = finder:forwardAlgorithm(self, { self.iDesc[0], self.input_slice, self.weightDesc[0],
self.weight, self.convDesc[0], self.oDesc[0], self.output_slice})
+
local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()
- for g = 0, self.groups - 1 do
- checkedCall(self,'cudnnConvolutionForward', cudnn.getHandle(),
+ checkedCall(self,'cudnnConvolutionForward', cudnn.getHandle(),
cudnn.scalar(input, 1),
- self.iDesc[0], input:data() + g*self.input_offset,
- self.weightDesc[0], self.weight:data() + g*self.weight_offset,
+ self.iDesc[0], input:data(),
+ self.weightDesc[0], self.weight:data(),
self.convDesc[0], fwdAlgo,
extraBuffer, extraBufferSize,
cudnn.scalar(input, 0),
- self.oDesc[0], self.output:data() + g*self.output_offset);
- end
-
- -- add bias
+ self.oDesc[0], self.output:data());
+ -- add bias
if self.bias then
errcheck('cudnnAddTensor', cudnn.getHandle(),
cudnn.scalar(input, 1), self.biasDesc[0], self.bias:data(),
@@ -222,17 +209,15 @@ function SpatialConvolution:updateGradInput(input, gradOutput)
local bwdDataAlgo = finder:backwardDataAlgorithm(self, { self.weightDesc[0], self.weight, self.oDesc[0],
self.output_slice, self.convDesc[0], self.iDesc[0], self.input_slice })
local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()
- for g = 0,self.groups - 1 do
- checkedCall(self,'cudnnConvolutionBackwardData', cudnn.getHandle(),
+ checkedCall(self,'cudnnConvolutionBackwardData', cudnn.getHandle(),
cudnn.scalar(input, 1),
- self.weightDesc[0], self.weight:data() + g*self.weight_offset,
- self.oDesc[0], gradOutput:data() + g*self.output_offset,
+ self.weightDesc[0], self.weight:data(),
+ self.oDesc[0], gradOutput:data(),
self.convDesc[0],
bwdDataAlgo,
extraBuffer, extraBufferSize,
cudnn.scalar(input, 0),
- self.iDesc[0], self.gradInput:data() + g*self.input_offset)
- end
+ self.iDesc[0], self.gradInput:data())
return self.gradInput
end
@@ -259,18 +244,16 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
end
local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()
- for g = 0, self.groups - 1 do
- -- gradWeight
- checkedCall(self,'cudnnConvolutionBackwardFilter', cudnn.getHandle(),
+ -- gradWeight
+ checkedCall(self,'cudnnConvolutionBackwardFilter', cudnn.getHandle(),
self.scaleT:data(),
- self.iDesc[0], input:data() + g*self.input_offset,
- self.oDesc[0], gradOutput:data() + g*self.output_offset,
+ self.iDesc[0], input:data(),
+ self.oDesc[0], gradOutput:data(),
self.convDesc[0],
bwdFilterAlgo,
extraBuffer, extraBufferSize,
cudnn.scalar(input, 1),
- self.weightDesc[0], self.gradWeight:data() + g*self.weight_offset);
- end
+ self.weightDesc[0], self.gradWeight:data());
return self.gradOutput
end
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua
index 7c9520e..9093ce4 100644
--- a/SpatialFullConvolution.lua
+++ b/SpatialFullConvolution.lua
@@ -7,9 +7,28 @@ local checkedCall = find.checkedCall
local Convolution = cudnn.SpatialConvolution
+function SpatialFullConvolution:__init(nInputPlane, nOutputPlane,
+ kW, kH, dW, dH, padW, padH, adjW, adjH, groups)
+ local delayedReset = self.reset
+ self.reset = function() end
+ parent.__init(self, nInputPlane, nOutputPlane,
+ kW, kH, dW, dH, padW, padH, adjW, adjH)
+ self.reset = delayedReset
+ self.groups = groups or 1
+ assert(nInputPlane % self.groups == 0,
+ 'nInputPlane should be divisible by nGroups')
+ assert(nOutputPlane % self.groups == 0,
+ 'nOutputPlane should be divisible by nGroups')
+ self.weight = torch.Tensor(nInputPlane, nOutputPlane/self.groups, kH, kW)
+ self.gradWeight = torch.Tensor(nInputPlane, nOutputPlane/self.groups, kH, kW)
+ self:reset()
+ -- should nil for serialization, the reset will still work
+ self.reset = nil
+end
+
function SpatialFullConvolution:resetWeightDescriptors()
return Convolution.resetWeightDescriptors(self, {self.nInputPlane,
- self.nOutputPlane,
+ self.nOutputPlane/self.groups,
self.kH, self.kW})
end
@@ -47,11 +66,12 @@ function SpatialFullConvolution:createIODescriptors(input)
self.pad = {self.padH, self.padW}
self.stride = {self.dH, self.dW}
- self.convDescData = { padA = self.pad,
+ self.convDesc = cudnn.setConvolutionDescriptor({
+ padA = self.pad,
filterStrideA = self.stride,
- dataType = cudnn.configmap(torch.type(self.weight))
- }
- self.convDesc = cudnn.setConvolutionDescriptor(self.convDescData)
+ dataType = cudnn.configmap(torch.type(self.weight)),
+ groupCount = self.groups
+ })
-- get output shape, resize output
local iwidth = input:size(4)
diff --git a/convert.lua b/convert.lua
index 7b4de1c..d299141 100644
--- a/convert.lua
+++ b/convert.lua
@@ -44,6 +44,9 @@ function cudnn.convert(net, dst, exclusion_fn)
y.divide = true
y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
end
+ if src == nn and string.find(v, 'Convolution') then
+ y.groups = 1
+ end
return y
end
diff --git a/ffi.lua b/ffi.lua
index f72baea..788e2b0 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -11,9 +11,9 @@ typedef enum
} libraryPropertyType;
typedef enum {
- CUDNN_MAJOR = 6,
+ CUDNN_MAJOR = 7,
CUDNN_MINOR = 0,
- CUDNN_PATCHLEVEL = 2,
+ CUDNN_PATCHLEVEL = 15,
CUDNN_VERSION = (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
} cudnnVerFakeEnum;
@@ -62,6 +62,8 @@ typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t;
typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t;
typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t;
typedef struct cudnnReduceTensorStruct* cudnnReduceTensorDescriptor_t;
+typedef struct cudnnCTCLossStruct* cudnnCTCLossDescriptor_t;
+
/*
* CUDNN data type
*/
@@ -76,9 +78,17 @@ typedef enum
} cudnnDataType_t;
/*
+* CUDNN math type
+*/
+typedef enum {
+ CUDNN_DEFAULT_MATH = 0,
+ CUDNN_TENSOR_OP_MATH = 1,
+} cudnnMathType_t;
+
+/*
* CUDNN propagate Nan
*/
-typedef enum {
+typedef enum{
CUDNN_NOT_PROPAGATE_NAN = 0,
CUDNN_PROPAGATE_NAN = 1,
} cudnnNanPropagation_t;
@@ -225,6 +235,7 @@ typedef enum
CUDNN_OP_TENSOR_MIN = 2,
CUDNN_OP_TENSOR_MAX = 3,
CUDNN_OP_TENSOR_SQRT = 4,
+ CUDNN_OP_TENSOR_NOT = 5,
} cudnnOpTensorOp_t;
cudnnStatus_t cudnnCreateOpTensorDescriptor(
@@ -245,7 +256,8 @@ cudnnStatus_t cudnnGetOpTensorDescriptor(
cudnnStatus_t cudnnDestroyOpTensorDescriptor(
cudnnOpTensorDescriptor_t opTensorDesc );
-/* Tensor Bias operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
+/* Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
+/* B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT. */
cudnnStatus_t cudnnOpTensor(
cudnnHandle_t handle,
const cudnnOpTensorDescriptor_t opTensorDesc,
@@ -264,14 +276,15 @@ cudnnStatus_t cudnnOpTensor(
*/
typedef enum
{
- CUDNN_REDUCE_TENSOR_ADD = 0,
- CUDNN_REDUCE_TENSOR_MUL = 1,
- CUDNN_REDUCE_TENSOR_MIN = 2,
- CUDNN_REDUCE_TENSOR_MAX = 3,
- CUDNN_REDUCE_TENSOR_AMAX = 4,
- CUDNN_REDUCE_TENSOR_AVG = 5,
- CUDNN_REDUCE_TENSOR_NORM1 = 6,
- CUDNN_REDUCE_TENSOR_NORM2 = 7,
+ CUDNN_REDUCE_TENSOR_ADD = 0,
+ CUDNN_REDUCE_TENSOR_MUL = 1,
+ CUDNN_REDUCE_TENSOR_MIN = 2,
+ CUDNN_REDUCE_TENSOR_MAX = 3,
+ CUDNN_REDUCE_TENSOR_AMAX = 4,
+ CUDNN_REDUCE_TENSOR_AVG = 5,
+ CUDNN_REDUCE_TENSOR_NORM1 = 6,
+ CUDNN_REDUCE_TENSOR_NORM2 = 7,
+ CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
} cudnnReduceTensorOp_t;
/*
@@ -422,13 +435,25 @@ cudnnStatus_t cudnnDestroyFilterDescriptor(
cudnnStatus_t cudnnCreateConvolutionDescriptor(
cudnnConvolutionDescriptor_t *convDesc );
-cudnnStatus_t cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc,
- int pad_h, // zero-padding height
- int pad_w, // zero-padding width
- int u, // vertical filter stride
- int v, // horizontal filter stride
- int dilation_h, // filter dilation in the vertical dimension
- int dilation_w, // filter dilation in the horizontal dimension
+cudnnStatus_t cudnnSetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc,
+ cudnnMathType_t mathType );
+
+cudnnStatus_t cudnnGetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc,
+ cudnnMathType_t *mathType );
+
+cudnnStatus_t cudnnSetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc,
+ int groupCount );
+
+cudnnStatus_t cudnnGetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc,
+ int *groupCount );
+
+cudnnStatus_t cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc,
+ int pad_h, /* zero-padding height */
+ int pad_w, /* zero-padding width */
+ int u, /* vertical filter stride */
+ int v, /* horizontal filter stride */
+ int dilation_h, /* filter dilation in the vertical dimension */
+ int dilation_w, /* filter dilation in the horizontal dimension */
cudnnConvolutionMode_t mode,
cudnnDataType_t computeType
);
@@ -516,9 +541,13 @@ typedef struct {
float time;
size_t memory;
cudnnDeterminism_t determinism;
- int reserved[4];
+ cudnnMathType_t mathType;
+ int reserved[3];
} cudnnConvolutionFwdAlgoPerf_t;
+cudnnStatus_t cudnnGetConvolutionForwardAlgorithmMaxCount( cudnnHandle_t handle,
+ int *count);
+
cudnnStatus_t cudnnFindConvolutionForwardAlgorithm(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t xDesc,
@@ -555,6 +584,17 @@ cudnnStatus_t cudnnGetConvolutionForwardAlgorithm(
size_t memoryLimitInBytes,
cudnnConvolutionFwdAlgo_t *algo );
+
+cudnnStatus_t cudnnGetConvolutionForwardAlgorithm_v7(
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t srcDesc,
+ const cudnnFilterDescriptor_t filterDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t destDesc,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
+
/*
* convolution algorithm (which requires potentially some workspace)
*/
@@ -630,25 +670,30 @@ typedef enum
typedef enum
{
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic*/
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic, algo0 with workspace*/
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic */
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, /* not implemented */
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING= 6,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING = 6,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7
} cudnnConvolutionBwdFilterAlgo_t;
+
typedef struct {
cudnnConvolutionBwdFilterAlgo_t algo;
- cudnnStatus_t status;
- float time;
- size_t memory;
+ cudnnStatus_t status;
+ float time;
+ size_t memory;
cudnnDeterminism_t determinism;
- int reserved[4];
+ cudnnMathType_t mathType;
+ int reserved[3];
} cudnnConvolutionBwdFilterAlgoPerf_t;
+cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( cudnnHandle_t handle,
+ int *count);
+
cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t xDesc,
@@ -684,6 +729,16 @@ cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm(
size_t memoryLimitInBytes,
cudnnConvolutionBwdFilterAlgo_t *algo );
+cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm_v7(
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t srcDesc,
+ const cudnnTensorDescriptor_t diffDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnFilterDescriptor_t gradDesc,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
+
/*
* convolution algorithm (which requires potentially some workspace)
*/
@@ -712,7 +767,7 @@ cudnnStatus_t cudnnConvolutionBackwardFilter(
const void *beta,
const cudnnFilterDescriptor_t dwDesc,
void *dw );
-
+
/*********************************************************/
/* helper function to provide the convolution algo that fit best the requirement */
typedef enum
@@ -724,13 +779,13 @@ typedef enum
typedef enum
{
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic*/
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6
} cudnnConvolutionBwdDataAlgo_t;
typedef struct {
@@ -739,9 +794,12 @@ typedef struct {
float time;
size_t memory;
cudnnDeterminism_t determinism;
- int reserved[4];
+ cudnnMathType_t mathType;
+ int reserved[3];
} cudnnConvolutionBwdDataAlgoPerf_t;
+cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithmMaxCount( cudnnHandle_t handle,
+ int *count);
cudnnStatus_t cudnnFindConvolutionBackwardDataAlgorithm(
cudnnHandle_t handle,
@@ -778,6 +836,16 @@ cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(
size_t memoryLimitInBytes,
cudnnConvolutionBwdDataAlgo_t *algo );
+cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm_v7(
+ cudnnHandle_t handle,
+ const cudnnFilterDescriptor_t filterDesc,
+ const cudnnTensorDescriptor_t diffDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t gradDesc,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
cudnnStatus_t cudnnGetConvolutionBackwardDataWorkspaceSize(
cudnnHandle_t handle,
@@ -1135,8 +1203,14 @@ typedef enum
/* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice)*/
CUDNN_BATCHNORM_PER_ACTIVATION = 0,
- /*bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors)*/
- CUDNN_BATCHNORM_SPATIAL = 1,
+ /* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
+ CUDNN_BATCHNORM_SPATIAL = 1,
+
+ /*
+ * bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors).
+ * May be faster than CUDNN_BATCHNORM_SPATIAL but imposes some limits on the range of values
+ */
+ CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2,
} cudnnBatchNormMode_t;
/* static const float CUDNN_BN_MIN_EPSILON = 1e-5; */ /* Minimum epsilon allowed to be used in the Batch Normalization formula*/
@@ -1325,33 +1399,47 @@ cudnnStatus_t cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t
/*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
cudnnStatus_t cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t * sizeInBytes);
-cudnnStatus_t cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
- cudnnHandle_t handle,
- float dropout,
- void * states,
- size_t stateSizeInBytes,
- unsigned long long seed);
-
-cudnnStatus_t cudnnDropoutForward(cudnnHandle_t handle,
- const cudnnDropoutDescriptor_t dropoutDesc,
- const cudnnTensorDescriptor_t xdesc,
- const void * x,
- const cudnnTensorDescriptor_t ydesc,
- void * y,
- void * reserveSpace,
- size_t reserveSpaceSizeInBytes);
+cudnnStatus_t cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+ cudnnHandle_t handle,
+ float dropout,
+ void * states,
+ size_t stateSizeInBytes,
+ unsigned long long seed);
+
+// Restores the dropout descriptor to a previously saved-off state
+cudnnStatus_t cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+ cudnnHandle_t handle,
+ float dropout,
+ void * states,
+ size_t stateSizeInBytes,
+ unsigned long long seed);
+
+cudnnStatus_t cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+ cudnnHandle_t handle,
+ float * dropout,
+ void ** states,
+ unsigned long long * seed);
+
+cudnnStatus_t cudnnDropoutForward(cudnnHandle_t handle,
+ const cudnnDropoutDescriptor_t dropoutDesc,
+ const cudnnTensorDescriptor_t xdesc,
+ const void * x,
+ const cudnnTensorDescriptor_t ydesc,
+ void * y,
+ void * reserveSpace,
+ size_t reserveSpaceSizeInBytes);
cudnnStatus_t cudnnDropoutBackward(cudnnHandle_t handle,
const cudnnDropoutDescriptor_t dropoutDesc,
- const cudnnTensorDescriptor_t dydesc,
- const void * dy,
- const cudnnTensorDescriptor_t dxdesc,
- void * dx,
- void * reserveSpace,
- size_t reserveSpaceSizeInBytes);
+ const cudnnTensorDescriptor_t dydesc,
+ const void * dy,
+ const cudnnTensorDescriptor_t dxdesc,
+ void * dx,
+ void * reserveSpace,
+ size_t reserveSpaceSizeInBytes);
/* RNN API */
-typedef enum
+typedef enum
{
CUDNN_RNN_RELU = 0, /* Stock RNN with ReLu activation*/
CUDNN_RNN_TANH = 1, /* Stock RNN with tanh activation*/
@@ -1368,17 +1456,17 @@ typedef enum
typedef enum
{
CUDNN_LINEAR_INPUT = 0,
- CUDNN_SKIP_INPUT = 1
- } cudnnRNNInputMode_t;
-
-
-typedef enum
+ CUDNN_SKIP_INPUT = 1
+ } cudnnRNNInputMode_t;
+
+
+typedef enum
{
- CUDNN_RNN_ALGO_STANDARD = 0,
+ CUDNN_RNN_ALGO_STANDARD = 0,
CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2
- } cudnnRNNAlgo_t;
-
+ } cudnnRNNAlgo_t;
+
struct cudnnRNNStruct;
typedef struct cudnnRNNStruct* cudnnRNNDescriptor_t;
@@ -1388,91 +1476,83 @@ cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDes
struct cudnnPersistentRNNPlan;
typedef struct cudnnPersistentRNNPlan *cudnnPersistentRNNPlan_t;
-
-// Expensive. Creates the plan for the specific settings.
-cudnnStatus_t cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
- const int minibatch,
- const cudnnDataType_t dataType,
- cudnnPersistentRNNPlan_t * plan);
-
-// Attaches the plan to the descriptor.
-cudnnStatus_t cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
- cudnnPersistentRNNPlan_t plan);
-
-cudnnStatus_t cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan);
-
-
-
-cudnnStatus_t cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
- cudnnRNNDescriptor_t rnnDesc,
- const int hiddenSize,
- const int numLayers,
- cudnnDropoutDescriptor_t dropoutDesc, // Between layers, not between recurrent steps.
- cudnnRNNInputMode_t inputMode,
- cudnnDirectionMode_t direction,
- cudnnRNNMode_t mode,
- cudnnRNNAlgo_t algo,
- cudnnDataType_t dataType);
-
-
-cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc,
- int hiddenSize,
- int numLayers,
- cudnnDropoutDescriptor_t dropoutDesc,
- cudnnRNNInputMode_t inputMode,
- cudnnDirectionMode_t direction,
- cudnnRNNMode_t mode,
- cudnnDataType_t dataType);
-
-
-
-// dataType in the RNN descriptor is used to determine math precision
-// dataType in weight descriptors and input descriptors is used to describe storage
-
+
+/* Expensive. Creates the plan for the specific settings. */
+cudnnStatus_t cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
+ const int minibatch,
+ const cudnnDataType_t dataType,
+ cudnnPersistentRNNPlan_t * plan);
+
+/* Attaches the plan to the descriptor. */
+cudnnStatus_t cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
+ cudnnPersistentRNNPlan_t plan);
+
+cudnnStatus_t cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan);
+
+cudnnStatus_t cudnnSetRNNDescriptor(cudnnHandle_t handle,
+ cudnnRNNDescriptor_t rnnDesc,
+ const int hiddenSize,
+ const int numLayers,
+ cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */
+ cudnnRNNInputMode_t inputMode,
+ cudnnDirectionMode_t direction,
+ cudnnRNNMode_t mode,
+ cudnnRNNAlgo_t algo,
+ cudnnDataType_t dataType);
+
+cudnnStatus_t cudnnGetRNNDescriptor(cudnnHandle_t cudnnHandle,
+ cudnnRNNDescriptor_t rnnDesc,
+ int * hiddenSize,
+ int * numLayers,
+ cudnnDropoutDescriptor_t * dropoutDesc,
+ cudnnRNNInputMode_t * inputMode,
+ cudnnDirectionMode_t * direction,
+ cudnnRNNMode_t * mode,
+ cudnnRNNAlgo_t * algo,
+ cudnnDataType_t * dataType);
+
+cudnnStatus_t cudnnSetRNNMatrixMathType (cudnnRNNDescriptor_t desc, cudnnMathType_t math);
+
+/* dataType in the RNN descriptor is used to determine math precision */
+/* dataType in weight descriptors and input descriptors is used to describe storage */
cudnnStatus_t cudnnGetRNNWorkspaceSize( cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnnDesc,
- const int seqLength,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
- size_t *sizeInBytes
- );
-
+ size_t *sizeInBytes);
+
cudnnStatus_t cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnnDesc,
- const int seqLength,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
- size_t *sizeInBytes
- );
+ size_t *sizeInBytes);
-
-cudnnStatus_t cudnnGetRNNParamsSize( cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnnDesc,
+
+cudnnStatus_t cudnnGetRNNParamsSize( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
const cudnnTensorDescriptor_t xDesc,
- size_t *sizeInBytes,
- cudnnDataType_t dataType
- );
+ size_t *sizeInBytes,
+ cudnnDataType_t dataType);
cudnnStatus_t cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnnDesc,
- const int layer,
- const cudnnTensorDescriptor_t xDesc,
- const cudnnFilterDescriptor_t wDesc,
- const void * w,
- const int linLayerID,
- cudnnFilterDescriptor_t linLayerMatDesc,
- void ** linLayerMat
- );
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int layer,
+ const cudnnTensorDescriptor_t xDesc,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const int linLayerID,
+ cudnnFilterDescriptor_t linLayerMatDesc,
+ void ** linLayerMat);
cudnnStatus_t cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnnDesc,
- const int layer,
- const cudnnTensorDescriptor_t xDesc,
- const cudnnFilterDescriptor_t wDesc,
- const void * w,
- const int linLayerID,
- cudnnFilterDescriptor_t linLayerBiasDesc,
- void ** linLayerBias
- );
-
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int layer,
+ const cudnnTensorDescriptor_t xDesc,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const int linLayerID,
+ cudnnFilterDescriptor_t linLayerBiasDesc,
+ void ** linLayerBias);
cudnnStatus_t cudnnRNNForwardInference( cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
@@ -1494,8 +1574,6 @@ cudnnStatus_t cudnnRNNForwardInference( cudnnHandle_t handle,
void * workspace,
size_t workSpaceSizeInBytes);
-
-
cudnnStatus_t cudnnRNNForwardTraining( cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int seqLength,
@@ -1543,7 +1621,7 @@ cudnnStatus_t cudnnRNNBackwardData( cudnnHandle_t handle,
void * dcx,
void * workspace,
size_t workSpaceSizeInBytes,
- const void * reserveSpace,
+ void * reserveSpace,
size_t reserveSpaceSizeInBytes );
@@ -1554,41 +1632,88 @@ cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle,
const void * x,
const cudnnTensorDescriptor_t hxDesc,
const void * hx,
- const cudnnTensorDescriptor_t * yDesc,
+ const cudnnTensorDescriptor_t * yDesc,
const void * y,
- const void * workspace,
- size_t workSpaceSizeInBytes,
- const cudnnFilterDescriptor_t dwDesc,
+ const void * workspace,
+ size_t workSpaceSizeInBytes,
+ const cudnnFilterDescriptor_t dwDesc,
void * dw,
- const void * reserveSpace,
+ const void * reserveSpace,
size_t reserveSpaceSizeInBytes );
+typedef enum
+{
+ CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0,
+ CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1
+}cudnnCTCLossAlgo_t;
-/* DEPRECATED routines to be removed next release :
- User should use the non-suffixed version (which has the API and functionality of _v5 version)
- Routines with _v4 suffix has the functionality of the non-suffixed routines in the CUDNN V5
+/*
+* Create an instance of a CTC (Connectionist Temporal Classification) loss descriptor
+*/
+cudnnStatus_t cudnnCreateCTCLossDescriptor( cudnnCTCLossDescriptor_t* ctcLossDesc );
+
+cudnnStatus_t cudnnSetCTCLossDescriptor(
+ cudnnCTCLossDescriptor_t ctcLossDesc,
+ cudnnDataType_t compType );
+
+cudnnStatus_t cudnnGetCTCLossDescriptor(
+ cudnnCTCLossDescriptor_t ctcLossDesc,
+ cudnnDataType_t* compType );
+
+cudnnStatus_t cudnnDestroyCTCLossDescriptor( cudnnCTCLossDescriptor_t ctcLossDesc );
+
+/* return the ctc costs and gradients, given the probabilities and labels */
+cudnnStatus_t cudnnCTCLoss( cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */
+ const void * probs, /* probabilities after softmax, in GPU memory */
+ const int * labels, /* labels, in CPU memory */
+ const int * labelLengths, /* the length of each label, in CPU memory */
+ const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
+ void * costs, /* the returned costs of CTC, in GPU memory */
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
+ const void * gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
+ cudnnCTCLossDescriptor_t ctcLossDesc,
+ void * workspace, /* pointer to the workspace, in GPU memory */
+ size_t workSpaceSizeInBytes); /* the workspace size needed */
+
+/* return the workspace size needed for ctc */
+cudnnStatus_t cudnnGetCTCLossWorkspaceSize(
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A. To compute costs only, set it to NULL */
+ const int * labels, /* labels, in CPU memory */
+ const int * labelLengths, /* the length of each label, in CPU memory */
+ const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
+ cudnnCTCLossDescriptor_t ctcLossDesc,
+ size_t *sizeInBytes ); /* pointer to the returned workspace size */
+
+
+/* DEPRECATED routines to be removed next release :
+ User should use the non-suffixed version (which has the API and functionality of _v6 version)
+ Routines with _v5 suffix has the functionality of the non-suffixed routines in the CUDNN V6
*/
-cudnnStatus_t cudnnSetConvolution2dDescriptor_v4(
- cudnnConvolutionDescriptor_t convDesc,
- int pad_h, // zero-padding height
- int pad_w, // zero-padding width
- int u, // vertical filter stride
- int v, // horizontal filter stride
- int dilation_h, // filter dilation in the vertical dimension
- int dilation_w, // filter dilation in the horizontal dimension
- cudnnConvolutionMode_t mode );
-
-cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc,
- int pad_h, // zero-padding height
- int pad_w, // zero-padding width
- int u, // vertical filter stride
- int v, // horizontal filter stride
- int dilation_h, // filter dilation in the vertical dimension
- int dilation_w, // filter dilation in the horizontal dimension
- cudnnConvolutionMode_t mode,
- cudnnDataType_t computeType
- );
+cudnnStatus_t cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
+ cudnnRNNDescriptor_t rnnDesc,
+ const int hiddenSize,
+ const int numLayers,
+ cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */
+ cudnnRNNInputMode_t inputMode,
+ cudnnDirectionMode_t direction,
+ cudnnRNNMode_t mode,
+ cudnnRNNAlgo_t algo,
+ cudnnDataType_t dataType);
+
+cudnnStatus_t cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc,
+ int hiddenSize,
+ int numLayers,
+ cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */
+ cudnnRNNInputMode_t inputMode,
+ cudnnDirectionMode_t direction,
+ cudnnRNNMode_t mode,
+ cudnnDataType_t dataType);
cudnnStatus_t cudnnGetConvolution2dDescriptor_v4(
const cudnnConvolutionDescriptor_t convDesc,
@@ -1614,10 +1739,10 @@ cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvo
local CUDNN_PATH = os.getenv('CUDNN_PATH')
if CUDNN_PATH then
- print('Found Environment variable CUDNN_PATH = ' .. CUDNN_PATH)
+ io.stderr:write('Found Environment variable CUDNN_PATH = ' .. CUDNN_PATH)
cudnn.C = ffi.load(CUDNN_PATH)
else
- local libnames = {'libcudnn.so.6', 'libcudnn.6.dylib', 'cudnn64_6.dll'}
+ local libnames = {'libcudnn.so.7', 'libcudnn.7.dylib', 'cudnn64_6.dll'}
local ok = false
for i=1,#libnames do
ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end)
@@ -1625,22 +1750,22 @@ else
end
if not ok then
- error([['libcudnn (R6\) not found in library path.
+ error([['libcudnn (R7\) not found in library path.
Please install CuDNN from https://developer.nvidia.com/cuDNN
-Then make sure files named as libcudnn.so.6 or libcudnn.6.dylib are placed in
+Then make sure files named as libcudnn.so.7 or libcudnn.7.dylib are placed in
your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
-Alternatively, set the path to libcudnn.so.6 or libcudnn.6.dylib
+Alternatively, set the path to libcudnn.so.7 or libcudnn.7.dylib
to the environment variable CUDNN_PATH and rerun torch.
-For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.6"
+For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.7"
]])
end
end
-- check cuDNN version
cudnn.version = tonumber(cudnn.C.cudnnGetVersion())
-if cudnn.version < 6002 then
- error('These bindings are for version 6002 or above, '
+if cudnn.version < 7000 then
+ error('These bindings are for version 7000 or above, '
.. 'while the loaded CuDNN is version: ' .. cudnn.version
.. ' \nAre you using an older or newer version of CuDNN?')
end
diff --git a/init.lua b/init.lua
index 13a77fd..1920ff0 100644
--- a/init.lua
+++ b/init.lua
@@ -150,8 +150,12 @@ function cudnn.getHandle()
end
function cudnn.call(f, ...)
- C.cudnnSetStream(cudnn.getHandle(),
+--context might be destroyed by the time gc calls destructors, in which case cudnnSetStream call would fail
+--and it is not necessary for cudnn destructors anyway
+ if not string.find(f, 'cudnnDestroy') then
+ C.cudnnSetStream(cudnn.getHandle(),
thc.THCState_getCurrentStream(cutorch.getState()))
+ end
return C[f](...)
end
@@ -212,7 +216,9 @@ function cudnn.setConvolutionDescriptor(data, desc)
if not data.arrayLength then data.arrayLength = #data.padA end
if not data.dilationA then data.dilationA = {1,1,1 } end -- assume maximum length==3
if not data.mode then data.mode = 'CUDNN_CROSS_CORRELATION' end
-
+ if not data.mathType then data.mathType = 'CUDNN_DEFAULT_MATH' end
+ if not data.groupCount then data.groupCount = 1 end
+
local myDesc = desc or cudnn.createDescriptors(
1, 'struct cudnnConvolutionStruct*[?]',
'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor')
@@ -227,6 +233,8 @@ function cudnn.setConvolutionDescriptor(data, desc)
upscaleATensor:data(),
data.mode,
data.dataType)
+ errcheck('cudnnSetConvolutionMathType', myDesc[0], data.mathType)
+ errcheck('cudnnSetConvolutionGroupCount', myDesc[0], data.groupCount)
return myDesc
end
diff --git a/test/bench_groups.lua b/test/bench_groups.lua
index cfcaa41..59310c6 100644
--- a/test/bench_groups.lua
+++ b/test/bench_groups.lua
@@ -1,6 +1,6 @@
require 'cudnn'
-m = cudnn.SpatialConvolution(512,512,13,13,1,1,1,1,512)
+m = cudnn.SpatialFullConvolution(512,512,13,13,1,1,1,1,512)
inp = torch.zeros(1,512,512,512)
diff --git a/test/test.lua b/test/test.lua
index ffe6cc3..6ce836f 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -128,6 +128,10 @@ function cudnntest.SpatialConvolution()
local ini = (outi-1)*si+ki
local inj = (outj-1)*sj+kj
local scale = math.random()
+--for half, the total filter dim has to be even
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*from*to % 2 == 1 then
+ to = to+1
+ end
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,to,outj,outi):cuda()
@@ -164,6 +168,10 @@ function cudnntest.SpatialDilatedConvolution()
local inj = (outj-1)*sj+wj
local scale = math.random()
+--for half, the total filter dim has to be even
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*from*to % 2 == 1 then
+ to = to+1
+ end
local input = torch.randn(bs,from,inj,ini)
local gradOutput = torch.randn(bs,to,outj,outi)
@@ -207,6 +215,10 @@ function cudnntest.VolumetricDilatedConvolution()
local scale = math.random()
+--for half, the total filter dim has to be even
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*kk*from*to % 2 == 1 then
+ to = to+1
+ end
local input = torch.randn(bs,from,ini,ink,inj)
local gradOutput = torch.randn(bs,to,outi,outk,outj)
local sconv = nn.VolumetricDilatedConvolution(from,to,ki,kj,kk,si,sj,sk,0,0,0,di,dj,dk)
@@ -237,6 +249,10 @@ function cudnntest.SpatialFullConvolution()
local outi = (ini-1)*si+ki
local outj = (inj-1)*sj+kj
local scale = math.random()
+--for half, the total filter dim has to be even
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*from*to % 2 == 1 then
+ to = to+1
+ end
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,to,outj,outi):cuda()
@@ -264,6 +280,10 @@ function cudnntest.TemporalConvolution()
local outi = math.random(1,15)
local ini = (outi - 1) * si + ki
local scale = math.random()
+--for half, the total filter dim has to be even
+ if testparams.test_type=='torch.CudaHalfTensor' and inputFrameSize*outputFrameSize*ki % 2 == 1 then
+ outputFrameSize = outputFrameSize + 1
+ end
local input = torch.randn(bs,ini,inputFrameSize):cuda()
local gradOutput = torch.randn(bs,outi,outputFrameSize):cuda()
@@ -288,6 +308,9 @@ function cudnntest.TemporalConvolution_padding_batch()
local ini = (outi-1)*si+ki
local scale = math.random()
+ if testparams.test_type=='torch.CudaHalfTensor' and inputFrameSize*outputFrameSize*ki % 2 == 1 then
+ outputFrameSize = outputFrameSize + 1
+ end
local inputpadded = torch.randn(bs,ini,inputFrameSize):cuda()
for i=1,pad_h do
inputpadded:narrow(2,i,1):fill(0)
@@ -340,6 +363,9 @@ function cudnntest.TemporalConvolution_reduceBatchSize()
local ini = (outi-1)*si+ki
local batchSize = 128
local smallerBatchSize = batchSize/2
+ if testparams.test_type=='torch.CudaHalfTensor' and inputFrameSize*outputFrameSize*ki % 2 == 1 then
+ outputFrameSize = outputFrameSize + 1
+ end
local input = cast(torch.randn(batchSize,ini,inputFrameSize))
local conv = cast(cudnn.TemporalConvolution(inputFrameSize,outputFrameSize,ki,si):cuda())
@@ -372,6 +398,9 @@ function cudnntest.VolumetricConvolution()
local inj = outj*sj+kj-1
local ink = outk*sk+kk-1
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*kk*from*to % 2 == 1 then
+ to = to+1
+ end
local scale = math.random()
local input = torch.randn(bs,from,ink,inj,ini):cuda()
@@ -408,6 +437,9 @@ function cudnntest.VolumetricFullConvolution()
local outj = (inj-1)*sj+kj
local outk = (ink-1)*sk+kk
local scale = math.random()
+ if testparams.test_type=='torch.CudaHalfTensor' and ki*kj*kk*from*to % 2 == 1 then
+ to = to+1
+ end
if testparams.test_type == 'torch.CudaDoubleTensor' then
return
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index 69cc740..1be5e37 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -456,11 +456,11 @@ function cudnntest.testVariableLengthSequences()
for _, pair in ipairs(corresponding) do
local sep, batched = unpack(pair)
local diff = torch.csub(separate[sep], packedOutput[batched]):abs():sum()
- mytester:assert(diff < 1e-7)
+ mytester:assertle(diff, 1.5e-7, "output difference is larger than expected")
end
local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum()
- mytester:assert(hdiff < 2e-7)
+ mytester:assertle(hdiff, 2e-7, "hidden output difference is larger than expected")
-- Step 2: update grad input as batch and individually
@@ -470,7 +470,7 @@ function cudnntest.testVariableLengthSequences()
for _, pair in ipairs(corresponding) do
sep, batched = unpack(pair)
local diff = torch.csub(igiTestable[sep], packedGradInput[batched]):abs():sum()
- mytester:assert(diff < 1e-7)
+ mytester:assertle(diff, 1e-7, "gradInput difference is larger than expected")
end
-- Step 3: Basically verify that accGradParameters works for batch