Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r--VolumetricConvolution.lua10
1 files changed, 4 insertions, 6 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index fd5e9c7..b255467 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -10,17 +10,15 @@ autotunerCache[3] = {} -- backwardData
-- if you change the configuration of the module manually, call this
function VolumetricConvolution:resetWeightDescriptors()
- assert(torch.typename(self.weight) == 'torch.CudaTensor',
- 'Only Cuda supported duh!')
- assert(torch.typename(self.bias) == 'torch.CudaTensor',
- 'Only Cuda supported duh!')
+ assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!')
+ assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!')
-- create filterDescriptor for weight
self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane,
self.kT, self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 5,
+ cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', 5,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -87,7 +85,7 @@ function VolumetricConvolution:createIODescriptors(input)
errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
3, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- 'CUDNN_DATA_FLOAT');
+ cudnn.configmap(torch.type(self.weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end