diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2015-10-03 06:40:27 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2015-10-03 06:40:27 +0300 |
commit | edf35b479307594e48fb3a77f32b462865eb18a2 (patch) | |
tree | c3965dd1ec412000ecdd39ac4772636e3b2238de /ffi.lua | |
parent | fb31eba8a4252f07a6891e423d883b79371257a3 (diff) |
Initial cudnn_v4 compatibility patch. cudnn4 Batch Normaliztion Support added. TODO: convert addTensor_v2 to new addTensor API.
Diffstat (limited to 'ffi.lua')
-rw-r--r-- | ffi.lua | 245 |
1 files changed, 235 insertions, 10 deletions
@@ -65,7 +65,7 @@ typedef enum CUDNN_ADD_FULL_TENSOR = 3 } cudnnAddMode_t; -cudnnStatus_t cudnnAddTensor(cudnnHandle_t handle, +cudnnStatus_t cudnnAddTensor_v2(cudnnHandle_t handle, cudnnAddMode_t mode, const void *alpha, const cudnnTensorDescriptor_t biasDesc, @@ -99,7 +99,9 @@ typedef enum CUDNN_CONVOLUTION_WEIGHT_GRAD = 1, /* Weight Gradient update function */ CUDNN_CONVOLUTION_DATA_GRAD = 2 /* Data Gradient update function */ } cudnnConvolutionPath_t; + cudnnStatus_t cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc); + cudnnStatus_t cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc, cudnnDataType_t dataType, int nbDims, @@ -107,8 +109,10 @@ cudnnStatus_t cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc, ); cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnFilterDescriptor_t filterDesc); + cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc ); + cudnnStatus_t cudnnSetConvolutionNdDescriptor_v3( cudnnConvolutionDescriptor_t convDesc, int arrayLength, @@ -120,6 +124,24 @@ cudnnSetConvolutionNdDescriptor_v3( cudnnConvolutionDescriptor_t convDesc, ); cudnnStatus_t + cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, + int arrayLength, /* nbDims-2 size */ + const int padA[], + const int filterStrideA[], + const int upscaleA[], + cudnnConvolutionMode_t mode + ); + +cudnnStatus_t + cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc, + int arrayLengthRequested, + int *arrayLength, + int padA[], + int strideA[], + int upscaleA[], + cudnnConvolutionMode_t *mode + ); +cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim( const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t inputTensorDesc, @@ -128,6 +150,7 @@ cudnnStatus_t int tensorOuputDimA[] ); +/* Destroy an instance of convolution descriptor */ cudnnStatus_t cudnnDestroyConvolutionDescriptor( cudnnConvolutionDescriptor_t convDesc ); @@ -186,6 +209,7 @@ cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize( cudnnHandle_t handle, ); +/* Function to perform the forward multiconvolution */ cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -201,6 +225,7 @@ cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, void *destData ); +/* Functions to perform the backward multiconvolution */ cudnnStatus_t cudnnConvolutionBackwardBias( cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -350,6 +375,35 @@ cudnnStatus_t cudnnConvolutionBackwardData_v3( ); +cudnnStatus_t cudnnConvolutionBackwardFilter( cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t srcDesc, + const void *srcData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + const void *beta, + const cudnnFilterDescriptor_t gradDesc, + void *gradData + ); + + +cudnnStatus_t cudnnConvolutionBackwardData( cudnnHandle_t handle, + const void *alpha, + const cudnnFilterDescriptor_t filterDesc, + const void *filterData, + const cudnnTensorDescriptor_t diffDesc, + const void *diffData, + const cudnnConvolutionDescriptor_t convDesc, + const void *beta, + const cudnnTensorDescriptor_t gradDesc, + void *gradData + ); + + +/* + * softmax algorithm + */ typedef enum { CUDNN_SOFTMAX_FAST = 0, @@ -398,6 +452,7 @@ typedef enum CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility } cudnnPoolingMode_t; +/* Create an instance of pooling descriptor */ cudnnStatus_t cudnnCreatePoolingDescriptor( cudnnPoolingDescriptor_t *poolingDesc); cudnnStatus_t cudnnSetPoolingNdDescriptor( @@ -420,14 +475,15 @@ cudnnStatus_t cudnnGetPoolingNdDescriptor( ); cudnnStatus_t cudnnGetPoolingNdForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[]); - + const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, + int outputTensorDimA[]); +/* Destroy an instance of pooling descriptor */ cudnnStatus_t cudnnDestroyPoolingDescriptor( - cudnnPoolingDescriptor_t poolingDesc ); + cudnnPoolingDescriptor_t poolingDesc ); +/* Function to perform forward pooling */ cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, @@ -438,7 +494,8 @@ cudnnStatus_t cudnnPoolingForward( cudnnHandle_t handle, void *destData ); -cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, +/* Function to perform backward pooling */ +cudnnStatus_t cudnnPoolingBackward( cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -459,7 +516,8 @@ typedef enum CUDNN_ACTIVATION_TANH = 2 } cudnnActivationMode_t; -cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, +/* Function to perform forward activation */ +cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, cudnnActivationMode_t mode, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -469,7 +527,8 @@ cudnnStatus_t cudnnActivationForward( cudnnHandle_t handle, void *destData ); -cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, +/* Function to perform backward activation */ +cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, cudnnActivationMode_t mode, const void *alpha, const cudnnTensorDescriptor_t srcDesc, @@ -486,6 +545,172 @@ cudnnStatus_t cudnnActivationBackward( cudnnHandle_t handle, cudnnStatus_t cudnnCreateLRNDescriptor( cudnnLRNDescriptor_t* normDesc ); typedef enum +{ + CUDNN_BATCHNORM_PER_ACTIVATION = 0, + CUDNN_BATCHNORM_SPATIAL = 1 +} cudnnBatchNormMode_t; + +// Derives a tensor descriptor from layer data descriptor for BatchNormalization scale, invVariance, bnBias, bnScale subtensors +// Use the tensor desc produced by these functions as the bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc parameters in +// Spatial and Per-activation Batch Normalization forward and backward functions. +// Note - derivedBnDesc has to be first created using cudnnCreateTensorDescriptor +// Note - dataDesc is the descriptor for the layer data and has to be setup with proper dimensions prior to calling these functions. +cudnnStatus_t cudnnDeriveBNTensorDescriptor( + cudnnTensorDescriptor_t derivedBnDesc, + const cudnnTensorDescriptor_t dataDesc, + cudnnBatchNormMode_t mode); + +// This function performs a forward pass for Batch Normalization layer. +// In addition to resultTopData it accumulates the moving averages of the mean and inverse variances +cudnnStatus_t cudnnBatchNormalizationForwardTraining( + cudnnHandle_t handle, + cudnnBatchNormMode_t mode, + + const void *alpha, // alpha[0] = result blend factor + const void *beta, // beta[0] = dest layer blend factor + + const cudnnTensorDescriptor_t bottomDesc, + const void *bottomData, // NxCxHxW + void *resultTopData, // NxCxHxW + + // Same shared desc for all the 6 tensors below in the argument list. + // Note that the data type for this descriptor has to be set as follows: + // type = (typeOf(bottomData) == half) ? float : typeof(bottomData) + // The dimensions for this tensor descriptor are dependent on the normalization mode + // For spatial normalization the tensors are expected to be 1D (of size C) + // (in this case normalization is performed across NxHxW) + // In per-activation mode the normalization is performed across N dimension only + // So the tensors are expected to have dimensions of CxHxW + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + + // Note - bnScale is 'gamma' in paper's notation + const void *bnScaleData, // Mode-dependent dims + // Note - this bias parameter can effectively replace the bias in Conv and FCN layers + // (Which can be set to zero for efficiency) + // Note - bnBias is 'beta' in paper's notation + const void *bnBiasData, // Mode-dependent dims + + // It is required that factor=1 is used for the very first call of a complete training cycle. + // This is necessary to properly initialize the moving average. + // Use a factor=1/(1+n) at N-th call to the function to get + // Cumulative Moving Average (CMA) behavior + // CMA[n] = (x[1]+...+x[n])/n + // Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = + // ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = + // CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) + double exponentialAverageFactor, + + // runningMean = newMean*factor + runningMean*(1-factor) + // if isTrainingPhase == false, these tensors will remain const + // and exponentialAverageFactor parameter is not used. + + // Both of these pointers (running mean, inv variance) can be NULL but only at the same time. + void *resultRunningMean, + // The value stored here (or passed as an input in inference mode) is the moving average + // of the expression 1 / sqrt( epsilon + variance[bottomData] ) + void *resultRunningInvVariance, + + // Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. + // Same epsilon value should be used in forward and backward functions. + double epsilon, + + // Optional cache to save intermediate results computed during the forward pass + // - these can then be reused to speed up backward pass. For this to work correctly, + // the bottom layer data has to remain unchanged until the backward function is called. + // Note that both of these parameters can be NULL but only at the same time. + // It is recommended to use this cache since memory overhead is relatively small. + void *resultSaveMean, + void *resultSaveInvVariance + ); + +// This function will compute a linear transform of the inputs as follows: +// topData[i] = bnScale[k]*(bottomData[i]-estimatedMean[k])*estimatedInvVariance[k] + bnBias[k] +// with bnScale, bnBias, runningMean, runningInvVariance tensors indexed +// according to spatial or per-activation mode (please refer to the paper for details). +// During inference estimatedMean and estimatedVariance are treated +// as const inputs (accumulated and saved during the training phase) +cudnnStatus_t cudnnBatchNormalizationForwardInference( + cudnnHandle_t handle, + cudnnBatchNormMode_t mode, + + const void *alpha, // alpha[0] = result blend factor + const void *beta, // beta[0] = dest layer blend factor + + const cudnnTensorDescriptor_t bottomDesc, + const void *bottomData, // NxCxHxW + void *resultTopData, // NxCxHxW + + // Same desc for all 4 tensors below + // Note that the data type for this descriptor has to be set as follows: + // type = (typeOf(bottomData) == half) ? float : typeof(bottomData) + // The dimensions for this tensor descriptor are dependent on the normalization mode + // For spatial normalization the tensors are expected to be 1D (of size C) + // (in this case normalization is performed across NxHxW) + // In per-activation mode the normalization is performed across N dimension only + // So the tensors are expected to have dimensions of CxHxW + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + + // Note - bnScale is 'gamma' in paper's notation + const void *bnScaleData, // Mode-dependent dims + // Note - this bias parameter can effectively replace the bias in Conv and FCN layers + // (Which can be set to zero for efficiency) + // Note - bnBias is 'beta' in paper's notation + const void *bnBiasData, // Mode-dependent dims + + // runningMean = newMean*factor + runningMean*(1-factor) + // if isTrainingPhase == false, these tensors will remain const + // and exponentialAverageFactor parameter is not used. + + // An estimate of the batch mean, can be accumulated over multiple calls to + // batchNormalizationForwardTraining + const void *estimatedMean, + // An estimate of the expression 1 / sqrt( epsilon + variance[bottomData] ), + // Can also be accumulated over multiple calls to batchNormalizationForwardTraining. + const void *estimatedInvVariance, + + // Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. + // Same epsilon value should be used in forward and backward functions. + double epsilon + ); + + +// This function performs a backward pass for Batch Normalization layer. +// The results are +// 1. bottom layer data differential +// 2. bnScale differential +// 3. bnBias differential +cudnnStatus_t cudnnBatchNormalizationBackward( + cudnnHandle_t handle, + cudnnBatchNormMode_t mode, + + const void *alpha, // result blend factor = alpha[0] + const void *beta, // bottom blend factor = beta[0] + + const cudnnTensorDescriptor_t bottomDesc, // same desc for topDiff, bottomDiff + const void *bottomData, // NxCxHxW + const void *topDiff, // NxCxHxW + void *resultBottomDiff, // NxCxHxW + + // this tensor desc is used for all the 4 tensors below + const cudnnTensorDescriptor_t bnScaleBiasDiffDesc, + const void *bottomBnScale, // bottomBnBias doesn't affect backpropagation + + // scale and bias diff are not backpropagated below this layer (dead-end computation DAG nodes) + void *resultBnScaleDiff, // mode-dependent dims + void *resultBnBiasDiff, // mode-dependent dims + // Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. + // Same epsilon value should be used in forward and backward functions. + double epsilon, + + // Optional cache parameters containing saved intermediate results computed during the forward pass + // For this to work correctly, the bottom layer data has to remain unchanged until the backward function is called. + // Note that both of these parameters can be NULL but only at the same time. + // It is recommended to use this cache since memory overhead is relatively small. + const void *savedMean, + const void *savedInvVariance + ); + +typedef enum { CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, } cudnnLRNMode_t; |