Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-04-10 05:13:50 +0300
committersoumith <soumith@fb.com>2015-04-10 05:13:50 +0300
commit35d4f5df368415c27dda955130bcc01d6234ffe6 (patch)
tree24ea2d025ba13791d85f97a4cbb2adaab627278f /init.lua
parent48b3b6df88198c28086de74a3d74d4745d507f76 (diff)
using the new streams API (cudnn does not ovelap compute yet, weird)
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua26
1 files changed, 23 insertions, 3 deletions
diff --git a/init.lua b/init.lua
index 5867a99..6eeef67 100644
--- a/init.lua
+++ b/init.lua
@@ -5,7 +5,21 @@ include 'ffi.lua'
local C = cudnn.C
local ffi = require 'ffi'
+local initialized = false
+local maxStreamsPerDevice = 100
+
+function cudnn.getHandle()
+ local curStream = cutorch.getStream()
+ assert(curStream < maxStreamsPerDevice, 'cudnn bindings only support max of : '
+ .. maxStreamsPerDevice .. ' streams per device')
+ return cudnn.handle[(((cutorch.getDevice()-1)*maxStreamsPerDevice) + curStream)]
+end
+
local errcheck = function(f, ...)
+ if initialized then
+ C.cudnnSetStream(cudnn.getHandle(),
+ ffi.C.THCState_getCurrentStream(cutorch.getState()))
+ end
local status = C[f](...)
if status ~= 'CUDNN_STATUS_SUCCESS' then
local str = ffi.string(C.cudnnGetErrorString(status))
@@ -16,11 +30,13 @@ cudnn.errcheck = errcheck
local numDevices = cutorch.getDeviceCount()
local currentDevice = cutorch.getDevice()
-cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices)
+cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices*maxStreamsPerDevice)
-- create handle
for i=1,numDevices do
cutorch.setDevice(i)
- errcheck('cudnnCreate', cudnn.handle+i-1)
+ for j=0,maxStreamsPerDevice-1 do
+ errcheck('cudnnCreate', cudnn.handle+(((i-1)*maxStreamsPerDevice) + j))
+ end
end
cutorch.setDevice(currentDevice)
@@ -28,12 +44,16 @@ local function destroy(handle)
local currentDevice = cutorch.getDevice()
for i=1,numDevices do
cutorch.setDevice(i)
- errcheck('cudnnDestroy', handle[i-1]);
+ for j=0,maxStreamsPerDevice-1 do
+ errcheck('cudnnDestroy', handle[(((i-1)*maxStreamsPerDevice) + j)]);
+ end
end
cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)
+initialized = true
+
function cudnn.toDescriptor(t)
assert(torch.typename(t) == 'torch.CudaTensor')
local descriptor = ffi.new('struct cudnnTensorStruct*[1]')