diff options
author | soumith <soumith@gmail.com> | 2015-03-27 22:12:09 +0300 |
---|---|---|
committer | soumith <soumith@gmail.com> | 2015-03-27 22:12:09 +0300 |
commit | d35848abf687fa813af375b97278578e1b3c578b (patch) | |
tree | 6b62e0aded9032064d8c5ed69d8da713cbacfaa6 | |
parent | 3f62241aeb1459d3c27513a978dfbc097928bd57 (diff) |
choosing reasonable defaults for buffer-limits in convolution module, with optional fastest mode
-rw-r--r-- | SpatialConvolution.lua | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 5ec6f8a..23775fa 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -46,6 +46,12 @@ function SpatialConvolution:resetWeightDescriptors() self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1)[bias_slice]) end +function SpatialConvolution:fastest(mode) + mode = mode or true + self.fastest_mode = mode + return self +end + function SpatialConvolution:createIODescriptors(input) local batch = true if input:dim() == 3 then @@ -95,11 +101,14 @@ function SpatialConvolution:createIODescriptors(input) -- create forwardAlgorithm descriptors for local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) + local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT' + local algWorkspaceLimit = self.nInputPlane * self.kH * self.kW * 4 -- 4 = sizeof int. + if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end errcheck('cudnnGetConvolutionForwardAlgorithm', cudnn.handle[cutorch.getDevice()-1], self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], - 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType) + algSearchMode, algWorkspaceLimit, algType) self.algType = algType local bufSize = torch.LongTensor(1) errcheck('cudnnGetConvolutionForwardWorkspaceSize', |