diff options
author | soumith <soumith@gmail.com> | 2016-02-29 03:11:17 +0300 |
---|---|---|
committer | soumith <soumith@gmail.com> | 2016-02-29 03:11:17 +0300 |
commit | 42b96c054aa5f205faba8d35b46837532f580227 (patch) | |
tree | e36627d36eb9a8dba5bdae6a04380efdb2b8f0d8 | |
parent | b80bdbab6faf66b711bf7c8a159703f62508c5e2 (diff) |
fake halfhalf
-rw-r--r-- | SpatialConvolution.lua | 2 | ||||
-rw-r--r-- | init.lua | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 09e421d..5ee4d83 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -43,7 +43,7 @@ function SpatialConvolution:resetWeightDescriptors() self.nInputPlane/self.groups, self.kH, self.kW}) errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], - 'CUDNN_DATA_FLOAT', 4, + 'CUDNN_DATA_HALF', 4, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); @@ -79,7 +79,7 @@ function cudnn.toDescriptor(t) -- set descriptor local size = torch.LongTensor(t:size()):int() local stride = torch.LongTensor(t:stride()):int() - errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT', + errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_HALF', t:dim(), size:data(), stride:data()) return descriptor end |