Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2016-02-29 03:11:17 +0300
committersoumith <soumith@gmail.com>2016-02-29 03:11:17 +0300
commit42b96c054aa5f205faba8d35b46837532f580227 (patch)
treee36627d36eb9a8dba5bdae6a04380efdb2b8f0d8
parentb80bdbab6faf66b711bf7c8a159703f62508c5e2 (diff)
fake halfhalf
-rw-r--r--SpatialConvolution.lua2
-rw-r--r--init.lua2
2 files changed, 2 insertions, 2 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 09e421d..5ee4d83 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -43,7 +43,7 @@ function SpatialConvolution:resetWeightDescriptors()
self.nInputPlane/self.groups,
self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- 'CUDNN_DATA_FLOAT', 4,
+ 'CUDNN_DATA_HALF', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
diff --git a/init.lua b/init.lua
index 53cb7ea..71c7e57 100644
--- a/init.lua
+++ b/init.lua
@@ -79,7 +79,7 @@ function cudnn.toDescriptor(t)
-- set descriptor
local size = torch.LongTensor(t:size()):int()
local stride = torch.LongTensor(t:stride()):int()
- errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT',
+ errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_HALF',
t:dim(), size:data(), stride:data())
return descriptor
end