diff options
author | soumith <soumith@fb.com> | 2015-04-10 05:13:50 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-04-10 05:13:50 +0300 |
commit | 35d4f5df368415c27dda955130bcc01d6234ffe6 (patch) | |
tree | 24ea2d025ba13791d85f97a4cbb2adaab627278f /init.lua | |
parent | 48b3b6df88198c28086de74a3d74d4745d507f76 (diff) |
using the new streams API (cudnn does not ovelap compute yet, weird)
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 26 |
1 files changed, 23 insertions, 3 deletions
@@ -5,7 +5,21 @@ include 'ffi.lua' local C = cudnn.C local ffi = require 'ffi' +local initialized = false +local maxStreamsPerDevice = 100 + +function cudnn.getHandle() + local curStream = cutorch.getStream() + assert(curStream < maxStreamsPerDevice, 'cudnn bindings only support max of : ' + .. maxStreamsPerDevice .. ' streams per device') + return cudnn.handle[(((cutorch.getDevice()-1)*maxStreamsPerDevice) + curStream)] +end + local errcheck = function(f, ...) + if initialized then + C.cudnnSetStream(cudnn.getHandle(), + ffi.C.THCState_getCurrentStream(cutorch.getState())) + end local status = C[f](...) if status ~= 'CUDNN_STATUS_SUCCESS' then local str = ffi.string(C.cudnnGetErrorString(status)) @@ -16,11 +30,13 @@ cudnn.errcheck = errcheck local numDevices = cutorch.getDeviceCount() local currentDevice = cutorch.getDevice() -cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices) +cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices*maxStreamsPerDevice) -- create handle for i=1,numDevices do cutorch.setDevice(i) - errcheck('cudnnCreate', cudnn.handle+i-1) + for j=0,maxStreamsPerDevice-1 do + errcheck('cudnnCreate', cudnn.handle+(((i-1)*maxStreamsPerDevice) + j)) + end end cutorch.setDevice(currentDevice) @@ -28,12 +44,16 @@ local function destroy(handle) local currentDevice = cutorch.getDevice() for i=1,numDevices do cutorch.setDevice(i) - errcheck('cudnnDestroy', handle[i-1]); + for j=0,maxStreamsPerDevice-1 do + errcheck('cudnnDestroy', handle[(((i-1)*maxStreamsPerDevice) + j)]); + end end cutorch.setDevice(currentDevice) end ffi.gc(cudnn.handle, destroy) +initialized = true + function cudnn.toDescriptor(t) assert(torch.typename(t) == 'torch.CudaTensor') local descriptor = ffi.new('struct cudnnTensorStruct*[1]') |