Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-09-29 21:33:35 +0300
committerGitHub <noreply@github.com>2016-09-29 21:33:35 +0300
commit01e765d8ef8cadb079fb5063918b1524061b3241 (patch)
tree5171cacdc7eccd8572d2826d7d6a30e0ae64d6d8
parent6b454379d7b784c56b1d4710bcac3c0f5dbc471c (diff)
parent40a9ab188affe55e2faa2a0cf4bf226199f3c454 (diff)
Merge pull request #265 from szagoruyko/fastest-functional
cudnn.fastest in functional
-rw-r--r--functional.lua74
1 files changed, 62 insertions, 12 deletions
diff --git a/functional.lua b/functional.lua
index d721b0d..84bd7a5 100644
--- a/functional.lua
+++ b/functional.lua
@@ -103,23 +103,39 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou
-- create forwardAlgorithm descriptors for
local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
- local algWorkspaceLimit = 0
- if workspace then
- algWorkspaceLimit = workspace:nElement() * 4 -- 4 = sizeof float
+ if cudnn.fastest then
+ algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST'
end
+ local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight)
+
errcheck('cudnnGetConvolutionForwardAlgorithm',
handle,
iDesc[0], weightDesc[0],
convDesc[0], oDesc[0],
algSearchMode, algWorkspaceLimit, algType)
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionForwardWorkspaceSize',
+ handle,
+ iDesc[0], weightDesc[0],
+ convDesc[0], oDesc[0],
+ algType[0], bufSize:data())
+ local maxBufSize = bufSize[1]
+
+ local extraBuffer = workspace or cudnn.getSharedWorkspace()
+ local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float
+ if maxBufSize > extraBufferSizeInBytes then
+ extraBuffer:resize(math.ceil(maxBufSize / 4))
+ extraBufferSizeInBytes = maxBufSize
+ end
+
-- do convolution
errcheck('cudnnConvolutionForward', handle,
cudnn.scalar(input, 1),
iDesc[0], input:data(),
weightDesc[0], weight:data(),
convDesc[0], algType[0],
- workspace and workspace:data() or nil, algWorkspaceLimit,
+ extraBuffer:data(), extraBufferSizeInBytes,
cudnn.scalar(input, 0),
oDesc[0], output:data());
end
@@ -167,12 +183,31 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
+ if cudnn.fastest then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
+ end
+ local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(weight)
errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
- cudnn.getHandle(),
+ handle,
weightDesc[0], oDesc[0],
convDesc[0], iDesc[0],
- algSearchMode, 0, algType)
+ algSearchMode, algWorkspaceLimit, algType)
+
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize',
+ handle,
+ weightDesc[0], oDesc[0],
+ convDesc[0], iDesc[0],
+ algType[0], bufSize:data())
+ local maxBufSize = bufSize[1]
+
+ local extraBuffer = cudnn.getSharedWorkspace()
+ local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float
+ if maxBufSize > extraBufferSizeInBytes then
+ extraBuffer:resize(math.ceil(maxBufSize / 4))
+ extraBufferSizeInBytes = maxBufSize
+ end
-- do convolution
errcheck('cudnnConvolutionBackwardData', handle,
@@ -181,11 +216,9 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
oDesc[0], gradOutput:data(),
convDesc[0],
algType[0],
- NULL, 0,
+ extraBuffer:data(), extraBufferSizeInBytes,
cudnn.scalar(input, 0),
iDesc[0], gradInput:data());
-
-
end
-- accumulates the gradients into gradWeight.
@@ -232,14 +265,31 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
- local algWorkspaceLimit = 0
+ if cudnn.fastest then
+ algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
+ end
+ local algWorkspaceLimit = nInputPlane * kH * kW * cudnn.sizeof(gradWeight)
errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
- cudnn.getHandle(),
+ handle,
iDesc[0], oDesc[0],
convDesc[0], weightDesc[0],
algSearchMode, algWorkspaceLimit, algType)
+ local bufSize = torch.LongTensor(1)
+ errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize',
+ handle,
+ iDesc[0], oDesc[0],
+ convDesc[0], weightDesc[0],
+ algType[0], bufSize:data())
+ local maxBufSize = bufSize[1]
+
+ local extraBuffer = cudnn.getSharedWorkspace()
+ local extraBufferSizeInBytes = extraBuffer:nElement() * 4 -- extraBuffer is always float
+ if maxBufSize > extraBufferSizeInBytes then
+ extraBuffer:resize(math.ceil(maxBufSize / 4))
+ extraBufferSizeInBytes = maxBufSize
+ end
-- do convolution
errcheck('cudnnConvolutionBackwardFilter', handle,
@@ -248,7 +298,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
oDesc[0], gradOutput:data(),
convDesc[0],
algType[0],
- NULL, 0,
+ extraBuffer:data(), extraBufferSizeInBytes,
cudnn.scalar(input, 1),
weightDesc[0], gradWeight:data());
end