From 42b96c054aa5f205faba8d35b46837532f580227 Mon Sep 17 00:00:00 2001 From: soumith Date: Sun, 28 Feb 2016 19:11:17 -0500 Subject: fake half --- SpatialConvolution.lua | 2 +- init.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 09e421d..5ee4d83 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -43,7 +43,7 @@ function SpatialConvolution:resetWeightDescriptors() self.nInputPlane/self.groups, self.kH, self.kW}) errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], - 'CUDNN_DATA_FLOAT', 4, + 'CUDNN_DATA_HALF', 4, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); diff --git a/init.lua b/init.lua index 53cb7ea..71c7e57 100644 --- a/init.lua +++ b/init.lua @@ -79,7 +79,7 @@ function cudnn.toDescriptor(t) -- set descriptor local size = torch.LongTensor(t:size()):int() local stride = torch.LongTensor(t:stride()):int() - errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT', + errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_HALF', t:dim(), size:data(), stride:data()) return descriptor end -- cgit v1.2.3