Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-07-08 11:08:18 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-07-08 11:08:18 +0300
commit155d6f59211504e6dc72840133df1eb93a16cfb3 (patch)
treed34b2bc005ad39aef7042ddecea316ef1f07b73d
parent3fd281c6e5dcb90bc80ef0083dc778a164c31159 (diff)
handle half true/pseudo
-rw-r--r--SpatialConvolution.lua2
-rw-r--r--SpatialFullConvolution.lua2
-rw-r--r--VolumetricConvolution.lua2
-rw-r--r--init.lua26
4 files changed, 24 insertions, 8 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index a224a6d..58c78b2 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -123,7 +123,7 @@ function SpatialConvolution:createIODescriptors(input)
errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- cudnn.configmap[torch.type(self.weight)]);
+ cudnn.configmap(torch.type(self.weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua
index 76c95b2..ec8061c 100644
--- a/SpatialFullConvolution.lua
+++ b/SpatialFullConvolution.lua
@@ -100,7 +100,7 @@ function SpatialFullConvolution:createIODescriptors(input)
errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- cudnn.configmap[torch.type(self.weight)]);
+ cudnn.configmap(torch.type(self.weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 00a476d..b255467 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -85,7 +85,7 @@ function VolumetricConvolution:createIODescriptors(input)
errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
3, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- cudnn.configmap[torch.type(self.weight)]);
+ cudnn.configmap(torch.type(self.weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
diff --git a/init.lua b/init.lua
index 147e1b3..b2e73d4 100644
--- a/init.lua
+++ b/init.lua
@@ -38,11 +38,27 @@ cudnn.typemap = {
-- TODO: determine if device supports true half and use true half on it
-- so far use float for half and float, double for double
-cudnn.configmap = {
- ['torch.CudaHalfTensor'] = 'CUDNN_DATA_FLOAT',
- ['torch.CudaTensor'] = 'CUDNN_DATA_FLOAT',
- ['torch.CudaDoubleTensor'] = 'CUDNN_DATA_DOUBLE',
-}
+local function determineHalfCapability(dev)
+ local prop = cutorch.getDeviceProperties(dev)
+ if prop.major >= 6 or prop.name:find'X1' then
+ return 'CUDNN_DATA_HALF'
+ else
+ return 'CUDNN_DATA_FLOAT'
+ end
+end
+
+local configmaps = {}
+for i=1,cutorch.getDeviceCount() do
+ configmaps[i] = {
+ ['torch.CudaHalfTensor'] = determineHalfCapability(i),
+ ['torch.CudaTensor'] = 'CUDNN_DATA_FLOAT',
+ ['torch.CudaDoubleTensor'] = 'CUDNN_DATA_DOUBLE',
+ }
+end
+
+cudnn.configmap = function(tensortype)
+ return configmaps[cutorch.getDevice()][tensortype]
+end
function cudnn.getHandle()
local device = cutorch.getDevice()