diff options
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r-- | VolumetricConvolution.lua | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index fd5e9c7..b255467 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -10,17 +10,15 @@ autotunerCache[3] = {} -- backwardData -- if you change the configuration of the module manually, call this function VolumetricConvolution:resetWeightDescriptors() - assert(torch.typename(self.weight) == 'torch.CudaTensor', - 'Only Cuda supported duh!') - assert(torch.typename(self.bias) == 'torch.CudaTensor', - 'Only Cuda supported duh!') + assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!') + assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!') -- create filterDescriptor for weight self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]') errcheck('cudnnCreateFilterDescriptor', self.weightDesc) local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane, self.kT, self.kH, self.kW}) errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], - 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 5, + cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', 5, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); @@ -87,7 +85,7 @@ function VolumetricConvolution:createIODescriptors(input) errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0], 3, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - 'CUDNN_DATA_FLOAT'); + cudnn.configmap(torch.type(self.weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end |