diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-07-08 11:08:18 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-07-08 11:08:18 +0300 |
commit | 155d6f59211504e6dc72840133df1eb93a16cfb3 (patch) | |
tree | d34b2bc005ad39aef7042ddecea316ef1f07b73d | |
parent | 3fd281c6e5dcb90bc80ef0083dc778a164c31159 (diff) |
handle half true/pseudo
-rw-r--r-- | SpatialConvolution.lua | 2 | ||||
-rw-r--r-- | SpatialFullConvolution.lua | 2 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 2 | ||||
-rw-r--r-- | init.lua | 26 |
4 files changed, 24 insertions, 8 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index a224a6d..58c78b2 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -123,7 +123,7 @@ function SpatialConvolution:createIODescriptors(input) errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 2, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap[torch.type(self.weight)]); + cudnn.configmap(torch.type(self.weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua index 76c95b2..ec8061c 100644 --- a/SpatialFullConvolution.lua +++ b/SpatialFullConvolution.lua @@ -100,7 +100,7 @@ function SpatialFullConvolution:createIODescriptors(input) errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 2, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap[torch.type(self.weight)]); + cudnn.configmap(torch.type(self.weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 00a476d..b255467 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -85,7 +85,7 @@ function VolumetricConvolution:createIODescriptors(input) errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 3, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - cudnn.configmap[torch.type(self.weight)]); + cudnn.configmap(torch.type(self.weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end @@ -38,11 +38,27 @@ cudnn.typemap = { -- TODO: determine if device supports true half and use true half on it -- so far use float for half and float, double for double -cudnn.configmap = { - ['torch.CudaHalfTensor'] = 'CUDNN_DATA_FLOAT', - ['torch.CudaTensor'] = 'CUDNN_DATA_FLOAT', - ['torch.CudaDoubleTensor'] = 'CUDNN_DATA_DOUBLE', -} +local function determineHalfCapability(dev) + local prop = cutorch.getDeviceProperties(dev) + if prop.major >= 6 or prop.name:find'X1' then + return 'CUDNN_DATA_HALF' + else + return 'CUDNN_DATA_FLOAT' + end +end + +local configmaps = {} +for i=1,cutorch.getDeviceCount() do + configmaps[i] = { + ['torch.CudaHalfTensor'] = determineHalfCapability(i), + ['torch.CudaTensor'] = 'CUDNN_DATA_FLOAT', + ['torch.CudaDoubleTensor'] = 'CUDNN_DATA_DOUBLE', + } +end + +cudnn.configmap = function(tensortype) + return configmaps[cutorch.getDevice()][tensortype] +end function cudnn.getHandle() local device = cutorch.getDevice() |