Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2016-09-22 12:19:15 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2016-09-23 04:06:43 +0300
commit9465aae4f41734c8218adaf2d50c7b3f5c9e80f7 (patch)
treef847cee2bf66726d11e1f5a6e402f936a108a401 /VolumetricConvolution.lua
parenta17af4f12cbeb87103dbc514408eb64e1be85ba7 (diff)
Revamped workspace handling in find.lua
Retired functional.lua: impossible to maintain consistently with Find. Simplified FindEx state machine: replaced witgh warmup iterations concept, controllable by user. FindEx still needs some work. Improved cache handling and debug print
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r--VolumetricConvolution.lua15
1 files changed, 10 insertions, 5 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 6a06075..d38125b 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -37,13 +37,18 @@ function VolumetricConvolution:createIODescriptors(input)
-- create conv descriptor
self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]',
'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor')
- local pad = torch.IntTensor({self.padT, self.padH, self.padW})
- local stride = torch.IntTensor({self.dT, self.dH, self.dW})
+ self.pad = torch.IntTensor({self.padT, self.padH, self.padW})
+ self.stride = torch.IntTensor({self.dT, self.dH, self.dW})
local upscale = torch.IntTensor({1,1,1})
+ local mathtype=cudnn.configmap(torch.type(self.weight))
+ -- 3D convolutions do not work in 16 bits
+ if mathtype == 'CUDNN_DATA_HALF' then
+ mathtype = 'CUDNN_DATA_FLOAT'
+ end
errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0],
- 3, pad:data(),
- stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- cudnn.configmap(torch.type(self.weight)));
+ 3, self.pad:data(),
+ self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
+ mathtype);
-- create output descriptor and resize output
local oSize = torch.IntTensor(5)