Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2015-08-23 06:12:10 +0300
committersoumith <soumith@gmail.com>2015-08-23 06:12:10 +0300
commitb3d74f5b882312863f11025fc38e0ca2b0e8f478 (patch)
treebbe7dfbef9ea053381af81d6cd0cf292ee9d4e86
parent4c2a0a568a232956f0cbe3c89c5e889df0f0ed94 (diff)
flag to enable or not to enable auto-tuner
-rw-r--r--SpatialConvolution.lua162
-rw-r--r--ffi.lua75
-rw-r--r--test/benchmark.lua5
-rw-r--r--test/test.lua5
4 files changed, 175 insertions, 72 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 0c1f311..3442d89 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -35,8 +35,8 @@ function SpatialConvolution:resetWeightDescriptors()
self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
local desc = torch.IntTensor({self.nOutputPlane/self.groups,
- self.nInputPlane/self.groups,
- self.kH, self.kW})
+ self.nInputPlane/self.groups,
+ self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
'CUDNN_DATA_FLOAT', 4,
desc:data());
@@ -86,6 +86,7 @@ function SpatialConvolution:createIODescriptors(input)
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
self.iSize = input:size()
+
-- resize gradInput
if self.gradInput then self.gradInput:resizeAs(input); end
assert(self.nInputPlane == input:size(2), 'input has to contain: '
@@ -93,9 +94,11 @@ function SpatialConvolution:createIODescriptors(input)
.. ' feature maps, but received input of size: '
.. input:size(1) .. ' x ' .. input:size(2) ..
' x ' .. input:size(3) .. ' x ' .. input:size(4))
+
-- create input descriptor
local input_slice = {{},{1,self.nInputPlane/self.groups},{},{}}
self.iDesc = cudnn.toDescriptor(input[input_slice])
+
-- create conv descriptor
self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
@@ -111,7 +114,7 @@ function SpatialConvolution:createIODescriptors(input)
end
ffi.gc(self.convDesc, destroyConvDesc)
- -- create output descriptor and resize output
+ -- get output shape, resize output
local oSize = torch.IntTensor(4)
local oSizeD = oSize:data()
errcheck('cudnnGetConvolutionNdForwardOutputDim',
@@ -119,6 +122,7 @@ function SpatialConvolution:createIODescriptors(input)
self.weightDesc[0], 4, oSizeD)
oSize[2] = oSize[2] * self.groups
self.output:resize(oSize:long():storage())
+
-- create descriptor for output
local output_slice = {{},{1,self.nOutputPlane/self.groups},{},{}}
self.oDesc = cudnn.toDescriptor(self.output[output_slice])
@@ -126,17 +130,34 @@ function SpatialConvolution:createIODescriptors(input)
-----------------------------------------------------------------------
local maxBufSize = 0
- -- create forwardAlgorithm descriptors for
+
+ -- create forwardAlgorithm descriptors
local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
local algWorkspaceLimit = self.workspace_limit
or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
+
if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end
- errcheck('cudnnGetConvolutionForwardAlgorithm',
- cudnn.getHandle(),
- self.iDesc[0], self.weightDesc[0],
- self.convDesc[0], self.oDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ local perfResults = ffi.new("cudnnConvolutionFwdAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ if cudnn.verbose then
+ print('AutoTuning:', perfResults[0].time,
+ tonumber(perfResults[0].memory), tonumber(perfResults[0].algo))
+ end
+ else
+ errcheck('cudnnGetConvolutionForwardAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.weightDesc[0],
+ self.convDesc[0], self.oDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
algType[0] = self.fmode or algType[0]
self.fwdAlgType = algType
local bufSize = torch.LongTensor(1)
@@ -147,17 +168,35 @@ function SpatialConvolution:createIODescriptors(input)
algType[0], bufSize:data())
maxBufSize = math.max(maxBufSize, bufSize[1])
- -- create backwardFilterAlgorithm descriptors for
+ -- create backwardFilterAlgorithm descriptors
local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
local algWorkspaceLimit = self.workspace_limit
or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
- if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST' end
- errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
- cudnn.getHandle(),
- self.iDesc[0], self.oDesc[0],
- self.convDesc[0], self.weightDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
+ if self.fastest_mode then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
+ end
+
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ local perfResults = ffi.new("cudnnConvolutionBwdFilterAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.oDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ if cudnn.verbose then
+ print('AutoTuning:', perfResults[0].time,
+ tonumber(perfResults[0].memory), tonumber(perfResults[0].algo))
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ self.iDesc[0], self.oDesc[0],
+ self.convDesc[0], self.weightDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
algType[0] = self.bwmode or algType[0]
self.bwdFilterAlgType = algType
local bufSize = torch.LongTensor(1)
@@ -168,17 +207,34 @@ function SpatialConvolution:createIODescriptors(input)
algType[0], bufSize:data())
maxBufSize = math.max(maxBufSize, bufSize[1])
- -- create backwardDataAlgorithm descriptors for
+ -- create backwardDataAlgorithm descriptors
local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
local algWorkspaceLimit = self.workspace_limit
or (self.nInputPlane * self.kH * self.kW * 4) -- 4 = sizeof int/float.
- if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST' end
- errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
- cudnn.getHandle(),
- self.weightDesc[0], self.oDesc[0],
- self.convDesc[0], self.iDesc[0],
- algSearchMode, algWorkspaceLimit, algType)
+ if self.fastest_mode then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
+ end
+ if cudnn.benchmark then -- the manual auto-tuner is run
+ local perfResults = ffi.new("cudnnConvolutionBwdDataAlgoPerf_t[?]", 1)
+ local intt = torch.IntTensor(1);
+ errcheck('cudnnFindConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.oDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ 1, intt:data(), perfResults)
+ algType[0] = perfResults[0].algo
+ if cudnn.verbose then
+ print('AutoTuning:', perfResults[0].time,
+ tonumber(perfResults[0].memory), tonumber(perfResults[0].algo))
+ end
+ else
+ errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ self.weightDesc[0], self.oDesc[0],
+ self.convDesc[0], self.iDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+ end
algType[0] = self.bdmode or algType[0]
self.bwdDataAlgType = algType
local bufSize = torch.LongTensor(1)
@@ -198,10 +254,12 @@ function SpatialConvolution:createIODescriptors(input)
-----------------------------------------------------------------------
-- create offsets for groups
- self.input_offset = self.nInputPlane/self.groups*input:size(3)*input:size(4)
- self.output_offset = self.nOutputPlane/self.groups*oSize[3]*oSize[4]
- self.weight_offset =
- self.nInputPlane/self.groups*self.nOutputPlane/self.groups*self.kW*self.kH
+ local iH, iW = input:size(3), input:size(4)
+ local kH, kW = self.kH, self.kW
+ local oH, oW = oSize[3], oSize[4]
+ self.input_offset = self.nInputPlane / self.groups * iH * iW
+ self.output_offset = self.nOutputPlane / self.groups * oH, oW
+ self.weight_offset = self.nInputPlane / self.groups * self.nOutputPlane / self.groups * kH * kW
if not batch then
self.gradInput = self.gradInput:view(self.gradInput:size(2),
@@ -220,17 +278,20 @@ local zero = torch.FloatTensor({0});
function SpatialConvolution:updateOutput(input)
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
+
local prevStream
local streamQueue = {}
- if self.groups > 1 then
+ if self.groups > 1 then -- try to do stream parallelization
prevStream = cutorch.getStream()
+
--[[
Only if prevStream is 0, then do parallelization.
the justification for this is that this is a hard problem, there is no
way to know if one is doing other kinds of stream-parallelization
- (like GPUConcat), and if that's the case, streams are already
+ (like GPUConcat), and if thats the case, streams are already
being ideally exploited.
- ]]--
+ --]]
+
if prevStream == 0 then
cutorch.reserveStreams(self.groups)
for i=1,self.groups do
@@ -238,11 +299,14 @@ function SpatialConvolution:updateOutput(input)
end
end
end
- for g=0,self.groups-1 do
+
+ for g = 0, self.groups - 1 do
+ -- stream-parallelize if appropriate
if self.groups > 1 and prevStream == 0 then
- cutorch.setStream(g+1)
- table.insert(streamQueue, g+1)
+ cutorch.setStream(g + 1)
+ table.insert(streamQueue, g + 1)
end
+
errcheck('cudnnConvolutionForward', cudnn.getHandle(),
one:data(),
self.iDesc[0], input:data() + g*self.input_offset,
@@ -252,10 +316,13 @@ function SpatialConvolution:updateOutput(input)
zero:data(),
self.oDesc[0], self.output:data() + g*self.output_offset);
end
+
if prevStream == 0 then
cutorch.setStream(prevStream)
cutorch.streamWaitFor(prevStream, streamQueue)
end
+
+ -- add bias
errcheck('cudnnAddTensor', cudnn.getHandle(),
'CUDNN_ADD_SAME_C',
one:data(), self.biasDesc[0], self.bias:data(),
@@ -266,11 +333,13 @@ end
function SpatialConvolution:updateGradInput(input, gradOutput)
if not self.gradInput then return end
- assert((gradOutput:dim() == 3 or gradOutput:dim() == 4)
- and gradOutput:isContiguous());
- if not self.weightDesc then self:resetWeightDescriptors() end
+
+ assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D');
+ assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous')
self:createIODescriptors(input)
- for g=0,self.groups-1 do
+ if not self.weightDesc then self:resetWeightDescriptors() end
+
+ for g = 0,self.groups - 1 do
errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(),
one:data(),
self.weightDesc[0], self.weight:data() + g*self.weight_offset,
@@ -290,17 +359,20 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
self.scaleT = self.scaleT:float()
scale = scale or 1.0
self.scaleT[1] = scale
- assert((gradOutput:dim() == 3 or gradOutput:dim() == 4)
- and gradOutput:isContiguous());
+
+ assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D');
+ assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous')
self:createIODescriptors(input)
if not self.weightDesc then self:resetWeightDescriptors() end
+
-- gradBias
errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
self.scaleT:data(),
self.oDescForBias[0], gradOutput:data(),
one:data(),
self.biasDesc[0], self.gradBias:data())
- for g=0,self.groups-1 do
+
+ for g = 0, self.groups - 1 do
-- gradWeight
errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(),
self.scaleT:data(),
@@ -333,13 +405,3 @@ function SpatialConvolution:write(f)
end
f:writeObject(var)
end
-
---[[
-function SpatialConvolution:zeroGradParameters()
- -- gradWeight, gradBias to zero
- errcheck('cudnnSetTensor', cudnn.getHandle(),
- self.weightDesc, self.gradWeight:data(), zero:data());
- errcheck('cudnnSetTensor', cudnn.getHandle(),
- self.biasDesc, self.gradBias:data(), zero:data());
-end
-]]--
diff --git a/ffi.lua b/ffi.lua
index 43f5d9c..d749744 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -113,8 +113,8 @@ cudnnStatus_t
cudnnSetConvolutionNdDescriptor_v3( cudnnConvolutionDescriptor_t convDesc,
int arrayLength,
const int padA[],
- const int filterStrideA[],
- const int upscaleA[],
+ const int filterStrideA[],
+ const int upscaleA[],
cudnnConvolutionMode_t mode,
cudnnDataType_t dataType
);
@@ -158,7 +158,7 @@ cudnnStatus_t
cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
const cudnnTensorDescriptor_t srcDesc,
const cudnnFilterDescriptor_t filterDesc,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t destDesc,
const int requestedCount,
int *returnedCount,
@@ -172,7 +172,7 @@ cudnnStatus_t cudnnGetConvolutionForwardAlgorithm( cudnnHandle_t handle,
const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t destDesc,
cudnnConvolutionFwdPreference_t preference,
- size_t memoryLimitInbytes,
+ size_t memoryLimitInbytes,
cudnnConvolutionFwdAlgo_t *algo
);
@@ -214,8 +214,8 @@ typedef enum
{
CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE = 0,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST = 1
-} cudnnConvolutionBwdFilterPreference_t;
-
+} cudnnConvolutionBwdFilterPreference_t;
+
typedef enum
{
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, // non-deterministic
@@ -223,29 +223,47 @@ typedef enum
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2
} cudnnConvolutionBwdFilterAlgo_t;
+typedef struct {
+ cudnnConvolutionBwdFilterAlgo_t algo;
+ cudnnStatus_t status;
+ float time;
+ size_t memory;
+} cudnnConvolutionBwdFilterAlgoPerf_t;
+
+cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm( cudnnHandle_t handle,
+ const cudnnTensorDescriptor_t srcDesc,
+ const cudnnTensorDescriptor_t diffDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnFilterDescriptor_t gradDesc,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults
+ );
+
+
cudnnStatus_t
cudnnGetConvolutionBackwardFilterAlgorithm(
cudnnHandle_t handle,
const cudnnTensorDescriptor_t srcDesc,
const cudnnTensorDescriptor_t diffDesc,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
const cudnnFilterDescriptor_t gradDesc,
cudnnConvolutionBwdFilterPreference_t preference,
size_t memoryLimitInbytes,
cudnnConvolutionBwdFilterAlgo_t *algo
);
-
+
cudnnStatus_t
cudnnGetConvolutionBackwardFilterWorkspaceSize(
- cudnnHandle_t handle,
+ cudnnHandle_t handle,
const cudnnTensorDescriptor_t srcDesc,
const cudnnTensorDescriptor_t diffDesc,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
const cudnnFilterDescriptor_t gradDesc,
cudnnConvolutionBwdFilterAlgo_t algo,
size_t *sizeInBytes
);
-
+
cudnnStatus_t cudnnConvolutionBackwardFilter_v3(
cudnnHandle_t handle,
const void *alpha,
@@ -267,37 +285,54 @@ typedef enum
CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE = 0,
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST = 1
} cudnnConvolutionBwdDataPreference_t;
-
+
typedef enum
{
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
} cudnnConvolutionBwdDataAlgo_t;
-
+typedef struct {
+ cudnnConvolutionBwdDataAlgo_t algo;
+ cudnnStatus_t status;
+ float time;
+ size_t memory;
+} cudnnConvolutionBwdDataAlgoPerf_t;
+
+
+cudnnStatus_t cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
+ const cudnnFilterDescriptor_t filterDesc,
+ const cudnnTensorDescriptor_t diffDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnTensorDescriptor_t gradDesc,
+ const int requestedAlgoCount,
+ int *returnedAlgoCount,
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults
+ );
+
cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(
cudnnHandle_t handle,
const cudnnFilterDescriptor_t filterDesc,
const cudnnTensorDescriptor_t diffDesc,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t gradDesc,
- cudnnConvolutionBwdDataPreference_t preference,
+ cudnnConvolutionBwdDataPreference_t preference,
size_t memoryLimitInbytes,
cudnnConvolutionBwdDataAlgo_t *algo
);
cudnnStatus_t cudnnGetConvolutionBackwardDataWorkspaceSize(
- cudnnHandle_t handle,
+ cudnnHandle_t handle,
const cudnnFilterDescriptor_t filterDesc,
const cudnnTensorDescriptor_t diffDesc,
- const cudnnConvolutionDescriptor_t convDesc,
+ const cudnnConvolutionDescriptor_t convDesc,
const cudnnTensorDescriptor_t gradDesc,
cudnnConvolutionBwdDataAlgo_t algo,
size_t *sizeInBytes
- );
+ );
+
-
cudnnStatus_t cudnnConvolutionBackwardData_v3(
cudnnHandle_t handle,
const void *alpha,
diff --git a/test/benchmark.lua b/test/benchmark.lua
index 08218b9..4372502 100644
--- a/test/benchmark.lua
+++ b/test/benchmark.lua
@@ -28,6 +28,9 @@ iH = (outH-1)*sH+kH
print('CUDNN Version: ', tonumber(cudnn.C.cudnnGetVersion()))
+-- just auto-tuned by cudnn with CUDNN_CONVOLUTION_FWD_PREFER_FASTEST mode
+bench('Forward AutoTuned ', from, to, kH, kW, sH, sW, iH, iW, batchSize)
+
bench('Forward implicit gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM',
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
@@ -43,8 +46,6 @@ bench('Forward gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSi
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
--- just auto-tuned by cudnn with CUDNN_CONVOLUTION_FWD_PREFER_FASTEST mode
-bench('Forward AutoTuned ', from, to, kH, kW, sH, sW, iH, iW, batchSize)
bench('Forward FFT ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
'CUDNN_CONVOLUTION_FWD_ALGO_FFT',
diff --git a/test/test.lua b/test/test.lua
index 5c8b31d..c2938de 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -764,6 +764,11 @@ math.randomseed(os.time())
mytester = torch.Tester()
mytester:add(cudnntest)
+if torch.random(1,2) == 1 then
+ cudnn.benchmark = true -- run manual auto-tuner
+end
+
+
for i=1,cutorch.getDeviceCount() do
print('Running test on device: ' .. i)
cutorch.setDevice(i)