diff options
-rw-r--r-- | BatchNormalization.lua | 2 | ||||
-rw-r--r-- | Pointwise.lua | 8 | ||||
-rw-r--r-- | Pooling.lua | 7 | ||||
-rw-r--r-- | Pooling3D.lua | 9 | ||||
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | RNN.lua | 35 | ||||
-rw-r--r-- | SpatialConvolution.lua | 2 | ||||
-rw-r--r-- | SpatialCrossMapLRN.lua | 7 | ||||
-rw-r--r-- | SpatialFullConvolution.lua | 6 | ||||
-rw-r--r-- | SpatialSoftMax.lua | 7 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 8 | ||||
-rw-r--r-- | ffi.lua | 214 | ||||
-rw-r--r-- | test/test.lua | 2 | ||||
-rw-r--r-- | test/test_rnn.lua | 6 |
14 files changed, 155 insertions, 162 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua index a342063..83597d3 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -48,7 +48,6 @@ function BatchNormalization:createIODescriptors(input) local nFeature = self.running_mean:numel() self.iSize = input:size() self.output:resizeAs(input) - self.gradInput:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) local biasSize = torch.ones(self.nDim):totable() @@ -88,6 +87,7 @@ end local function backward(self,input,gradOutput, scale) assert(gradOutput:isContiguous()) self:createIODescriptors(input) + self.gradInput:resizeAs(input) scale = scale or 1 scaleTens:fill(scale) errcheck('cudnnBatchNormalizationBackward', diff --git a/Pointwise.lua b/Pointwise.lua index 92b3e45..93298ad 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -12,7 +12,6 @@ function Pointwise:createIODescriptors(input) assert(self.mode, 'mode is not set. (trying to use base class?)'); assert(input:isContiguous(), 'Non-contiguous inputs not supported yet'); if not self.inplace then - self.gradInput:resizeAs(input) self.output:resizeAs(input) end @@ -60,7 +59,12 @@ function Pointwise:updateGradInput(input, gradOutput) gradOutput = self._gradOutput end self:createIODescriptors(input) - if self.inplace then self.output:set(input); self.gradInput:set(gradOutput) end + if self.inplace then + self.output:set(input); + self.gradInput:set(gradOutput) + else + self.gradInput:resizeAs(input) + end errcheck('cudnnActivationBackward', cudnn.getHandle(), self.activDesc[0], one:data(), diff --git a/Pooling.lua b/Pooling.lua index d004563..45afccb 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -53,9 +53,6 @@ function Pooling:createIODescriptors(input) input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() - -- resize gradInput - self.gradInput:resizeAs(input) - -- resize output local oW, oH if self.ceil_mode then oW = math.ceil((input:size(4)+self.padW*2 - self.kW)/self.dW + 1) @@ -70,9 +67,6 @@ function Pooling:createIODescriptors(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) if not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) @@ -102,6 +96,7 @@ function Pooling:updateGradInput(input, gradOutput) self._gradOutput:resizeAs(gradOutput):copy(gradOutput) gradOutput = self._gradOutput end + self.gradInput:resizeAs(input) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) errcheck('cudnnPoolingBackward', diff --git a/Pooling3D.lua b/Pooling3D.lua index cce67c3..e4c0218 100644 --- a/Pooling3D.lua +++ b/Pooling3D.lua @@ -58,8 +58,6 @@ function Pooling:createIODescriptors(input) or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] or input:size(5) ~= self.iSize[5] then self.iSize = input:size() - -- resize gradInput - self.gradInput:resizeAs(input) -- resize output local oW, oH, oT if self.ceil_mode then @@ -77,10 +75,6 @@ function Pooling:createIODescriptors(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) if not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4), - self.gradInput:size(5)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4), @@ -105,6 +99,9 @@ function Pooling:updateOutput(input) end function Pooling:updateGradInput(input, gradOutput) + if not self.gradInput then return end + self.gradInput:resizeAs(input) + assert(gradOutput:dim() == 4 or gradOutput:dim() == 5); if not gradOutput:isContiguous() then self._gradOutput = self._gradOutput or gradOutput.new() @@ -98,6 +98,4 @@ nn.Sequential { For version CuDNN R1, checkout the branch **R1** For version CuDNN R2, checkout the branch **R2** For version CuDNN R3, checkout the branch **R3** -For version CuDNN R4, checkout the branch **master** - -R5 Release Notes: +For version CuDNN R4, checkout the branch **R4** @@ -44,8 +44,9 @@ function RNN:reset(stdv) errcheck('cudnnGetRNNParamsSize', cudnn.getHandle(), self.rnnDesc[0], - self.xDescs, - weightSize:data()) + self.xDescs[0], + weightSize:data(), + self.datatype) weightSize[1] = (weightSize[1] + 3) / 4 -- sizeof(float) self.weight:resize(weightSize[1]) self.weight:uniform(-stdv, stdv) @@ -116,12 +117,10 @@ end function RNN:resetRNNDescriptor() if not self.rnnDesc then self.rnnDesc = self:createRNNDescriptors(1) - end - + end errcheck('cudnnSetRNNDescriptor', self.rnnDesc[0], self.hiddenSize, - self.seqLength, self.numLayers, self.dropoutDesc[0], self.inputMode, @@ -150,8 +149,8 @@ function RNN:resetIODescriptors() self.yDescs = self:createTensorDescriptors(self.seqLength) for i = 0, self.seqLength - 1 do - local dim = torch.IntTensor({self.inputSize, self.miniBatch, self.seqLength}) - local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]}) + local dim = torch.IntTensor({ self.miniBatch,self.inputSize, 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', self.xDescs[i], self.datatype, @@ -159,8 +158,8 @@ function RNN:resetIODescriptors() dim:data(), stride:data()) - local dim = torch.IntTensor({self.hiddenSize * self.numDirections, self.miniBatch, self.seqLength}) - local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]}) + local dim = torch.IntTensor({self.miniBatch, self.hiddenSize * self.numDirections, 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', self.yDescs[i], self.datatype, @@ -173,9 +172,8 @@ end function RNN:resetHiddenDescriptors() self.hxDesc = self:createTensorDescriptors(1) self.hyDesc = self:createTensorDescriptors(1) - - local dim = torch.IntTensor({self.hiddenSize, self.miniBatch, self.numLayers}) - local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]}) + local dim = torch.IntTensor({self.numLayers*self.numDirections, self.miniBatch, self.hiddenSize }) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', self.hxDesc[0], @@ -194,9 +192,8 @@ end function RNN:resetCellDescriptors() self.cxDesc = self:createTensorDescriptors(1) self.cyDesc = self:createTensorDescriptors(1) - - local dim = torch.IntTensor({self.hiddenSize, self.miniBatch, self.numLayers}) - local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]}) + local dim = torch.IntTensor({self.numLayers*self.numDirections, self.miniBatch, self.hiddenSize }) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', self.cxDesc[0], @@ -261,7 +258,7 @@ function RNN:updateOutput(input) -- Update descriptors/tensors if resetRNN then - self:resetDropoutDescriptor() + if not self.dropoutDesc then self:resetDropoutDescriptor() end self:resetRNNDescriptor() end if resetIO then @@ -305,6 +302,7 @@ function RNN:updateOutput(input) errcheck('cudnnGetRNNWorkspaceSize', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.xDescs, workspaceSize:data()) workspaceSize[1] = (workspaceSize[1] + 3) / 4 -- sizeof(float) @@ -317,6 +315,7 @@ function RNN:updateOutput(input) errcheck('cudnnGetRNNTrainingReserveSize', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.xDescs, reserveSize:data()) reserveSize[1] = (reserveSize[1] + 3) / 4 -- sizeof(float) @@ -328,6 +327,7 @@ function RNN:updateOutput(input) errcheck('cudnnRNNForwardTraining', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.xDescs, x:data(), self.hxDesc[0], hx and hx:data() or nil, self.cxDesc[0], cx and cx:data() or nil, @@ -341,6 +341,7 @@ function RNN:updateOutput(input) errcheck('cudnnRNNForwardInference', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.xDescs, x:data(), self.hxDesc[0], hx and hx:data() or nil, self.cxDesc[0], cx and cx:data() or nil, @@ -417,6 +418,7 @@ function RNN:updateGradInput(input, gradOutput) errcheck('cudnnRNNBackwardData', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.yDescs, y:data(), self.yDescs, dy:data(), self.hyDesc[0], dhy and dhy:data() or nil, @@ -480,6 +482,7 @@ function RNN:accGradParameters(input, gradOutput, scale) errcheck('cudnnRNNBackwardWeights', cudnn.getHandle(), self.rnnDesc[0], + self.seqLength, self.xDescs, x:data(), self.hxDesc[0], hx and hx:data() or nil, self.yDescs, y:data(), diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index b92dd57..1fd3ea0 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -388,8 +388,8 @@ end function SpatialConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end - self.gradInput:resizeAs(input) + input, gradOutput = makeContiguous(self, input, gradOutput) assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); if not self.weightDesc then self:resetWeightDescriptors() end diff --git a/SpatialCrossMapLRN.lua b/SpatialCrossMapLRN.lua index f6e7cd9..1f4ba33 100644 --- a/SpatialCrossMapLRN.lua +++ b/SpatialCrossMapLRN.lua @@ -36,15 +36,11 @@ function LRN:createIODescriptors(input) input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() - self.gradInput:resizeAs(input) self.output:resizeAs(input) -- create input/output descriptor self.iDesc = cudnn.toDescriptor(input) if not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) @@ -70,6 +66,9 @@ function LRN:updateOutput(input) end function LRN:updateGradInput(input, gradOutput) + if not self.gradInput then return end + self.gradInput:resizeAs(input) + assert(gradOutput:dim() == 3 or gradOutput:dim() == 4); if not gradOutput:isContiguous() then self._gradOutput = self._gradOutput or gradOutput.new() diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua index 887e85b..cfbd61d 100644 --- a/SpatialFullConvolution.lua +++ b/SpatialFullConvolution.lua @@ -75,8 +75,6 @@ function SpatialFullConvolution:createIODescriptors(input) or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() - -- resize gradInput - if self.gradInput then self.gradInput:resizeAs(input); end assert(self.nInputPlane == input:size(2), 'input has to contain: ' .. self.nInputPlane .. ' feature maps, but received input of size: ' @@ -291,9 +289,6 @@ function SpatialFullConvolution:createIODescriptors(input) end if not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) @@ -329,6 +324,7 @@ end function SpatialFullConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end + self.gradInput:resizeAs(input) assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous') diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 4b3a488..167eb1f 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -35,22 +35,16 @@ function SpatialSoftMax:createIODescriptors(input) input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() - self.gradInput:resizeAs(input) self.output:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) if not singleDim and not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) elseif singleDim and not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2)) self.output = self.output:view(self.output:size(2)) elseif singleDim and batch then - self.gradInput = self.gradInput:view(self.gradInput:size(1), self.gradInput:size(2)) self.output = self.output:view(self.output:size(1), self.output:size(2)) end end @@ -72,6 +66,7 @@ function SpatialSoftMax:updateOutput(input) end function SpatialSoftMax:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(input) if not gradOutput:isContiguous() then self._gradOutput = self._gradOutput or gradOutput.new() self._gradOutput:resizeAs(gradOutput):copy(gradOutput) diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index e9efb64..fd5e9c7 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -76,8 +76,6 @@ function VolumetricConvolution:createIODescriptors(input) or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] or input:size(5) ~= self.iSize[5] then self.iSize = input:size() - -- resize gradInput - if self.gradInput then self.gradInput:resizeAs(input); end -- create input descriptor self.iDesc = cudnn.toDescriptor(input) -- create conv descriptor @@ -287,10 +285,6 @@ function VolumetricConvolution:createIODescriptors(input) ----------------------------------------------------------------------- if not batch then - self.gradInput = self.gradInput:view(self.gradInput:size(2), - self.gradInput:size(3), - self.gradInput:size(4), - self.gradInput:size(5)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4), @@ -337,6 +331,8 @@ end function VolumetricConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end + self.gradInput:resizeAs(input) + input, gradOutput = makeContiguous(self, input, gradOutput) assert(gradOutput:dim() == 4 or gradOutput:dim() == 5, 'gradOutput has to be a 4D or 5D tensor'); @@ -724,7 +724,8 @@ typedef enum { CUDNN_POOLING_MAX = 0, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values*/ - CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2 /* count for average does not include padded values*/ + CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values*/ + CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility } cudnnPoolingMode_t; /* Create an instance of pooling descriptor */ @@ -1241,146 +1242,154 @@ cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDes cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc); cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int seqLength, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps.*/ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, + int hiddenSize, + int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnDataType_t dataType); -/* dataType in the RNN descriptor is used to determine math precision*/ -/* dataType in weight descriptors and input descriptors is used to describe storage*/ +// dataType in the RNN descriptor is used to determine math precision +// dataType in weight descriptors and input descriptors is used to describe storage cudnnStatus_t cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, size_t *sizeInBytes ); - + cudnnStatus_t cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, size_t *sizeInBytes ); - + cudnnStatus_t cudnnGetRNNParamsSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes + const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, + size_t *sizeInBytes, + cudnnDataType_t dataType ); cudnnStatus_t cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDescriptor_t rnnDesc, const int layer, - const cudnnTensorDescriptor_t * xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, + const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const void * w, + const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void ** linLayerMat ); cudnnStatus_t cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDescriptor_t rnnDesc, const int layer, - const cudnnTensorDescriptor_t * xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void ** linLayerBias + const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const void * w, + const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, + void ** linLayerBias ); -cudnnStatus_t cudnnRNNForwardInference( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, - size_t workSpaceSizeInBytes); - - - -cudnnStatus_t cudnnRNNForwardTraining( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t *xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, +cudnnStatus_t cudnnRNNForwardInference( cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, + const cudnnTensorDescriptor_t * xDesc, + const void * x, + const cudnnTensorDescriptor_t hxDesc, + const void * hx, + const cudnnTensorDescriptor_t cxDesc, + const void * cx, + const cudnnFilterDescriptor_t wDesc, + const void * w, + const cudnnTensorDescriptor_t *yDesc, + void * y, + const cudnnTensorDescriptor_t hyDesc, + void * hy, + const cudnnTensorDescriptor_t cyDesc, + void * cy, + void * workspace, + size_t workSpaceSizeInBytes); + + + +cudnnStatus_t cudnnRNNForwardTraining( cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, + const cudnnTensorDescriptor_t *xDesc, + const void * x, + const cudnnTensorDescriptor_t hxDesc, + const void * hx, + const cudnnTensorDescriptor_t cxDesc, + const void * cx, + const cudnnFilterDescriptor_t wDesc, + const void * w, + const cudnnTensorDescriptor_t *yDesc, + void * y, + const cudnnTensorDescriptor_t hyDesc, + void * hy, + const cudnnTensorDescriptor_t cyDesc, + void * cy, + void * workspace, size_t workSpaceSizeInBytes, - void * reserveSpace, + void * reserveSpace, size_t reserveSpaceSizeInBytes); -cudnnStatus_t cudnnRNNBackwardData( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const cudnnTensorDescriptor_t * dyDesc, - const void * dy, - const cudnnTensorDescriptor_t dhyDesc, - const void * dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void * dcy, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnTensorDescriptor_t * dxDesc, - void * dx, +cudnnStatus_t cudnnRNNBackwardData( cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, + const cudnnTensorDescriptor_t * yDesc, + const void * y, + const cudnnTensorDescriptor_t * dyDesc, + const void * dy, + const cudnnTensorDescriptor_t dhyDesc, + const void * dhy, + const cudnnTensorDescriptor_t dcyDesc, + const void * dcy, + const cudnnFilterDescriptor_t wDesc, + const void * w, + const cudnnTensorDescriptor_t hxDesc, + const void * hx, + const cudnnTensorDescriptor_t cxDesc, + const void * cx, + const cudnnTensorDescriptor_t * dxDesc, + void * dx, const cudnnTensorDescriptor_t dhxDesc, void * dhx, const cudnnTensorDescriptor_t dcxDesc, void * dcx, void * workspace, size_t workSpaceSizeInBytes, - const void * reserveSpace, + const void * reserveSpace, size_t reserveSpaceSizeInBytes ); -cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t * yDesc, +cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, + const cudnnTensorDescriptor_t * xDesc, + const void * x, + const cudnnTensorDescriptor_t hxDesc, + const void * hx, + const cudnnTensorDescriptor_t * yDesc, const void * y, - const void * workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, + const void * workspace, + size_t workSpaceSizeInBytes, + const cudnnFilterDescriptor_t dwDesc, void * dw, - const void * reserveSpace, + const void * reserveSpace, size_t reserveSpaceSizeInBytes ); - + + + /* DEPRECATED routines to be removed next release : @@ -1576,7 +1585,6 @@ cudnnStatus_t cudnnActivationBackward_v4( const cudnnTensorDescriptor_t dxDesc, void *dx ); - ]] local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib'} @@ -1595,8 +1603,8 @@ Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in yo end cudnn.version = tonumber(cudnn.C.cudnnGetVersion()) -if cudnn.version < 5002 then - error('These bindings are for version 5002 or above, ' +if cudnn.version < 5005 then + error('These bindings are for version 5005 or above, ' .. 'while the loaded CuDNN is version: ' .. cudnn.version .. ' \nAre you using an older version of CuDNN?') end diff --git a/test/test.lua b/test/test.lua index ba1b5a1..3d2521e 100644 --- a/test/test.lua +++ b/test/test.lua @@ -820,6 +820,8 @@ function cudnntest.SpatialAveragePooling_single() local sconv = nn.SpatialAveragePooling(ki,kj,si,sj):cuda() local gconv = cudnn.SpatialAveragePooling(ki,kj,si,sj):cuda() + mytester:assert(cudnn.C.CUDNN_POOLING_AVERAGE ~= nil, 'back-compat broken') + local function test(sconv, gconv) local groundtruth = sconv:forward(input):clone() local groundgrad = sconv:backward(input, gradOutput) diff --git a/test/test_rnn.lua b/test/test_rnn.lua index e7ee3de..2476ce4 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -216,7 +216,7 @@ function getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numbe cudnn.getHandle(), rnn.rnnDesc[0], layer, - rnn.xDescs, + rnn.xDescs[0], rnn.wDesc[0], rnn.weight:data(), layerId, @@ -247,7 +247,7 @@ function getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numbe cudnn.getHandle(), rnn.rnnDesc[0], layer, - rnn.xDescs, + rnn.xDescs[0], rnn.wDesc[0], rnn.weight:data(), layerId, @@ -313,4 +313,4 @@ end mytester = torch.Tester() mytester:add(cudnntest) -mytester:run()
\ No newline at end of file +mytester:run() |