diff options
-rw-r--r-- | init.lua | 14 |
1 files changed, 10 insertions, 4 deletions
@@ -216,11 +216,15 @@ function cudnn.setConvolutionDescriptor(data, desc) local myDesc = desc or cudnn.createDescriptors( 1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') + -- make sure we have references to these tensors so gc doesn't clean them up + local padATensor = torch.IntTensor(data.padA) + local filterStrideATensor = torch.IntTensor(data.filterStrideA) + local upscaleATensor = torch.IntTensor(data.upscaleA) errcheck('cudnnSetConvolutionNdDescriptor', myDesc[0], data.arrayLength, - torch.IntTensor(data.padA):data(), - torch.IntTensor(data.filterStrideA):data(), - torch.IntTensor(data.upscaleA):data(), + padATensor:data(), + filterStrideATensor:data(), + upscaleATensor:data(), data.mode, data.dataType) return myDesc @@ -231,9 +235,11 @@ function cudnn.setFilterDescriptor(data, filterDesc) 1, 'struct cudnnFilterStruct*[?]', 'cudnnCreateFilterDescriptor', 'cudnnDestroyFilterDescriptor') local dims = data.nbDims or #data.filterDimA + -- make sure we have references to these tensors so gc doesn't clean them up + local filterDimATensor = torch.IntTensor(data.filterDimA) errcheck('cudnnSetFilterNdDescriptor', myDesc[0], data.dataType, data.format or 'CUDNN_TENSOR_NCHW', - dims, torch.IntTensor(data.filterDimA):data()); + dims, filterDimATensor:data()); return myDesc end |