diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-09-22 12:19:15 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-09-23 04:06:43 +0300 |
commit | 9465aae4f41734c8218adaf2d50c7b3f5c9e80f7 (patch) | |
tree | f847cee2bf66726d11e1f5a6e402f936a108a401 /VolumetricConvolution.lua | |
parent | a17af4f12cbeb87103dbc514408eb64e1be85ba7 (diff) |
Revamped workspace handling in find.lua
Retired functional.lua: impossible to maintain consistently with Find.
Simplified FindEx state machine: replaced witgh warmup iterations concept, controllable by user.
FindEx still needs some work.
Improved cache handling and debug print
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r-- | VolumetricConvolution.lua | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 6a06075..d38125b 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -37,13 +37,18 @@ function VolumetricConvolution:createIODescriptors(input) -- create conv descriptor self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') - local pad = torch.IntTensor({self.padT, self.padH, self.padW}) - local stride = torch.IntTensor({self.dT, self.dH, self.dW}) + self.pad = torch.IntTensor({self.padT, self.padH, self.padW}) + self.stride = torch.IntTensor({self.dT, self.dH, self.dW}) local upscale = torch.IntTensor({1,1,1}) + local mathtype=cudnn.configmap(torch.type(self.weight)) + -- 3D convolutions do not work in 16 bits + if mathtype == 'CUDNN_DATA_HALF' then + mathtype = 'CUDNN_DATA_FLOAT' + end errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0], - 3, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap(torch.type(self.weight))); + 3, self.pad:data(), + self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + mathtype); -- create output descriptor and resize output local oSize = torch.IntTensor(5) |