Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua12
1 files changed, 10 insertions, 2 deletions
diff --git a/init.lua b/init.lua
index 13a77fd..1920ff0 100644
--- a/init.lua
+++ b/init.lua
@@ -150,8 +150,12 @@ function cudnn.getHandle()
end
function cudnn.call(f, ...)
- C.cudnnSetStream(cudnn.getHandle(),
+--context might be destroyed by the time gc calls destructors, in which case cudnnSetStream call would fail
+--and it is not necessary for cudnn destructors anyway
+ if not string.find(f, 'cudnnDestroy') then
+ C.cudnnSetStream(cudnn.getHandle(),
thc.THCState_getCurrentStream(cutorch.getState()))
+ end
return C[f](...)
end
@@ -212,7 +216,9 @@ function cudnn.setConvolutionDescriptor(data, desc)
if not data.arrayLength then data.arrayLength = #data.padA end
if not data.dilationA then data.dilationA = {1,1,1 } end -- assume maximum length==3
if not data.mode then data.mode = 'CUDNN_CROSS_CORRELATION' end
-
+ if not data.mathType then data.mathType = 'CUDNN_DEFAULT_MATH' end
+ if not data.groupCount then data.groupCount = 1 end
+
local myDesc = desc or cudnn.createDescriptors(
1, 'struct cudnnConvolutionStruct*[?]',
'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor')
@@ -227,6 +233,8 @@ function cudnn.setConvolutionDescriptor(data, desc)
upscaleATensor:data(),
data.mode,
data.dataType)
+ errcheck('cudnnSetConvolutionMathType', myDesc[0], data.mathType)
+ errcheck('cudnnSetConvolutionGroupCount', myDesc[0], data.groupCount)
return myDesc
end