Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-02-17 07:16:51 +0300
committerGitHub <noreply@github.com>2017-02-17 07:16:51 +0300
commit7f3e2b22c50d12c8583f33ff792c88d692bcef49 (patch)
tree642afc147974032fb8ef2d0558006deb66a440ac
parentc412ec1818722bb710e87cf7eb6bae1e17532373 (diff)
parent45c4c7b0cfe5b7ff63525dabd5ed5265a744ef77 (diff)
Merge pull request #331 from gchanan/gcfix
Fix CUDNN_STATUS_BAD_PARAM due to garbage collection reclaiming
-rw-r--r--init.lua14
1 files changed, 10 insertions, 4 deletions
diff --git a/init.lua b/init.lua
index 1519212..246583b 100644
--- a/init.lua
+++ b/init.lua
@@ -216,11 +216,15 @@ function cudnn.setConvolutionDescriptor(data, desc)
local myDesc = desc or cudnn.createDescriptors(
1, 'struct cudnnConvolutionStruct*[?]',
'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor')
+ -- make sure we have references to these tensors so gc doesn't clean them up
+ local padATensor = torch.IntTensor(data.padA)
+ local filterStrideATensor = torch.IntTensor(data.filterStrideA)
+ local upscaleATensor = torch.IntTensor(data.upscaleA)
errcheck('cudnnSetConvolutionNdDescriptor', myDesc[0],
data.arrayLength,
- torch.IntTensor(data.padA):data(),
- torch.IntTensor(data.filterStrideA):data(),
- torch.IntTensor(data.upscaleA):data(),
+ padATensor:data(),
+ filterStrideATensor:data(),
+ upscaleATensor:data(),
data.mode,
data.dataType)
return myDesc
@@ -231,9 +235,11 @@ function cudnn.setFilterDescriptor(data, filterDesc)
1, 'struct cudnnFilterStruct*[?]',
'cudnnCreateFilterDescriptor', 'cudnnDestroyFilterDescriptor')
local dims = data.nbDims or #data.filterDimA
+ -- make sure we have references to these tensors so gc doesn't clean them up
+ local filterDimATensor = torch.IntTensor(data.filterDimA)
errcheck('cudnnSetFilterNdDescriptor', myDesc[0],
data.dataType, data.format or 'CUDNN_TENSOR_NCHW',
- dims, torch.IntTensor(data.filterDimA):data());
+ dims, filterDimATensor:data());
return myDesc
end