Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2015-03-27 22:12:09 +0300
committersoumith <soumith@gmail.com>2015-03-27 22:12:09 +0300
commitd35848abf687fa813af375b97278578e1b3c578b (patch)
tree6b62e0aded9032064d8c5ed69d8da713cbacfaa6
parent3f62241aeb1459d3c27513a978dfbc097928bd57 (diff)
choosing reasonable defaults for buffer-limits in convolution module, with optional fastest mode
-rw-r--r--SpatialConvolution.lua11
1 files changed, 10 insertions, 1 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 5ec6f8a..23775fa 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -46,6 +46,12 @@ function SpatialConvolution:resetWeightDescriptors()
self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1)[bias_slice])
end
+function SpatialConvolution:fastest(mode)
+ mode = mode or true
+ self.fastest_mode = mode
+ return self
+end
+
function SpatialConvolution:createIODescriptors(input)
local batch = true
if input:dim() == 3 then
@@ -95,11 +101,14 @@ function SpatialConvolution:createIODescriptors(input)
-- create forwardAlgorithm descriptors for
local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
+ local algWorkspaceLimit = self.nInputPlane * self.kH * self.kW * 4 -- 4 = sizeof int.
+ if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end
errcheck('cudnnGetConvolutionForwardAlgorithm',
cudnn.handle[cutorch.getDevice()-1],
self.iDesc[0], self.weightDesc[0],
self.convDesc[0], self.oDesc[0],
- 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType)
+ algSearchMode, algWorkspaceLimit, algType)
self.algType = algType
local bufSize = torch.LongTensor(1)
errcheck('cudnnGetConvolutionForwardWorkspaceSize',