diff options
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 12 |
1 files changed, 10 insertions, 2 deletions
@@ -150,8 +150,12 @@ function cudnn.getHandle() end function cudnn.call(f, ...) - C.cudnnSetStream(cudnn.getHandle(), +--context might be destroyed by the time gc calls destructors, in which case cudnnSetStream call would fail +--and it is not necessary for cudnn destructors anyway + if not string.find(f, 'cudnnDestroy') then + C.cudnnSetStream(cudnn.getHandle(), thc.THCState_getCurrentStream(cutorch.getState())) + end return C[f](...) end @@ -212,7 +216,9 @@ function cudnn.setConvolutionDescriptor(data, desc) if not data.arrayLength then data.arrayLength = #data.padA end if not data.dilationA then data.dilationA = {1,1,1 } end -- assume maximum length==3 if not data.mode then data.mode = 'CUDNN_CROSS_CORRELATION' end - + if not data.mathType then data.mathType = 'CUDNN_DEFAULT_MATH' end + if not data.groupCount then data.groupCount = 1 end + local myDesc = desc or cudnn.createDescriptors( 1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') @@ -227,6 +233,8 @@ function cudnn.setConvolutionDescriptor(data, desc) upscaleATensor:data(), data.mode, data.dataType) + errcheck('cudnnSetConvolutionMathType', myDesc[0], data.mathType) + errcheck('cudnnSetConvolutionGroupCount', myDesc[0], data.groupCount) return myDesc end |