Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-04-13 16:01:55 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-04-13 19:01:48 +0300
commit70adafcae259f67129c2de6e1048594aa0283e59 (patch)
tree432358dc7274ec86132e2c3bb40b655b0f4a4f61
parent0a040cf2bb3d0f0cdf217c8e92efaddf29ed2efc (diff)
R5 rebase
-rw-r--r--Pointwise.lua23
-rw-r--r--Pooling.lua2
-rw-r--r--Pooling3D.lua2
-rw-r--r--README.md9
-rw-r--r--SpatialConvolution.lua18
-rw-r--r--SpatialFullConvolution.lua8
-rw-r--r--VolumetricConvolution.lua261
-rw-r--r--ffi.lua1226
-rw-r--r--functional.lua20
-rw-r--r--test/benchmark.lua84
10 files changed, 1060 insertions, 593 deletions
diff --git a/Pointwise.lua b/Pointwise.lua
index 8d0a06e..92b3e45 100644
--- a/Pointwise.lua
+++ b/Pointwise.lua
@@ -1,5 +1,7 @@
local Pointwise, parent = torch.class('cudnn._Pointwise','nn.Module')
+
local errcheck = cudnn.errcheck
+local ffi = require 'ffi'
function Pointwise:__init(inplace)
parent.__init(self)
@@ -13,11 +15,27 @@ function Pointwise:createIODescriptors(input)
self.gradInput:resizeAs(input)
self.output:resizeAs(input)
end
+
+ if not self.activDesc then
+ self.activDesc = ffi.new('struct cudnnActivationStruct*[1]')
+ errcheck('cudnnCreateActivationDescriptor', self.activDesc)
+ errcheck('cudnnSetActivationDescriptor', self.activDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', 0.0);
+
+ local function destroyADesc(a)
+ if (a[0]) then
+ errcheck('cudnnDestroyActivationDescriptor', a[0]);
+ a[0] = nil
+ end
+ end
+ ffi.gc(self.activDesc, destroyADesc)
+ end
+
local nElem = input:nElement()
self.nElem = self.nElem or nElem -- this goes to the second branch only once
if self.iDesc and nElem == self.nElem then return end
self.nElem = nElem
self.iDesc = cudnn.toDescriptor(input:view(1,1,1,nElem))
+
end
local one = torch.FloatTensor({1});
@@ -27,7 +45,7 @@ function Pointwise:updateOutput(input)
self:createIODescriptors(input)
if self.inplace then self.output:set(input) end
errcheck('cudnnActivationForward',
- cudnn.getHandle(), self.mode,
+ cudnn.getHandle(), self.activDesc[0],
one:data(),
self.iDesc[0], input:data(),
zero:data(),
@@ -44,7 +62,7 @@ function Pointwise:updateGradInput(input, gradOutput)
self:createIODescriptors(input)
if self.inplace then self.output:set(input); self.gradInput:set(gradOutput) end
errcheck('cudnnActivationBackward',
- cudnn.getHandle(), self.mode,
+ cudnn.getHandle(), self.activDesc[0],
one:data(),
self.iDesc[0], self.output:data(),
self.iDesc[0], gradOutput:data(),
@@ -56,6 +74,7 @@ end
function Pointwise:clearDesc()
self.iDesc = nil
+ self.activDesc = nil
end
function Pointwise:write(f)
diff --git a/Pooling.lua b/Pooling.lua
index 00bd0a8..d004563 100644
--- a/Pooling.lua
+++ b/Pooling.lua
@@ -33,7 +33,7 @@ function Pooling:resetPoolDescriptors()
local ker = torch.IntTensor({self.kH, self.kW})
local str = torch.IntTensor({self.dH, self.dW})
local pad = torch.IntTensor({self.padH, self.padW})
- errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 2,
+ errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', 2,
ker:data(), pad:data(), str:data());
local function destroyPoolDesc(d)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
diff --git a/Pooling3D.lua b/Pooling3D.lua
index 489865c..cce67c3 100644
--- a/Pooling3D.lua
+++ b/Pooling3D.lua
@@ -37,7 +37,7 @@ function Pooling:resetPoolDescriptors()
local ker = torch.IntTensor({self.kT, self.kH, self.kW})
local str = torch.IntTensor({self.dT, self.dH, self.dW})
local pad = torch.IntTensor({self.padT, self.padH, self.padW})
- errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 3,
+ errcheck('cudnnSetPoolingNdDescriptor', self.poolDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', 3,
ker:data(), pad:data(), str:data());
local function destroyPoolDesc(d)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
diff --git a/README.md b/README.md
index 1ac5cb3..3f1a43f 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,14 @@
cudnn.torch
===========
-Torch7 FFI bindings for NVIDIA cuDNN (R4) kernels!
+Torch7 FFI bindings for NVIDIA cuDNN (R5) kernels!
Modules are API compatible their [`nn`](https://github.com/torch/nn) equivalents. Fully unit-tested against `nn` implementations.
Conversion between `nn` and `cudnn` is available through `cudnn.convert` function.
#### Installation
-* Install cuDNN (version R4 EA)
+* Install cuDNN (version R5 EA)
* Have at least CUDA 7.0
* Have `libcudnn.so` in your library path (Install it from https://developer.nvidia.com/cuDNN )
@@ -89,8 +89,7 @@ nn.Sequential {
For version CuDNN R1, checkout the branch **R1**
For version CuDNN R2, checkout the branch **R2**
For version CuDNN R3, checkout the branch **R3**
+For version CuDNN R4, checkout the branch **master**
-R4 Release Notes:
-- Rather than resolving v3-v4 diffs, I have imported new cudnn.h with its entirety and converted comments and defines. This should be less error-prone.
-- addTensor_v2 uses changed to new AddTensor API.
+R5 Release Notes:
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 20f31ef..b92dd57 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -43,7 +43,7 @@ function SpatialConvolution:resetWeightDescriptors()
self.nInputPlane/self.groups,
self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- 'CUDNN_DATA_FLOAT', 4,
+ 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -122,7 +122,7 @@ function SpatialConvolution:createIODescriptors(input)
local pad = torch.IntTensor({self.padH, self.padW})
local stride = torch.IntTensor({self.dH, self.dW})
local upscale = torch.IntTensor({1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -177,7 +177,7 @@ function SpatialConvolution:createIODescriptors(input)
if autotunerCache[1][autotunerHash] then
algType[0] = autotunerCache[1][autotunerHash]
if cudnn.verbose then
- print('Using cached benchmark for: ', autotunerHash)
+ print('Autotuning SC FW: using cached algo = ', algType[0], ' for: ', autotunerHash)
end
else
local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1)
@@ -191,7 +191,7 @@ function SpatialConvolution:createIODescriptors(input)
autotunerCache[1][autotunerHash] = perfResults[0].algo
if cudnn.verbose then
print(string.format(
- "Autotuning Forward: Time: %3.5f Memory: %8d Algorithm: %d"
+ "\nAutotuning SC Forward: Time: %3.5f Memory: %8d Algorithm: %d"
.. " Weight: %15s Input: %15s Output: %15s",
perfResults[0].time, tonumber(perfResults[0].memory),
tonumber(perfResults[0].algo),
@@ -228,6 +228,9 @@ function SpatialConvolution:createIODescriptors(input)
if cudnn.benchmark then -- the manual auto-tuner is run
if autotunerCache[2][autotunerHash] then
algType[0] = autotunerCache[2][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning SC BW: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
else
local perfResults = ffi.new("cudnnConvolutionBwdFilterAlgoPerf_t[?]", 1)
local intt = torch.IntTensor(1);
@@ -276,6 +279,9 @@ function SpatialConvolution:createIODescriptors(input)
if cudnn.benchmark then -- the manual auto-tuner is run
if autotunerCache[3][autotunerHash] then
algType[0] = autotunerCache[3][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning SC BWD: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
else
local perfResults = ffi.new("cudnnConvolutionBwdDataAlgoPerf_t[?]", 1)
local intt = torch.IntTensor(1);
@@ -390,7 +396,7 @@ function SpatialConvolution:updateGradInput(input, gradOutput)
self:createIODescriptors(input)
for g = 0,self.groups - 1 do
- errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(),
+ errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
one:data(),
self.weightDesc[0], self.weight:data() + g*self.weight_offset,
self.oDesc[0], gradOutput:data() + g*self.output_offset,
@@ -427,7 +433,7 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
for g = 0, self.groups - 1 do
-- gradWeight
- errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(),
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(),
self.scaleT:data(),
self.iDesc[0], input:data() + g*self.input_offset,
self.oDesc[0], gradOutput:data() + g*self.output_offset,
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua
index d00a8a2..887e85b 100644
--- a/SpatialFullConvolution.lua
+++ b/SpatialFullConvolution.lua
@@ -21,7 +21,7 @@ function SpatialFullConvolution:resetWeightDescriptors()
self.nOutputPlane,
self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- 'CUDNN_DATA_FLOAT', 4,
+ 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -93,7 +93,7 @@ function SpatialFullConvolution:createIODescriptors(input)
local pad = torch.IntTensor({self.padH, self.padW})
local stride = torch.IntTensor({self.dH, self.dW})
local upscale = torch.IntTensor({1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -147,7 +147,7 @@ function SpatialFullConvolution:createIODescriptors(input)
if autotunerCache[1][autotunerHash] then
algType[0] = autotunerCache[1][autotunerHash]
if cudnn.verbose then
- print('Using cached benchmark for: ', autotunerHash)
+ print('Autotuning SFC: using cached algo = ', algType[0], ' for: ', autotunerHash)
end
else
local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1)
@@ -309,7 +309,7 @@ function SpatialFullConvolution:updateOutput(input)
self:createIODescriptors(input)
-- Because SpatialFullConvolution is performing the adjoint of the forward
- -- convolution operator, we need to swap the forward and backward passes.
+ -- convolution operator, we need to swap the forward and backward passes.
errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
one:data(),
self.weightDesc[0], self.weight:data(),
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 6ec3302..e9efb64 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -3,6 +3,11 @@ local VolumetricConvolution, parent
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
+local autotunerCache = {}
+autotunerCache[1] = {} -- forward
+autotunerCache[2] = {} -- backwardFilter
+autotunerCache[3] = {} -- backwardData
+
-- if you change the configuration of the module manually, call this
function VolumetricConvolution:resetWeightDescriptors()
assert(torch.typename(self.weight) == 'torch.CudaTensor',
@@ -15,7 +20,7 @@ function VolumetricConvolution:resetWeightDescriptors()
local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane,
self.kT, self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- 'CUDNN_DATA_FLOAT', 5,
+ 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 5,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -81,7 +86,7 @@ function VolumetricConvolution:createIODescriptors(input)
local pad = torch.IntTensor({self.padT, self.padH, self.padW})
local stride = torch.IntTensor({self.dT, self.dH, self.dW})
local upscale = torch.IntTensor({1,1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
3, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -104,85 +109,183 @@ function VolumetricConvolution:createIODescriptors(input)
self.output:size(2),
self.output:size(3)*self.output:size(4),
self.output:size(5)))
- -----------------------------------------------------------------
- local maxBufSize = 0
- -- create forwardAlgorithm descriptors for
- local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
- local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
- local algWorkspaceLimit = self.workspace_limit
- or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
- if self.fastest_mode or cudnn.fastest == true then
- algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST'
- end
- errcheck('cudnnGetConvolutionForwardAlgorithm',
- cudnn.getHandle(),
- self.iDesc[0], self.weightDesc[0],
- self.convDesc[0], self.oDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
- algType[0] = self.fmode or algType[0]
- self.fwdAlgType = algType
- local bufSize = torch.LongTensor(1)
- errcheck('cudnnGetConvolutionForwardWorkspaceSize',
- cudnn.getHandle(),
- self.iDesc[0], self.weightDesc[0],
- self.convDesc[0], self.oDesc[0],
- algType[0], bufSize:data())
- maxBufSize = math.max(maxBufSize, bufSize[1])
- -- create backwardFilterAlgorithm descriptors for
- local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
- local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
- local algWorkspaceLimit = self.workspace_limit
- or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
- if self.fastest_mode or cudnn.fastest == true then
- algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
- end
- errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
- cudnn.getHandle(),
- self.iDesc[0], self.oDesc[0],
- self.convDesc[0], self.weightDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
- algType[0] = self.bwmode or algType[0]
- self.bwdFilterAlgType = algType
- local bufSize = torch.LongTensor(1)
- errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize',
- cudnn.getHandle(),
- self.iDesc[0], self.oDesc[0],
- self.convDesc[0], self.weightDesc[0],
- algType[0], bufSize:data())
- maxBufSize = math.max(maxBufSize, bufSize[1])
- -- create backwardDataAlgorithm descriptors for
- local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
- local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
- local algWorkspaceLimit = self.workspace_limit
- or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
- if self.fastest_mode or cudnn.fastest == true then
- algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
- end
- errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
- cudnn.getHandle(),
- self.weightDesc[0], self.oDesc[0],
- self.convDesc[0], self.iDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
- algType[0] = self.bdmode or algType[0]
- self.bwdDataAlgType = algType
- local bufSize = torch.LongTensor(1)
- errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize',
- cudnn.getHandle(),
- self.weightDesc[0], self.oDesc[0],
- self.convDesc[0], self.iDesc[0],
- algType[0], bufSize:data())
- maxBufSize = math.max(maxBufSize, bufSize[1])
- self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace()
- self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float
- if maxBufSize > self.extraBufferSizeInBytes then
- self.extraBuffer:resize(math.ceil(maxBufSize/4))
- self.extraBufferSizeInBytes = maxBufSize
- end
+ -----------------------------------------------------------------------
+ local function shape(x)
+ return table.concat(x:size():totable(),'x')
+ end
+ local autotunerHash = shape(self.weight) .. ';'
+ .. shape(input) .. ';'
+ .. shape(self.output)
+
+ local maxBufSize = 0
+
+ -- create forwardAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST'
+ end
+
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[1][autotunerHash] then
+ algType[0] = autotunerCache[1][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC FW: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[1][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "\nAutotuning VMC Forward: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.fmode or algType[0]
+ self.fwdAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionForwardWorkspaceSize',
+ cudnn.getHandle(),
+ self.iDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ -- create backwardFilterAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
+ end
+
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[2][autotunerHash] then
+ algType[0] = autotunerCache[2][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC BWF: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionBwdFilterAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.oDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[2][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "Autotuning backwardFilter: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.oDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.bwmode or algType[0]
+ self.bwdFilterAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize',
+ cudnn.getHandle(),
+ self.iDesc[0], self.oDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ -- create backwardDataAlgorithm descriptors
+ local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
+ local algWorkspaceLimit = self.workspace_limit
+ or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+ if self.fastest_mode or cudnn.fastest == true then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
+ end
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ if autotunerCache[3][autotunerHash] then
+ algType[0] = autotunerCache[3][autotunerHash]
+ if cudnn.verbose then
+ print('Autotuning VMC BWD: using cached algo = ', algType[0], ' for: ', autotunerHash)
+ end
+ else
+ local perfResults = ffi.new("cudnnConvolutionBwdDataAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.oDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ autotunerCache[3][autotunerHash] = perfResults[0].algo
+ if cudnn.verbose then
+ print(string.format(
+ "Autotuning backwardData: Time: %3.5f Memory: %8d Algorithm: %d"
+ .. " Weight: %15s Input: %15s Output: %15s\n",
+ perfResults[0].time, tonumber(perfResults[0].memory),
+ tonumber(perfResults[0].algo),
+ shape(self.weight), shape(input),
+ shape(self.output)))
+ end
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.oDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
+ algType[0] = self.bdmode or algType[0]
+ self.bwdDataAlgType = algType
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.oDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ algType[0], bufSize:data())
+ maxBufSize = math.max(maxBufSize, bufSize[1])
+
+ self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace()
+ self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float
+ if maxBufSize > self.extraBufferSizeInBytes then
+ self.extraBuffer:resize(math.ceil(maxBufSize/4))
+ self.extraBufferSizeInBytes = maxBufSize
+ end
+ -----------------------------------------------------------------------
- -----------------------------------------------------------------
if not batch then
self.gradInput = self.gradInput:view(self.gradInput:size(2),
self.gradInput:size(3),
@@ -239,7 +342,7 @@ function VolumetricConvolution:updateGradInput(input, gradOutput)
'gradOutput has to be a 4D or 5D tensor');
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
- errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(),
+ errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
one:data(),
self.weightDesc[0], self.weight:data(),
self.oDesc[0], gradOutput:data(),
@@ -270,7 +373,7 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
one:data(),
self.biasDesc[0], self.gradBias:data());
-- gradWeight
- errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(),
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(),
self.scaleT:data(),
self.iDesc[0], input:data(),
self.oDesc[0], gradOutput:data(),
diff --git a/ffi.lua b/ffi.lua
index e2b5b16..91ca885 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -1,7 +1,15 @@
local ffi = require 'ffi'
ffi.cdef[[
-size_t cudnnGetVersion();
+
+
+typedef enum {
+ CUDNN_MAJOR = 5,
+ CUDNN_MINOR = 0,
+ CUDNN_PATCHLEVEL = 4,
+ CUDNN_VERSION = (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
+} cudnnVerFakeEnum;
+
struct cudnnContext;
typedef struct cudnnContext *cudnnHandle_t;
@@ -25,22 +33,24 @@ typedef enum
CUDNN_STATUS_LICENSE_ERROR = 10
} cudnnStatus_t;
+/* human-readable error messages*/
const char * cudnnGetErrorString(cudnnStatus_t status);
-typedef struct CUstream_st *cudaStream_t;
-cudnnStatus_t cudnnCreate(cudnnHandle_t *handle);
-cudnnStatus_t cudnnDestroy(cudnnHandle_t handle);
-cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
-cudnnStatus_t cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId);
+cudnnStatus_t cudnnCreate (cudnnHandle_t *handle);
+cudnnStatus_t cudnnDestroy (cudnnHandle_t handle);
+cudnnStatus_t cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId);
+cudnnStatus_t cudnnGetStream (cudnnHandle_t handle, cudaStream_t *streamId);
/* Data structures to represent Image/Filter and the Neural Network Layer */
-typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t;
-typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t;
-typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t;
-typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t;
-typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t;
-typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t;
+typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t;
+typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t;
+typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t;
+typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t;
+typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t;
+typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t;
+typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t;
+typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t;
/*
* CUDNN data type
*/
@@ -75,20 +85,20 @@ typedef enum
cudnnStatus_t cudnnSetTensor4dDescriptor(
cudnnTensorDescriptor_t tensorDesc,
cudnnTensorFormat_t format,
- cudnnDataType_t dataType, /* image data type */
- int n, /* number of inputs (batch size) */
- int c, /* number of input feature maps */
- int h, /* height of input section */
- int w ); /* width of input section */
+ cudnnDataType_t dataType, /* image data type*/
+ int n, /* number of inputs (batch size)*/
+ int c, /* number of input feature maps*/
+ int h, /* height of input section*/
+ int w ); /* width of input section*/
cudnnStatus_t cudnnSetTensor4dDescriptorEx(
cudnnTensorDescriptor_t tensorDesc,
- cudnnDataType_t dataType, /* image data type */
- int n, /* number of inputs (batch size) */
- int c, /* number of input feature maps */
- int h, /* height of input section */
- int w, /* width of input section */
+ cudnnDataType_t dataType, /* image data type*/
+ int n, /* number of inputs (batch size)*/
+ int c, /* number of input feature maps*/
+ int h, /* height of input section*/
+ int w, /* width of input section*/
int nStride,
int cStride,
int hStride,
@@ -96,11 +106,11 @@ cudnnStatus_t cudnnSetTensor4dDescriptorEx(
cudnnStatus_t cudnnGetTensor4dDescriptor(
const cudnnTensorDescriptor_t tensorDesc,
- cudnnDataType_t *dataType, /* image data type */
- int *n, /* number of inputs (batch size) */
- int *c, /* number of input feature maps */
- int *h, /* height of input section */
- int *w, /* width of input section */
+ cudnnDataType_t *dataType, /* image data type*/
+ int *n, /* number of inputs (batch size)*/
+ int *c, /* number of input feature maps*/
+ int *h, /* height of input section*/
+ int *w, /* width of input section*/
int *nStride,
int *cStride,
int *hStride,
@@ -159,55 +169,69 @@ cudnnStatus_t cudnnTransformTensor(
const cudnnTensorDescriptor_t yDesc,
void *y );
-typedef enum
-{
- /* add one image to every feature maps of each input */
- CUDNN_ADD_IMAGE = 0,
- CUDNN_ADD_SAME_HW = 0,
-
- /* add a set of feature maps to a batch of inputs : tensorBias has n=1 , same number of features as x and y */
- CUDNN_ADD_FEATURE_MAP = 1,
- CUDNN_ADD_SAME_CHW = 1,
- /* add a tensor of size 1,c,1,1 to every corresponding point of n,c,h,w input */
- CUDNN_ADD_SAME_C = 2,
-
- /* add 2 tensors with same n,c,h,w */
- CUDNN_ADD_FULL_TENSOR = 3
-} cudnnAddMode_t;
-
-/* Tensor Bias addition : y = alpha * b + beta * y */
+/* Tensor Bias addition : C = alpha * A + beta * C */
cudnnStatus_t cudnnAddTensor(
cudnnHandle_t handle,
const void *alpha,
- const cudnnTensorDescriptor_t bDesc,
- const void *b,
+ const cudnnTensorDescriptor_t aDesc,
+ const void *A,
const void *beta,
- cudnnTensorDescriptor_t yDesc,
- void *y );
+ const cudnnTensorDescriptor_t cDesc,
+ void *C );
-/* cudnnAddTensor_v3 is now mapped to cudnnAddTensor
- and will be removed at the same time as cudnnAddTensor_v2
- Use cudnnAddTensor instead
- */
-cudnnStatus_t cudnnAddTensor_v3(
+/*
+* CUDNN OpTensor op type
+*/
+typedef enum
+{
+ CUDNN_OP_TENSOR_ADD = 0,
+ CUDNN_OP_TENSOR_MUL = 1,
+ CUDNN_OP_TENSOR_MIN = 2,
+ CUDNN_OP_TENSOR_MAX = 3,
+} cudnnOpTensorOp_t;
+
+cudnnStatus_t cudnnCreateOpTensorDescriptor(
+ cudnnOpTensorDescriptor_t *opTensorDesc );
+
+cudnnStatus_t cudnnSetOpTensorDescriptor(
+ cudnnOpTensorDescriptor_t opTensorDesc,
+ cudnnOpTensorOp_t opTensorOp,
+ cudnnDataType_t opTensorCompType,
+ cudnnNanPropagation_t opTensorNanOpt );
+
+cudnnStatus_t cudnnGetOpTensorDescriptor(
+ const cudnnOpTensorDescriptor_t opTensorDesc,
+ cudnnOpTensorOp_t *opTensorOp,
+ cudnnDataType_t *opTensorCompType,
+ cudnnNanPropagation_t *opTensorNanOpt );
+
+cudnnStatus_t cudnnDestroyOpTensorDescriptor(
+ cudnnOpTensorDescriptor_t opTensorDesc );
+
+/* Tensor Bias operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
+cudnnStatus_t cudnnOpTensor(
cudnnHandle_t handle,
- const void *alpha,
+ const cudnnOpTensorDescriptor_t opTensorDesc,
+ const void *alpha1,
+ const cudnnTensorDescriptor_t aDesc,
+ const void *A,
+ const void *alpha2,
const cudnnTensorDescriptor_t bDesc,
- const void *b,
+ const void *B,
const void *beta,
- cudnnTensorDescriptor_t yDesc,
- void *y );
+ const cudnnTensorDescriptor_t cDesc,
+ void *C );
/* Set all values of a tensor to a given value : y[i] = value[0] */
-cudnnStatus_t cudnnSetTensor(
+cudnnStatus_t cudnnSetTensor(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t yDesc,
void *y,
const void *valuePtr );
/* Scale all values of a tensor by a given factor : y[i] = alpha * y[i] */
-cudnnStatus_t cudnnScaleTensor(
+cudnnStatus_t cudnnScaleTensor(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t yDesc,
void *y,
@@ -224,53 +248,33 @@ typedef enum
/* Create an instance of FilterStruct */
-cudnnStatus_t cudnnCreateFilterDescriptor(
+cudnnStatus_t cudnnCreateFilterDescriptor(
cudnnFilterDescriptor_t *filterDesc );
-cudnnStatus_t cudnnSetFilter4dDescriptor(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type */
- int k, /* number of output feature maps */
- int c, /* number of input feature maps */
- int h, /* height of each input filter */
- int w ); /* width of each input fitler */
-cudnnStatus_t cudnnSetFilter4dDescriptor_v4(
+cudnnStatus_t cudnnSetFilter4dDescriptor(
cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type */
+ cudnnDataType_t dataType, /* image data type*/
cudnnTensorFormat_t format,
- int k, /* number of output feature maps */
- int c, /* number of input feature maps */
- int h, /* height of each input filter */
- int w ); /* width of each input fitler */
+ int k, /* number of output feature maps*/
+ int c, /* number of input feature maps*/
+ int h, /* height of each input filter*/
+ int w ); /* width of each input filter*/
-cudnnStatus_t cudnnGetFilter4dDescriptor(
- const cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t *dataType, /* image data type */
- int *k, /* number of output feature maps */
- int *c, /* number of input feature maps */
- int *h, /* height of each input filter */
- int *w ); /* width of each input fitler */
-cudnnStatus_t cudnnGetFilter4dDescriptor_v4(
+cudnnStatus_t cudnnGetFilter4dDescriptor(
const cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t *dataType, /* image data type */
+ cudnnDataType_t *dataType, /* image data type*/
cudnnTensorFormat_t *format,
- int *k, /* number of output feature maps */
- int *c, /* number of input feature maps */
- int *h, /* height of each input filter */
- int *w ); /* width of each input fitler */
+ int *k, /* number of output feature maps*/
+ int *c, /* number of input feature maps*/
+ int *h, /* height of each input filter*/
+ int *w ); /* width of each input filter*/
-cudnnStatus_t cudnnSetFilterNdDescriptor(
- cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type */
- int nbDims,
- const int filterDimA[] );
-
-cudnnStatus_t cudnnSetFilterNdDescriptor_v4(
+cudnnStatus_t cudnnSetFilterNdDescriptor(
cudnnFilterDescriptor_t filterDesc,
- cudnnDataType_t dataType, /* image data type */
+ cudnnDataType_t dataType, /* image data type*/
cudnnTensorFormat_t format,
int nbDims,
const int filterDimA[] );
@@ -278,47 +282,63 @@ cudnnStatus_t cudnnSetFilterNdDescriptor_v4(
cudnnStatus_t cudnnGetFilterNdDescriptor(
const cudnnFilterDescriptor_t filterDesc,
int nbDimsRequested,
- cudnnDataType_t *dataType,
- int *nbDims,
- int filterDimA[] );
-
-cudnnStatus_t cudnnGetFilterNdDescriptor_v4(
- const cudnnFilterDescriptor_t filterDesc,
- int nbDimsRequested,
- cudnnDataType_t *dataType,
+ cudnnDataType_t *dataType, /* image data type*/
cudnnTensorFormat_t *format,
int *nbDims,
int filterDimA[] );
-cudnnStatus_t cudnnDestroyFilterDescriptor( cudnnFilterDescriptor_t filterDesc);
+
+cudnnStatus_t cudnnDestroyFilterDescriptor(
+ cudnnFilterDescriptor_t filterDesc );
/* Create an instance of convolution descriptor */
-cudnnStatus_t cudnnCreateConvolutionDescriptor(
+cudnnStatus_t cudnnCreateConvolutionDescriptor(
cudnnConvolutionDescriptor_t *convDesc );
-cudnnStatus_t cudnnSetConvolution2dDescriptor(
+cudnnStatus_t cudnnSetConvolution2dDescriptor(
cudnnConvolutionDescriptor_t convDesc,
- int pad_h, /* zero-padding height */
- int pad_w, /* zero-padding width */
- int u, /* vertical filter stride */
- int v, /* horizontal filter stride */
- int upscalex, /* upscale the input in x-direction */
- int upscaley, /* upscale the input in y-direction */
+ int pad_h, /* zero-padding height*/
+ int pad_w, /* zero-padding width*/
+ int u, /* vertical filter stride*/
+ int v, /* horizontal filter stride*/
+ int upscalex, /* upscale the input in x-direction*/
+ int upscaley, /* upscale the input in y-direction*/
cudnnConvolutionMode_t mode );
-
-cudnnStatus_t cudnnGetConvolution2dDescriptor(
+cudnnStatus_t cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc,
+ int pad_h, /* zero-padding height*/
+ int pad_w, /* zero-padding width*/
+ int u, /* vertical filter stride*/
+ int v, /* horizontal filter stride*/
+ int upscalex, /* upscale the input in x-direction*/
+ int upscaley, /* upscale the input in y-direction*/
+ cudnnConvolutionMode_t mode,
+ cudnnDataType_t dataType
+ );
+
+cudnnStatus_t cudnnGetConvolution2dDescriptor(
const cudnnConvolutionDescriptor_t convDesc,
- int *pad_h, /* zero-padding height */
- int *pad_w, /* zero-padding width */
- int *u, /* vertical filter stride */
- int *v, /* horizontal filter stride */
- int *upscalex, /* upscale the input in x-direction */
- int *upscaley, /* upscale the input in y-direction */
+ int *pad_h, /* zero-padding height*/
+ int *pad_w, /* zero-padding width*/
+ int *u, /* vertical filter stride*/
+ int *v, /* horizontal filter stride*/
+ int *upscalex, /* upscale the input in x-direction*/
+ int *upscaley, /* upscale the input in y-direction*/
cudnnConvolutionMode_t *mode );
+cudnnStatus_t cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc,
+ int* pad_h, /* zero-padding height*/
+ int* pad_w, /* zero-padding width*/
+ int* u, /* vertical filter stride*/
+ int* v, /* horizontal filter stride*/
+ int* upscalex, /* upscale the input in x-direction*/
+ int* upscaley, /* upscale the input in y-direction*/
+ cudnnConvolutionMode_t* mode,
+ cudnnDataType_t *dataType
+ );
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
-cudnnStatus_t cudnnGetConvolution2dForwardOutputDim(
+cudnnStatus_t cudnnGetConvolution2dForwardOutputDim(
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t inputTensorDesc,
const cudnnFilterDescriptor_t filterDesc,
@@ -328,16 +348,16 @@ cudnnStatus_t cudnnGetConvolution2dForwardOutputDim(
int *w );
-cudnnStatus_t cudnnSetConvolutionNdDescriptor(
+cudnnStatus_t cudnnSetConvolutionNdDescriptor(
cudnnConvolutionDescriptor_t convDesc,
int arrayLength, /* nbDims-2 size */
const int padA[],
const int filterStrideA[],
const int upscaleA[],
cudnnConvolutionMode_t mode,
- cudnnDataType_t dataType ); /* convolution data type */
+ cudnnDataType_t dataType ); /* convolution data type*/
-cudnnStatus_t cudnnGetConvolutionNdDescriptor(
+cudnnStatus_t cudnnGetConvolutionNdDescriptor(
const cudnnConvolutionDescriptor_t convDesc,
int arrayLengthRequested,
int *arrayLength,
@@ -345,36 +365,11 @@ cudnnStatus_t cudnnGetConvolutionNdDescriptor(
int strideA[],
int upscaleA[],
cudnnConvolutionMode_t *mode,
- cudnnDataType_t *dataType ); /* convolution data type */
+ cudnnDataType_t *dataType ); /* convolution data type*/
-/* cudnnSetConvolutionNdDescriptor_v3 is now mapped to cudnnSetConvolutionNdDescriptor
- and will be removed at the same time than cudnnSetConvolutionNdDescriptor_v2
- Use cudnnSetConvolutionNdDescriptor instead */
-cudnnStatus_t cudnnSetConvolutionNdDescriptor_v3(
- cudnnConvolutionDescriptor_t convDesc,
- int arrayLength, /* nbDims-2 size */
- const int padA[],
- const int filterStrideA[],
- const int upscaleA[],
- cudnnConvolutionMode_t mode,
- cudnnDataType_t dataType ); /* convolution data type */
-
-/* cudnnGetConvolutionNdDescriptor_v3 is now mapped to cudnnGetConvolutionNdDescriptor
- and will be removed at the same time thancudnnGetConvolutionNdDescriptor_v2
- Use cudnnGetConvolutionNdDescriptor instead
- */
-cudnnStatus_t cudnnGetConvolutionNdDescriptor_v3(
- const cudnnConvolutionDescriptor_t convDesc,
- int arrayLengthRequested,
- int *arrayLength,
- int padA[],
- int strideA[],
- int upscaleA[],
- cudnnConvolutionMode_t *mode,
- cudnnDataType_t *dataType ); /* convolution data type */
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
-cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim(
+cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim(
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t inputTensorDesc,
const cudnnFilterDescriptor_t filterDesc,
@@ -402,8 +397,8 @@ typedef enum
CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
- /* CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_BATCHED_GEMM = 100, */
- CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6
} cudnnConvolutionFwdAlgo_t;
typedef struct {
@@ -423,14 +418,30 @@ cudnnStatus_t cudnnFindConvolutionForwardAlgorithm(
int *returnedAlgoCount,
cudnnConvolutionFwdAlgoPerf_t *perfResults );
+cudnnStatus_t cudnnFindConvolutionForwardAlgorithmEx(
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
+ const cudnnFilterDescriptor_t wDesc,
+ const void *w,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t yDesc,
+ void *y,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionFwdAlgoPerf_t *perfResults,
+ void *workSpace,
+ size_t workSpaceSizeInBytes );
+
+
cudnnStatus_t cudnnGetConvolutionForwardAlgorithm(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t xDesc,
- const cudnnFilterDescriptor_t filterDesc,
+ const cudnnFilterDescriptor_t wDesc,
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t yDesc,
cudnnConvolutionFwdPreference_t preference,
- size_t memoryLimitInbytes,
+ size_t memoryLimitInBytes,
cudnnConvolutionFwdAlgo_t *algo );
/*
@@ -441,7 +452,7 @@ cudnnStatus_t cudnnGetConvolutionForwardAlgorithm(
cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t xDesc,
- const cudnnFilterDescriptor_t filterDesc,
+ const cudnnFilterDescriptor_t wDesc,
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t yDesc,
cudnnConvolutionFwdAlgo_t algo,
@@ -487,10 +498,10 @@ typedef enum
typedef enum
{
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic*/
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3 /* non-deterministic, algo0 with workspace */
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3 /* non-deterministic, algo0 with workspace*/
} cudnnConvolutionBwdFilterAlgo_t;
@@ -506,20 +517,35 @@ cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm(
const cudnnTensorDescriptor_t xDesc,
const cudnnTensorDescriptor_t dyDesc,
const cudnnConvolutionDescriptor_t convDesc,
- const cudnnFilterDescriptor_t wDesc,
+ const cudnnFilterDescriptor_t dwDesc,
const int requestedAlgoCount,
- int *returnedAlgoCount,
- cudnnConvolutionBwdFilterAlgoPerf_t*perfResults );
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults );
+
+cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithmEx(
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
+ const cudnnTensorDescriptor_t dyDesc,
+ const void *y,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnFilterDescriptor_t dwDesc,
+ void *dw,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
+ void *workSpace,
+ size_t workSpaceSizeInBytes );
cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm(
- cudnnHandle_t handle,
- const cudnnTensorDescriptor_t xDesc,
- const cudnnTensorDescriptor_t dyDesc,
- const cudnnConvolutionDescriptor_t convDesc,
- const cudnnFilterDescriptor_t wDesc,
+ cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t xDesc,
+ const cudnnTensorDescriptor_t dyDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnFilterDescriptor_t dwDesc,
cudnnConvolutionBwdFilterPreference_t preference,
- size_t memoryLimitInbytes,
- cudnnConvolutionBwdFilterAlgo_t *algo );
+ size_t memoryLimitInBytes,
+ cudnnConvolutionBwdFilterAlgo_t *algo );
/*
* convolution algorithm (which requires potentially some workspace)
@@ -550,24 +576,6 @@ cudnnStatus_t cudnnConvolutionBackwardFilter(
const cudnnFilterDescriptor_t dwDesc,
void *dw );
-/* cudnnConvolutionBackwardFilter_v3 is now mapped to cudnnConvolutionBackwardFilter
- and will be removed at the same time thancudnnConvolutionBackwardFilter_v2
- Use cudnnConvolutionBackwardFilter instead */
-cudnnStatus_t cudnnConvolutionBackwardFilter_v3(
- cudnnHandle_t handle,
- const void *alpha,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const cudnnTensorDescriptor_t dyDesc,
- const void *dy,
- const cudnnConvolutionDescriptor_t convDesc,
- cudnnConvolutionBwdFilterAlgo_t algo,
- void *workSpace,
- size_t workSpaceSizeInBytes,
- const void *beta,
- const cudnnFilterDescriptor_t dwDesc,
- void *dw );
-
/*********************************************************/
/* helper function to provide the convolution algo that fit best the requirement */
typedef enum
@@ -579,10 +587,11 @@ typedef enum
typedef enum
{
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic*/
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4
} cudnnConvolutionBwdDataAlgo_t;
typedef struct {
@@ -603,6 +612,21 @@ cudnnStatus_t cudnnFindConvolutionBackwardDataAlgorithm(
int *returnedAlgoCount,
cudnnConvolutionBwdDataAlgoPerf_t *perfResults );
+cudnnStatus_t cudnnFindConvolutionBackwardDataAlgorithmEx(
+ cudnnHandle_t handle,
+ const cudnnFilterDescriptor_t wDesc,
+ const void *w,
+ const cudnnTensorDescriptor_t dyDesc,
+ const void *dy,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t dxDesc,
+ void *dx,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
+ void *workSpace,
+ size_t workSpaceSizeInBytes );
+
cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(
cudnnHandle_t handle,
const cudnnFilterDescriptor_t wDesc,
@@ -610,7 +634,7 @@ cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t dxDesc,
cudnnConvolutionBwdDataPreference_t preference,
- size_t memoryLimitInbytes,
+ size_t memoryLimitInBytes,
cudnnConvolutionBwdDataAlgo_t *algo );
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
@@ -639,23 +663,6 @@ cudnnStatus_t cudnnConvolutionBackwardData(
const cudnnTensorDescriptor_t dxDesc,
void *dx );
-/* cudnnConvolutionBackwardData_v3 is now mapped to cudnnConvolutionBackwardData
- and will be removed at the same time thancudnnConvolutionBackwardData_v2
- Use cudnnConvolutionBackwardData instead */
-cudnnStatus_t cudnnConvolutionBackwardData_v3(
- cudnnHandle_t handle,
- const void *alpha,
- const cudnnFilterDescriptor_t wDesc,
- const void *w,
- const cudnnTensorDescriptor_t dyDesc,
- const void *dy,
- const cudnnConvolutionDescriptor_t convDesc,
- cudnnConvolutionBwdDataAlgo_t algo,
- void *workSpace,
- size_t workSpaceSizeInBytes,
- const void *beta,
- const cudnnTensorDescriptor_t dxDesc,
- void *dx );
cudnnStatus_t cudnnIm2Col(
cudnnHandle_t handle,
@@ -687,7 +694,7 @@ typedef enum
/* Function to perform forward softmax */
cudnnStatus_t cudnnSoftmaxForward(
cudnnHandle_t handle,
- cudnnSoftmaxAlgorithm_t algorithm,
+ cudnnSoftmaxAlgorithm_t algo,
cudnnSoftmaxMode_t mode,
const void *alpha,
const cudnnTensorDescriptor_t xDesc,
@@ -699,7 +706,7 @@ cudnnStatus_t cudnnSoftmaxForward(
/* Function to perform backward softmax */
cudnnStatus_t cudnnSoftmaxBackward(
cudnnHandle_t handle,
- cudnnSoftmaxAlgorithm_t algorithm,
+ cudnnSoftmaxAlgorithm_t algo,
cudnnSoftmaxMode_t mode,
const void *alpha,
const cudnnTensorDescriptor_t yDesc,
@@ -716,8 +723,8 @@ cudnnStatus_t cudnnSoftmaxBackward(
typedef enum
{
CUDNN_POOLING_MAX = 0,
- CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values */
- CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values */
+ CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values*/
+ CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values*/
CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility
} cudnnPoolingMode_t;
@@ -728,16 +735,6 @@ cudnnStatus_t cudnnCreatePoolingDescriptor(
cudnnStatus_t cudnnSetPooling2dDescriptor(
cudnnPoolingDescriptor_t poolingDesc,
cudnnPoolingMode_t mode,
- int windowHeight,
- int windowWidth,
- int verticalPadding,
- int horizontalPadding,
- int verticalStride,
- int horizontalStride );
-
-cudnnStatus_t cudnnSetPooling2dDescriptor_v4(
- cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t mode,
cudnnNanPropagation_t maxpoolingNanOpt,
int windowHeight,
int windowWidth,
@@ -749,16 +746,6 @@ cudnnStatus_t cudnnSetPooling2dDescriptor_v4(
cudnnStatus_t cudnnGetPooling2dDescriptor(
const cudnnPoolingDescriptor_t poolingDesc,
cudnnPoolingMode_t *mode,
- int *windowHeight,
- int *windowWidth,
- int *verticalPadding,
- int *horizontalPadding,
- int *verticalStride,
- int *horizontalStride );
-
-cudnnStatus_t cudnnGetPooling2dDescriptor_v4(
- const cudnnPoolingDescriptor_t poolingDesc,
- cudnnPoolingMode_t *mode,
cudnnNanPropagation_t *maxpoolingNanOpt,
int *windowHeight,
int *windowWidth,
@@ -770,14 +757,6 @@ cudnnStatus_t cudnnGetPooling2dDescriptor_v4(
cudnnStatus_t cudnnSetPoolingNdDescriptor(
cudnnPoolingDescriptor_t poolingDesc,
const cudnnPoolingMode_t mode,
- int nbDims,
- const int windowDimA[],
- const int paddingA[],
- const int strideA[] );
-
-cudnnStatus_t cudnnSetPoolingNdDescriptor_v4(
- cudnnPoolingDescriptor_t poolingDesc,
- const cudnnPoolingMode_t mode,
const cudnnNanPropagation_t maxpoolingNanOpt,
int nbDims,
const int windowDimA[],
@@ -786,15 +765,6 @@ cudnnStatus_t cudnnSetPoolingNdDescriptor_v4(
cudnnStatus_t cudnnGetPoolingNdDescriptor(
const cudnnPoolingDescriptor_t poolingDesc,
- const int nbDimsRequested,
- cudnnPoolingMode_t *mode,
- int *nbDims,
- int windowDimA[],
- int paddingA[],
- int strideA[] );
-
-cudnnStatus_t cudnnGetPoolingNdDescriptor_v4(
- const cudnnPoolingDescriptor_t poolingDesc,
int nbDimsRequested,
cudnnPoolingMode_t *mode,
cudnnNanPropagation_t *maxpoolingNanOpt,
@@ -812,10 +782,10 @@ cudnnStatus_t cudnnGetPoolingNdForwardOutputDim(
cudnnStatus_t cudnnGetPooling2dForwardOutputDim(
const cudnnPoolingDescriptor_t poolingDesc,
const cudnnTensorDescriptor_t inputTensorDesc,
- int *outN,
- int *outC,
- int *outH,
- int *outW );
+ int *n,
+ int *c,
+ int *h,
+ int *w );
/* Destroy an instance of pooling descriptor */
@@ -826,7 +796,7 @@ cudnnStatus_t cudnnDestroyPoolingDescriptor(
/* Function to perform forward pooling */
cudnnStatus_t cudnnPoolingForward(
- cudnnHandle_t handle,
+ cudnnHandle_t handle,
const cudnnPoolingDescriptor_t poolingDesc,
const void *alpha,
const cudnnTensorDescriptor_t xDesc,
@@ -883,16 +853,6 @@ cudnnStatus_t cudnnDestroyActivationDescriptor(
/* Function to perform forward activation */
cudnnStatus_t cudnnActivationForward(
cudnnHandle_t handle,
- cudnnActivationMode_t mode,
- const void *alpha,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t yDesc,
- void *y );
-
-cudnnStatus_t cudnnActivationForward_v4(
- cudnnHandle_t handle,
cudnnActivationDescriptor_t activationDesc,
const void *alpha,
const cudnnTensorDescriptor_t xDesc,
@@ -904,20 +864,6 @@ cudnnStatus_t cudnnActivationForward_v4(
/* Function to perform backward activation */
cudnnStatus_t cudnnActivationBackward(
cudnnHandle_t handle,
- cudnnActivationMode_t mode,
- const void *alpha,
- const cudnnTensorDescriptor_t yDesc,
- const void *y,
- const cudnnTensorDescriptor_t dyDesc,
- const void *dy,
- const cudnnTensorDescriptor_t xDesc,
- const void *x,
- const void *beta,
- const cudnnTensorDescriptor_t dxDesc,
- void *dx );
-
-cudnnStatus_t cudnnActivationBackward_v4(
- cudnnHandle_t handle,
cudnnActivationDescriptor_t activationDesc,
const void *alpha,
const cudnnTensorDescriptor_t yDesc,
@@ -930,37 +876,41 @@ cudnnStatus_t cudnnActivationBackward_v4(
const cudnnTensorDescriptor_t dxDesc,
void *dx );
-/* Create an instance of LRN (Local Response Normalization) descriptor */
-/* This function will set lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper */
+/*
+* Create an instance of LRN (Local Response Normalization) descriptor
+* Uses lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper
+*/
cudnnStatus_t cudnnCreateLRNDescriptor(
cudnnLRNDescriptor_t *normDesc );
-typedef enum { CUDNN_LRN_MIN_N = 1, /* minimum allowed lrnN */
+typedef enum { CUDNN_LRN_MIN_N = 1, /* minimum allowed lrnN */
CUDNN_LRN_MAX_N = 16 } /* maximum allowed lrnN */
- LRN_MinMaxFakeEnum;
+ LRN_MinMaxFakeEnum;
-/* define CUDNN_LRN_MIN_K 1e-5 -- minimum allowed lrnK */
-/* define CUDNN_LRN_MIN_BETA 0.01 -- minimum allowed lrnBeta */
+/* static const float CUDNN_LRN_MIN_K = 1e-5; */ /* minimum allowed lrnK*/
+/* static const float CUDNN_LRN_MIN_BETA = 0.01; */ /* minimum allowed lrnBeta*/
-/* LRN layer mode, currently only cross-channel is supported (across the tensor's dimA[1] dimension) */
+/* LRN layer mode */
typedef enum
{
- CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0,
+ CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0,/* Normalize across tensor's dimA[1] dimension*/
} cudnnLRNMode_t;
-/* LRN uses a window [center-lookBehind, center+lookAhead], where */
-/* lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1. */
-/* So for n=10, the window is [k-4...k...k+5] with a total of 10 samples. */
-/* Values of double parameters will be cast down to tensor data type. */
+/*
+* Uses a window [center-lookBehind, center+lookAhead], where
+* lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1.
+* Values of double parameters cast to tensor data type.
+*/
cudnnStatus_t cudnnSetLRNDescriptor(
cudnnLRNDescriptor_t normDesc,
unsigned lrnN,
double lrnAlpha,
double lrnBeta,
double lrnK );
-
-/* Retrieve the settings currently stored in an LRN layer descriptor */
-/* Any of the provided pointers can be NULL (no corresponding value will be returned) */
+/*
+* Retrieve the settings currently stored in an LRN layer descriptor
+* Any of the provided pointers can be NULL (no corresponding value will be returned)
+*/
cudnnStatus_t cudnnGetLRNDescriptor(
cudnnLRNDescriptor_t normDesc,
unsigned* lrnN,
@@ -968,13 +918,12 @@ cudnnStatus_t cudnnGetLRNDescriptor(
double* lrnBeta,
double* lrnK );
-/* Destroy an instance of LRN descriptor */
+/* Destroy an instance of LRN descriptor */
cudnnStatus_t cudnnDestroyLRNDescriptor( cudnnLRNDescriptor_t lrnDesc );
-/* LRN functions: of the form "output = alpha * normalize(x) + beta * old_y" */
+/* LRN functions: output = alpha * normalize(x) + beta * old_y */
-/* Function to perform LRN forward cross-channel computation */
-/* Values of double parameters will be cast down to tensor data type */
+/* LRN cross-channel forward computation. Double parameters cast to tensor data type */
cudnnStatus_t cudnnLRNCrossChannelForward(
cudnnHandle_t handle,
cudnnLRNDescriptor_t normDesc,
@@ -986,8 +935,7 @@ cudnnStatus_t cudnnLRNCrossChannelForward(
const cudnnTensorDescriptor_t yDesc,
void *y );
-/* Function to perform LRN cross-channel backpropagation */
-/* values of double parameters will be cast down to tensor data type */
+/* LRN cross-channel backward computation. Double parameters cast to tensor data type */
cudnnStatus_t cudnnLRNCrossChannelBackward(
cudnnHandle_t handle,
cudnnLRNDescriptor_t normDesc,
@@ -1008,16 +956,15 @@ typedef enum
CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0,
} cudnnDivNormMode_t;
-/* LCN/divisive normalization functions: of the form "y = alpha * normalize(x) + beta * y" */
-/* means can be NULL to reproduce Caffe's LRN within-channel behavior */
+/* LCN/divisive normalization functions: y = alpha * normalize(x) + beta * y */
cudnnStatus_t cudnnDivisiveNormalizationForward(
cudnnHandle_t handle,
cudnnLRNDescriptor_t normDesc,
cudnnDivNormMode_t mode,
const void *alpha,
- const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
+ const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2*/
const void *x,
- const void *means, /* if NULL, means are assumed to be zero */
+ const void *means, /* if NULL, means are assumed to be zero*/
void *temp,
void *temp2,
const void *beta,
@@ -1029,157 +976,114 @@ cudnnStatus_t cudnnDivisiveNormalizationBackward(
cudnnLRNDescriptor_t normDesc,
cudnnDivNormMode_t mode,
const void *alpha,
- const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2*/
const void *x,
- const void *means, /* if NULL, means are assumed to be zero */
+ const void *means, /* if NULL, means are assumed to be zero*/
const void *dy,
void *temp,
void *temp2,
const void *beta,
- const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
- void *dx, /* output x differential */
- void *dMeans ); /* output means differential, can be NULL */
+ const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans*/
+ void *dx, /* output x differential*/
+ void *dMeans ); /* output means differential, can be NULL*/
typedef enum
{
- /* Use for non-convolution layers. */
- /* bnScale, bnBias tensors dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
+ /* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice)*/
CUDNN_BATCHNORM_PER_ACTIVATION = 0,
- /* Use after convolution layers. bnScale, bnBias tensors dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
+ /*bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors)*/
CUDNN_BATCHNORM_SPATIAL = 1,
} cudnnBatchNormMode_t;
-/* CUDNN_BN_MIN_EPSILON 1e-5 -- Minimum epsilon allowed to be used in the Batch Normalization formula */
+/* static const float CUDNN_BN_MIN_EPSILON = 1e-5; */ /* Minimum epsilon allowed to be used in the Batch Normalization formula*/
-/* Derives a tensor descriptor from layer data descriptor for BatchNormalization scale, invVariance, bnBias, bnScale subtensors */
-/* Use the tensor desc produced by these functions as the bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc parameters in */
-/* Spatial and Per-activation Batch Normalization forward and backward functions. */
-/* Note - derivedBnDesc has to be first created using cudnnCreateTensorDescriptor */
-/* Note - dataDesc is the descriptor for the layer data and has to be setup with proper dimensions prior to calling these functions. */
+/*
+* Derives a tensor descriptor from layer data descriptor for BatchNormalization
+* scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
+* bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc in Batch Normalization forward and backward functions.
+*/
cudnnStatus_t cudnnDeriveBNTensorDescriptor(
cudnnTensorDescriptor_t derivedBnDesc,
const cudnnTensorDescriptor_t xDesc,
cudnnBatchNormMode_t mode );
-/* This function performs a forward pass for Batch Normalization layer. */
-/* In addition to computing y = BN(x) it accumulates the moving averages of the mean and inverse variances */
+/* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
cudnnStatus_t cudnnBatchNormalizationForwardTraining(
cudnnHandle_t handle,
cudnnBatchNormMode_t mode,
- const void *alpha, /* alpha[0] = result blend factor */
- const void *beta, /* beta[0] = dest layer blend factor */
+ const void *alpha, /* alpha[0] = result blend factor*/
+ const void *beta, /* beta[0] = dest layer blend factor*/
const cudnnTensorDescriptor_t xDesc,
- const void *x, /* NxCxHxW */
- const cudnnTensorDescriptor_t yDesc,
- void *y, /* NxCxHxW */
-
- /* Same shared desc for all the 6 tensors below in the argument list. */
- /* Note that the data type for this descriptor has to be set as follows: */
- /* type = (typeOf(x) == half) ? float : typeof(x) */
- /* The dimensions for this tensor descriptor are dependent on the normalization mode */
- /* For spatial normalization the tensors are expected to be 1D (of size C) */
- /* (in this case normalization is performed across NxHxW) */
- /* In per-activation mode the normalization is performed across N dimension only */
- /* So the tensors are expected to have dimensions of CxHxW */
+ const void *x, /* NxCxHxW*/
+ const cudnnTensorDescriptor_t yDesc,
+ void *y, /* NxCxHxW*/
+
+ /* Shared desc for the next 6 tensors in the argument list.
+ Data type to be set as follows:
+ type = (typeOf(x) == double) ? double : float
+ Dimensions for this descriptor depend on normalization mode
+ - Spatial Normalization : tensors are expected to have dims 1xCx1x1
+ (normalization is performed across NxHxW)
+ - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
+ (normalization is performed across N) */
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
- /* Note - bnScale is 'gamma' in paper's notation */
- const void *bnScale, /* Mode-dependent dims */
- /* Note - this bias parameter can effectively replace the bias in Conv and FCN layers */
- /* (Which can be set to zero for efficiency) */
- /* Note - bnBias is 'beta' in paper's notation */
- const void *bnBias, /* Mode-dependent dims */
-
- /* It is required that factor=1 is used for the very first call of a complete training cycle. */
- /* This is necessary to properly initialize the moving average. */
- /* Use a factor=1/(1+n) at N-th call to the function to get */
- /* Cumulative Moving Average (CMA) behavior */
- /* CMA[n] = (x[1]+...+x[n])/n */
- /* Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = */
- /* ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = */
- /* CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
+ /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation*/
+ const void *bnScale,
+ const void *bnBias,
+
+ /* MUST use factor=1 in the very first call of a complete training cycle.
+ Use a factor=1/(1+n) at N-th call to the function to get
+ Cumulative Moving Average (CMA) behavior
+ CMA[n] = (x[1]+...+x[n])/n
+ Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
+ ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
+ CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
double exponentialAverageFactor,
- /* runningMean = newMean*factor + runningMean*(1-factor) */
- /* if isTrainingPhase == false, these tensors will remain const */
- /* and exponentialAverageFactor parameter is not used. */
-
- /* Both of these pointers (running mean, inv variance) can be NULL but only at the same time. */
+ /* Used in Training phase only.
+ runningMean = newMean*factor + runningMean*(1-factor) */
void *resultRunningMean,
- /* The value stored here (or passed as an input in inference mode) is the moving average */
- /* of the expression 1 / sqrt( epsilon + variance[x] ) */
- void *resultRunningInvVariance,
+ /* Output in training mode, input in inference. Is the moving average
+ of variance[x] (factor is applied in the same way as for runningMean) */
+ void *resultRunningVariance,
- /* Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. */
- /* Same epsilon value should be used in forward and backward functions. */
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
double epsilon,
- /* Optional cache to save intermediate results computed during the forward pass */
- /* - these can then be reused to speed up backward pass. For this to work correctly, */
- /* the x data has to remain unchanged until the backward function is called. */
- /* Note that both of these parameters can be NULL but only at the same time. */
- /* It is recommended to use this cache since memory overhead is relatively small. */
+ /* Optionally save intermediate results from the forward pass here
+ - can be reused to speed up backward pass. NULL if unused */
void *resultSaveMean,
void *resultSaveInvVariance );
-/* This function will compute a linear transform of the inputs as follows: */
-/* y[i] = bnScale[k]*(x[i]-estimatedMean[k])*estimatedInvVariance[k] + bnBias[k] */
-/* with bnScale, bnBias, runningMean, runningInvVariance tensors indexed */
-/* according to spatial or per-activation mode (please refer to the paper for details). */
-/* During inference estimatedMean and estimatedVariance are treated */
-/* as const inputs (accumulated and saved during the training phase) */
+/*
+* Performs Batch Normalization during Inference:
+* y[i] = bnScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + bnBias[k]
+* with bnScale, bnBias, runningMean, runningInvVariance tensors indexed
+* according to spatial or per-activation mode. Refer to cudnnBatchNormalizationForwardTraining
+* above for notes on function arguments.
+*/
cudnnStatus_t cudnnBatchNormalizationForwardInference(
cudnnHandle_t handle,
cudnnBatchNormMode_t mode,
-
- const void *alpha, /* alpha[0] = result blend factor */
- const void *beta, /* beta[0] = dest layer blend factor */
-
+ const void *alpha, /* alpha[0] = result blend factor*/
+ const void *beta, /* beta[0] = dest layer blend factor*/
const cudnnTensorDescriptor_t xDesc,
- const void *x, /* NxCxHxW */
- const cudnnTensorDescriptor_t yDesc,
- void *y, /* NxCxHxW */
-
- /* Same desc for all 4 tensors below */
- /* Note that the data type for this descriptor has to be set as follows: */
- /* type = (typeOf(x) == half) ? float : typeof(x) */
- /* The dimensions for this tensor descriptor are dependent on the normalization mode */
- /* For spatial normalization the tensors are expected to be 1D (of size C) */
- /* (in this case normalization is performed across NxHxW) */
- /* In per-activation mode the normalization is performed across N dimension only */
- /* So the tensors are expected to have dimensions of CxHxW */
+ const void *x, /* NxCxHxW*/
+ const cudnnTensorDescriptor_t yDesc,
+ void *y, /* NxCxHxW*/
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
-
- /* Note - bnScale is 'gamma' in paper's notation */
- const void *bnScale, /* Mode-dependent dims */
- /* Note - this bias parameter can effectively replace the bias in Conv and FCN layers */
- /* (Which can be set to zero for efficiency) */
- /* Note - bnBias is 'beta' in paper's notation */
- const void *bnBias, /* Mode-dependent dims */
-
- /* runningMean = newMean*factor + runningMean*(1-factor) */
- /* if isTrainingPhase == false, these tensors will remain const */
- /* and exponentialAverageFactor parameter is not used. */
-
- /* An estimate of the batch mean, can be accumulated over multiple calls to */
- /* batchNormalizationForwardTraining */
+ const void *bnScale,
+ const void *bnBias,
const void *estimatedMean,
- /* An estimate of the expression 1 / sqrt( epsilon + variance[x] ), */
- /* Can also be accumulated over multiple calls to batchNormalizationForwardTraining. */
- const void *estimatedInvVariance,
-
- /* Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. */
- /* Same epsilon value should be used in forward and backward functions. */
+ const void *estimatedVariance,
double epsilon );
-/* This function performs a backward pass for Batch Normalization layer. */
-/* The results are */
-/* 1. x gradient */
-/* 2. bnScale gradient */
-/* 3. bnBias gradient */
+/* Performs backward pass of Batch Normalization layer. Returns x gradient,
+* bnScale gradient and bnBias gradient */
cudnnStatus_t cudnnBatchNormalizationBackward(
cudnnHandle_t handle,
cudnnBatchNormMode_t mode,
@@ -1187,87 +1091,496 @@ cudnnStatus_t cudnnBatchNormalizationBackward(
const void *betaDataDiff,
const void *alphaParamDiff,
const void *betaParamDiff,
-
- const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy*/
const void *x,
- const cudnnTensorDescriptor_t dyDesc,
+ const cudnnTensorDescriptor_t dyDesc,
const void *dy,
- const cudnnTensorDescriptor_t dxDesc,
+ const cudnnTensorDescriptor_t dxDesc,
void *dx,
-
- /* this tensor desc is used for all the 4 tensors below */
+ /* Shared tensor desc for the 4 tensors below */
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
- const void *bnScale, /* bnBias doesn't affect backpropagation */
-
- /* scale and bias diff are not backpropagated below this layer (dead-end computation DAG nodes) */
+ const void *bnScale, /* bnBias doesn't affect backpropagation*/
+ /* scale and bias diff are not backpropagated below this layer */
void *dBnScaleResult,
void *dBnBiasResult,
- /* Constant used to prevent divides by zero variance. Has to be >= CUDNN_BN_MIN_EPSILON. */
- /* Same epsilon value should be used in forward and backward functions. */
+ /* Same epsilon as forward pass */
double epsilon,
- /* Optional cache parameters containing saved intermediate results computed during the forward pass */
- /* For this to work correctly, the x data has to remain unchanged until the backward function is called. */
- /* Note that both of these parameters can be NULL but only at the same time. */
- /* It is recommended to use this cache since memory overhead is relatively small. */
+ /* Optionally cached intermediate results from
+ forward pass */
const void *savedMean,
const void *savedInvVariance );
-/* DEPRECATED API THAT WILL BE REMOVED SOON */
-cudnnStatus_t cudnnSetConvolutionNdDescriptor_v2(
- cudnnConvolutionDescriptor_t convDesc,
- int arrayLength, /* nbDims-2 size */
- const int padA[],
- const int filterStrideA[],
- const int upscaleA[],
- cudnnConvolutionMode_t mode );
-cudnnStatus_t cudnnGetConvolutionNdDescriptor_v2(
- const cudnnConvolutionDescriptor_t convDesc,
- int arrayLengthRequested,
- int *arrayLength,
- int padA[],
- int strideA[],
- int upscaleA[],
- cudnnConvolutionMode_t *mode );
+/* APIs for spatial transformer network*/
+typedef enum {
+ CUDNN_SAMPLER_BILINEAR=0,
+} cudnnSamplerType_t;
+
+cudnnStatus_t cudnnCreateSpatialTransformerDescriptor(
+
+ cudnnSpatialTransformerDescriptor_t *stDesc);
+
+cudnnStatus_t cudnnSetSpatialTransformerNdDescriptor(
+ cudnnSpatialTransformerDescriptor_t stDesc,
+ cudnnSamplerType_t samplerType,
+ cudnnDataType_t dataType,
+ const int nbDims,
+ const int dimA[]);
+
+cudnnStatus_t cudnnDestroySpatialTransformerDescriptor(
+ cudnnSpatialTransformerDescriptor_t stDesc);
+
+cudnnStatus_t cudnnSpatialTfGridGeneratorForward(
+ cudnnHandle_t handle,
+ const cudnnSpatialTransformerDescriptor_t stDesc,
+ const void *theta,
+ void *grid);
+
+cudnnStatus_t cudnnSpatialTfGridGeneratorBackward(
+ cudnnHandle_t handle,
+ const cudnnSpatialTransformerDescriptor_t stDesc,
+ const void *dgrid,
+ void *dtheta);
+
+cudnnStatus_t cudnnSpatialTfGridGeneratorForward(
+ cudnnHandle_t handle,
+ const cudnnSpatialTransformerDescriptor_t stDesc,
+ const void *theta,
+ void *grid);
+
+cudnnStatus_t cudnnSpatialTfSamplerForward(
+ cudnnHandle_t handle,
+ cudnnSpatialTransformerDescriptor_t stDesc,
+ const void *alpha,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
+ const void *grid,
+ const void *beta,
+ cudnnTensorDescriptor_t yDesc,
+ void *y);
+
+cudnnStatus_t cudnnSpatialTfSamplerBackward(
+ cudnnHandle_t handle,
+ cudnnSpatialTransformerDescriptor_t stDesc,
+ const void *alpha,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
+ const void *beta,
+ const cudnnTensorDescriptor_t dxDesc,
+ void *dx,
+ const void *alphaDgrid,
+ const cudnnTensorDescriptor_t dyDesc,
+ const void *dy,
+ const void *grid,
+ const void *betaDgrid,
+ void *dgrid);
+
+typedef struct cudnnDropoutStruct * cudnnDropoutDescriptor_t;
+
+cudnnStatus_t cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t * dropoutDesc);
+
+cudnnStatus_t cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc);
+
+/*helper function to determine size of the states to be passed to cudnnSetDropoutDescriptor */
+cudnnStatus_t cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * sizeInBytes);
+
+/*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
+cudnnStatus_t cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t * sizeInBytes);
+
+cudnnStatus_t cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
+ cudnnHandle_t handle,
+ float dropout,
+ void * states,
+ size_t stateSizeInBytes,
+ unsigned long long seed);
+
+cudnnStatus_t cudnnDropoutForward(cudnnHandle_t handle,
+ const cudnnDropoutDescriptor_t dropoutDesc,
+ const cudnnTensorDescriptor_t xdesc,
+ const void * x,
+ const cudnnTensorDescriptor_t ydesc,
+ void * y,
+ void * reserveSpace,
+ size_t reserveSpaceSizeInBytes);
+
+cudnnStatus_t cudnnDropoutBackward(cudnnHandle_t handle,
+ const cudnnDropoutDescriptor_t dropoutDesc,
+ const cudnnTensorDescriptor_t dydesc,
+ const void * dy,
+ const cudnnTensorDescriptor_t dxdesc,
+ void * dx,
+ void * reserveSpace,
+ size_t reserveSpaceSizeInBytes);
+
+/* RNN API */
+typedef enum
+ {
+ CUDNN_RNN_RELU = 0, /* Stock RNN with ReLu activation*/
+ CUDNN_RNN_TANH = 1, /* Stock RNN with tanh activation*/
+ CUDNN_LSTM = 2, /* LSTM with no peephole connections*/
+ CUDNN_GRU = 3 /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1);*/
+ } cudnnRNNMode_t;
+
+typedef enum
+ {
+ CUDNN_UNIDIRECTIONAL = 0,
+ CUDNN_BIDIRECTIONAL = 1 /* Using output concatination at each step. Do we also want to support output sum?*/
+ } cudnnDirectionMode_t;
-cudnnStatus_t cudnnAddTensor_v2(
+typedef enum
+ {
+ CUDNN_LINEAR_INPUT = 0,
+ CUDNN_SKIP_INPUT = 1
+ } cudnnRNNInputMode_t;
+
+
+struct cudnnRNNStruct;
+typedef struct cudnnRNNStruct* cudnnRNNDescriptor_t;
+
+cudnnStatus_t cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc);
+cudnnStatus_t cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
+
+cudnnStatus_t cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc,
+ int hiddenSize,
+ int seqLength,
+ int numLayers,
+ cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps.*/
+ cudnnRNNInputMode_t inputMode,
+ cudnnDirectionMode_t direction,
+ cudnnRNNMode_t mode,
+ cudnnDataType_t dataType);
+
+
+/* dataType in the RNN descriptor is used to determine math precision*/
+/* dataType in weight descriptors and input descriptors is used to describe storage*/
+
+cudnnStatus_t cudnnGetRNNWorkspaceSize( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t *xDesc,
+ size_t *sizeInBytes
+ );
+
+cudnnStatus_t cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t *xDesc,
+ size_t *sizeInBytes
+ );
+
+
+cudnnStatus_t cudnnGetRNNParamsSize( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t *xDesc,
+ size_t *sizeInBytes
+ );
+
+cudnnStatus_t cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int layer,
+ const cudnnTensorDescriptor_t * xDesc,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const int linLayerID,
+ cudnnFilterDescriptor_t linLayerMatDesc,
+ void ** linLayerMat
+ );
+
+cudnnStatus_t cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const int layer,
+ const cudnnTensorDescriptor_t * xDesc,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const int linLayerID,
+ cudnnFilterDescriptor_t linLayerBiasDesc,
+ void ** linLayerBias
+ );
+
+
+cudnnStatus_t cudnnRNNForwardInference( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t * xDesc,
+ const void * x,
+ const cudnnTensorDescriptor_t hxDesc,
+ const void * hx,
+ const cudnnTensorDescriptor_t cxDesc,
+ const void * cx,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const cudnnTensorDescriptor_t *yDesc,
+ void * y,
+ const cudnnTensorDescriptor_t hyDesc,
+ void * hy,
+ const cudnnTensorDescriptor_t cyDesc,
+ void * cy,
+ void * workspace,
+ size_t workSpaceSizeInBytes);
+
+
+
+cudnnStatus_t cudnnRNNForwardTraining( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t *xDesc,
+ const void * x,
+ const cudnnTensorDescriptor_t hxDesc,
+ const void * hx,
+ const cudnnTensorDescriptor_t cxDesc,
+ const void * cx,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const cudnnTensorDescriptor_t *yDesc,
+ void * y,
+ const cudnnTensorDescriptor_t hyDesc,
+ void * hy,
+ const cudnnTensorDescriptor_t cyDesc,
+ void * cy,
+ void * workspace,
+ size_t workSpaceSizeInBytes,
+ void * reserveSpace,
+ size_t reserveSpaceSizeInBytes);
+
+cudnnStatus_t cudnnRNNBackwardData( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t * yDesc,
+ const void * y,
+ const cudnnTensorDescriptor_t * dyDesc,
+ const void * dy,
+ const cudnnTensorDescriptor_t dhyDesc,
+ const void * dhy,
+ const cudnnTensorDescriptor_t dcyDesc,
+ const void * dcy,
+ const cudnnFilterDescriptor_t wDesc,
+ const void * w,
+ const cudnnTensorDescriptor_t hxDesc,
+ const void * hx,
+ const cudnnTensorDescriptor_t cxDesc,
+ const void * cx,
+ const cudnnTensorDescriptor_t * dxDesc,
+ void * dx,
+ const cudnnTensorDescriptor_t dhxDesc,
+ void * dhx,
+ const cudnnTensorDescriptor_t dcxDesc,
+ void * dcx,
+ void * workspace,
+ size_t workSpaceSizeInBytes,
+ const void * reserveSpace,
+ size_t reserveSpaceSizeInBytes );
+
+
+cudnnStatus_t cudnnRNNBackwardWeights( cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnnDesc,
+ const cudnnTensorDescriptor_t * xDesc,
+ const void * x,
+ const cudnnTensorDescriptor_t hxDesc,
+ const void * hx,
+ const cudnnTensorDescriptor_t * yDesc,
+ const void * y,
+ const void * workspace,
+ size_t workSpaceSizeInBytes,
+ const cudnnFilterDescriptor_t dwDesc,
+ void * dw,
+ const void * reserveSpace,
+ size_t reserveSpaceSizeInBytes );
+
+
+
+/* DEPRECATED routines to be removed next release :
+ User should use the non-suffixed version (which has the API and functionality of _v4 version)
+ Routines with _v3 suffix has the functionality of the non-suffixed routines in the CUDNN V4
+ */
+
+cudnnStatus_t cudnnSetFilter4dDescriptor_v3(
+ cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t dataType, /* image data type*/
+ int k, /* number of output feature maps*/
+ int c, /* number of input feature maps*/
+ int h, /* height of each input filter*/
+ int w ); /* width of each input filter*/
+
+cudnnStatus_t cudnnSetFilter4dDescriptor_v4(
+ cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t dataType, /* image data type*/
+ cudnnTensorFormat_t format,
+ int k, /* number of output feature maps*/
+ int c, /* number of input feature maps*/
+ int h, /* height of each input filter*/
+ int w ); /* width of each input filter*/
+
+cudnnStatus_t cudnnGetFilter4dDescriptor_v3(
+ const cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t *dataType, /* image data type*/
+ int *k, /* number of output feature maps*/
+ int *c, /* number of input feature maps*/
+ int *h, /* height of each input filter*/
+ int *w ); /* width of each input filter*/
+
+cudnnStatus_t cudnnGetFilter4dDescriptor_v4(
+ const cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t *dataType, /* image data type*/
+ cudnnTensorFormat_t *format,
+ int *k, /* number of output feature maps*/
+ int *c, /* number of input feature maps*/
+ int *h, /* height of each input filter*/
+ int *w ); /* width of each input filter */
+
+cudnnStatus_t cudnnSetFilterNdDescriptor_v3(
+ cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t dataType, /* image data type*/
+ int nbDims,
+ const int filterDimA[] );
+
+
+cudnnStatus_t cudnnSetFilterNdDescriptor_v4(
+ cudnnFilterDescriptor_t filterDesc,
+ cudnnDataType_t dataType, /* image data type*/
+ cudnnTensorFormat_t format,
+ int nbDims,
+ const int filterDimA[] );
+
+cudnnStatus_t cudnnGetFilterNdDescriptor_v3(
+ const cudnnFilterDescriptor_t filterDesc,
+ int nbDimsRequested,
+ cudnnDataType_t *dataType, /* image data type*/
+ int *nbDims,
+ int filterDimA[] );
+
+cudnnStatus_t cudnnGetFilterNdDescriptor_v4(
+ const cudnnFilterDescriptor_t filterDesc,
+ int nbDimsRequested,
+ cudnnDataType_t *dataType, /* image data type*/
+ cudnnTensorFormat_t *format,
+ int *nbDims,
+ int filterDimA[] );
+
+cudnnStatus_t cudnnSetPooling2dDescriptor_v3(
+ cudnnPoolingDescriptor_t poolingDesc,
+ cudnnPoolingMode_t mode,
+ int windowHeight,
+ int windowWidth,
+ int verticalPadding,
+ int horizontalPadding,
+ int verticalStride,
+ int horizontalStride );
+
+cudnnStatus_t cudnnSetPooling2dDescriptor_v4(
+ cudnnPoolingDescriptor_t poolingDesc,
+ cudnnPoolingMode_t mode,
+ cudnnNanPropagation_t maxpoolingNanOpt,
+ int windowHeight,
+ int windowWidth,
+ int verticalPadding,
+ int horizontalPadding,
+ int verticalStride,
+ int horizontalStride );
+cudnnStatus_t cudnnGetPooling2dDescriptor_v3(
+ const cudnnPoolingDescriptor_t poolingDesc,
+ cudnnPoolingMode_t *mode,
+ int *windowHeight,
+ int *windowWidth,
+ int *verticalPadding,
+ int *horizontalPadding,
+ int *verticalStride,
+ int *horizontalStride );
+
+cudnnStatus_t cudnnGetPooling2dDescriptor_v4(
+ const cudnnPoolingDescriptor_t poolingDesc,
+ cudnnPoolingMode_t *mode,
+ cudnnNanPropagation_t *maxpoolingNanOpt,
+ int *windowHeight,
+ int *windowWidth,
+ int *verticalPadding,
+ int *horizontalPadding,
+ int *verticalStride,
+ int *horizontalStride );
+
+cudnnStatus_t cudnnSetPoolingNdDescriptor_v3(
+ cudnnPoolingDescriptor_t poolingDesc,
+ const cudnnPoolingMode_t mode,
+ int nbDims,
+ const int windowDimA[],
+ const int paddingA[],
+ const int strideA[] );
+
+cudnnStatus_t cudnnSetPoolingNdDescriptor_v4(
+ cudnnPoolingDescriptor_t poolingDesc,
+ const cudnnPoolingMode_t mode,
+ const cudnnNanPropagation_t maxpoolingNanOpt,
+ int nbDims,
+ const int windowDimA[],
+ const int paddingA[],
+ const int strideA[] );
+
+cudnnStatus_t cudnnGetPoolingNdDescriptor_v3(
+ const cudnnPoolingDescriptor_t poolingDesc,
+ const int nbDimsRequested,
+ cudnnPoolingMode_t *mode,
+ int *nbDims,
+ int windowDimA[],
+ int paddingA[],
+ int strideA[] );
+
+cudnnStatus_t cudnnGetPoolingNdDescriptor_v4(
+ const cudnnPoolingDescriptor_t poolingDesc,
+ int nbDimsRequested,
+ cudnnPoolingMode_t *mode,
+ cudnnNanPropagation_t *maxpoolingNanOpt,
+ int *nbDims,
+ int windowDimA[],
+ int paddingA[],
+ int strideA[] );
+
+cudnnStatus_t cudnnActivationForward_v3(
cudnnHandle_t handle,
- cudnnAddMode_t mode,
+ cudnnActivationMode_t mode,
const void *alpha,
- const cudnnTensorDescriptor_t bDesc,
- const void *b,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
const void *beta,
- cudnnTensorDescriptor_t yDesc,
+ const cudnnTensorDescriptor_t yDesc,
void *y );
-cudnnStatus_t cudnnConvolutionBackwardFilter_v2(
+cudnnStatus_t cudnnActivationForward_v4(
cudnnHandle_t handle,
+ cudnnActivationDescriptor_t activationDesc,
const void *alpha,
const cudnnTensorDescriptor_t xDesc,
const void *x,
+ const void *beta,
+ const cudnnTensorDescriptor_t yDesc,
+ void *y );
+
+cudnnStatus_t cudnnActivationBackward_v3(
+ cudnnHandle_t handle,
+ cudnnActivationMode_t mode,
+ const void *alpha,
+ const cudnnTensorDescriptor_t yDesc,
+ const void *y,
const cudnnTensorDescriptor_t dyDesc,
const void *dy,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
const void *beta,
- const cudnnFilterDescriptor_t dxDesc,
+ const cudnnTensorDescriptor_t dxDesc,
void *dx );
-cudnnStatus_t cudnnConvolutionBackwardData_v2(
+cudnnStatus_t cudnnActivationBackward_v4(
cudnnHandle_t handle,
+ cudnnActivationDescriptor_t activationDesc,
const void *alpha,
- const cudnnFilterDescriptor_t xDesc,
- const void *x,
+ const cudnnTensorDescriptor_t yDesc,
+ const void *y,
const cudnnTensorDescriptor_t dyDesc,
const void *dy,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t xDesc,
+ const void *x,
const void *beta,
const cudnnTensorDescriptor_t dxDesc,
void *dx );
-]]
-local libnames = {'libcudnn.so.4', 'libcudnn.4.dylib'}
+]]
+
+local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib'}
local ok = false
for i=1,#libnames do
ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end)
@@ -1275,15 +1588,16 @@ for i=1,#libnames do
end
if not ok then
- error([['libcudnn (R4) not found in library path.
+ print(err)
+ error([['libcudnn (R5) not found in library path.
Please install CuDNN from https://developer.nvidia.com/cuDNN
-Then make sure files named as libcudnn.so.4 or libcudnn.4.dylib are placed in your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
+Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
]])
end
cudnn.version = tonumber(cudnn.C.cudnnGetVersion())
-if cudnn.version < 4005 then
- error('These bindings are for version 4005 or above, '
+if cudnn.version < 5002 then
+ error('These bindings are for version 5002 or above, '
.. 'while the loaded CuDNN is version: ' .. cudnn.version
.. ' \nAre you using an older version of CuDNN?')
end
diff --git a/functional.lua b/functional.lua
index 04db746..4564fb7 100644
--- a/functional.lua
+++ b/functional.lua
@@ -60,7 +60,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou
local nOutputPlane, nInputPlane, kH, kW
= weight:size(1), weight:size(2), weight:size(3), weight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -73,7 +73,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou
local pad = torch.IntTensor({padH, padW})
local stride = torch.IntTensor({strideH, strideW})
local upscale = torch.IntTensor({1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -139,7 +139,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
local nOutputPlane, nInputPlane, kH, kW
= weight:size(1), weight:size(2), weight:size(3), weight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -152,7 +152,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
local pad = torch.IntTensor({padH, padW})
local stride = torch.IntTensor({strideH, strideW})
local upscale = torch.IntTensor({1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -175,7 +175,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
algSearchMode, 0, algType)
-- do convolution
- errcheck('cudnnConvolutionBackwardData_v3', handle,
+ errcheck('cudnnConvolutionBackwardData', handle,
one:data(),
weightDesc[0], weight:data(),
oDesc[0], gradOutput:data(),
@@ -204,7 +204,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
local nOutputPlane, nInputPlane, kH, kW
= gradWeight:size(1), gradWeight:size(2), gradWeight:size(3), gradWeight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -217,7 +217,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
local pad = torch.IntTensor({padH, padW})
local stride = torch.IntTensor({strideH, strideW})
local upscale = torch.IntTensor({1,1})
- errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
'CUDNN_DATA_FLOAT');
@@ -242,7 +242,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
-- do convolution
- errcheck('cudnnConvolutionBackwardFilter_v3', handle,
+ errcheck('cudnnConvolutionBackwardFilter', handle,
scaleT:data(),
iDesc[0], input:data(),
oDesc[0], gradOutput:data(),
@@ -284,7 +284,7 @@ cudnn.functional.Pooling_updateOutput = function(handle, mode, input, output,
local ker = torch.IntTensor({kH, kW})
local str = torch.IntTensor({dH, dW})
local pad = torch.IntTensor({padH, padW})
- errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 2,
+ errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 2,
ker:data(), pad:data(), str:data());
local function destroyPoolDesc(d)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
@@ -347,7 +347,7 @@ cudnn.functional.Pooling_updateGradInput = function(handle, mode, input, output,
local ker = torch.IntTensor({kH, kW})
local str = torch.IntTensor({dH, dW})
local pad = torch.IntTensor({padH, padW})
- errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 2,
+ errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 2,
ker:data(), pad:data(), str:data());
local function destroyPoolDesc(d)
errcheck('cudnnDestroyPoolingDescriptor', d[0]);
diff --git a/test/benchmark.lua b/test/benchmark.lua
index 4372502..553e918 100644
--- a/test/benchmark.lua
+++ b/test/benchmark.lua
@@ -1,10 +1,11 @@
require 'cudnn'
require 'torch'
-function bench(title, nInputC, nOutputC, kH, kW, sH, sW, iH, iW, nBatch, ...)
+function benchSpatial(title, nInputC, nOutputC, kH, kW, sH, sW, iH, iW, nBatch, ...)
local m1 = cudnn.SpatialConvolution(nInputC,nOutputC,kW,kH, sW, sH):setMode(...):fastest():cuda()
local i1 = torch.zeros(nBatch, nInputC, iH, iW):cuda()
local o1 = m1:forward(i1)
+ cutorch.synchronize()
local t1 = torch.Timer()
local o1 = m1:forward(i1)
@@ -27,47 +28,72 @@ iH = (outH-1)*sH+kH
print('CUDNN Version: ', tonumber(cudnn.C.cudnnGetVersion()))
+print("cudnn.SpatialConvolution")
-- just auto-tuned by cudnn with CUDNN_CONVOLUTION_FWD_PREFER_FASTEST mode
-bench('Forward AutoTuned ', from, to, kH, kW, sH, sW, iH, iW, batchSize)
-
-bench('Forward implicit gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
- 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM',
- 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
- 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
-
-bench('Forward implicit precomp gemm', from, to, kH, kW, sH, sW, iH, iW, batchSize,
- 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM',
- 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
- 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
-
-bench('Forward gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
- 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM',
- 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
- 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+for i, mode_desc in ipairs({
+ {'Forward AutoTuned ', nil},
+ {'Forward implicit gemm ', 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'},
+ {'Forward implicit precomp gemm', 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'},
+ {'Forward gemm ', 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'},
+ {'Forward FFT ', 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'},
+ {'Forward FFT tiling ', 'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'},
+-- {'Forward Winograd ', 'CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD'} -- not supported for this size
+}) do
+ local title = mode_desc[1]
+ local mode = mode_desc[2]
+
+ benchSpatial(title, from, to, kH, kW, sH, sW, iH, iW, batchSize, mode)
+end
-bench('Forward FFT ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
- 'CUDNN_CONVOLUTION_FWD_ALGO_FFT',
- 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
- 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+function benchVolumetric(title, nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, kT_input, kW_input, kH_input, nBatch, ...)
+ local gconv = cudnn.VolumetricConvolution(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH):setMode(...):fastest():cuda()
+ local input = torch.zeros(nBatch, nInputPlane, kT_input, kW_input, kH_input):cuda()
+ local output = gconv:forward(input)
+ cutorch.synchronize()
+ local t1 = torch.Timer()
+ local output = gconv:forward(input)
+ cutorch.synchronize()
+ print(title .. ': ', nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH, kT_input, kW_input, kH_input, nBatch, t1:time().real)
+end
+print("cudnn.VolumetricConvolution")
+
+for i, mode_desc in ipairs({
+ {'Forward AutoTuned ', nil},
+ {'Forward implicit gemm ', 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'},
+ {'Forward implicit precomp gemm', 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'},
+-- {'Forward gemm ', 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'}, -- not supported for this size
+-- {'Forward FFT ', 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'}, -- not supported for this size
+ {'Forward FFT tiling ', 'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'},
+-- {'Forward Winograd ', 'CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD'} -- not supported for this size
+}) do
+ local title = mode_desc[1]
+ local mode = mode_desc[2]
+
+ benchVolumetric(title, 256, 256, 3,3,3, 1,1,1, 1,1,1, 8, 28, 28, 50, mode)
+end
-- For reference, CuDNN Convolution modes
--[[
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
- CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3, // Placeholder
- CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, // non-deterministic
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2
-
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
-
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3 // non-deterministic, algo0 with workspace
+
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4
]]--