Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-20 22:55:33 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-21 03:51:16 +0400
commit70433d6359cdae6833c315bb8151038ed9f75a1c (patch)
tree9d5fc34dfa6e63a920d80ea0e7090b368986eb45 /init.lua
parentae62e2be0dc9cf7500972b7355ccfd46e3d2b1a8 (diff)
Multi-GPU support
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua20
1 files changed, 16 insertions, 4 deletions
diff --git a/init.lua b/init.lua
index 6aebcbb..a435864 100644
--- a/init.lua
+++ b/init.lua
@@ -13,11 +13,23 @@ local errcheck = function(f, ...)
end
cudnn.errcheck = errcheck
-cudnn.handle = ffi.new('struct cudnnContext*[1]')
+local numDevices = cutorch.getDeviceCount()
+local currentDevice = cutorch.getDevice()
+cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices)
-- create handle
-errcheck('cudnnCreate', cudnn.handle)
-local function destroy(handle)
- errcheck('cudnnDestroy', handle[0]);
+for i=1,numDevices do
+ cutorch.setDevice(i)
+ errcheck('cudnnCreate', cudnn.handle+i-1)
+end
+cutorch.setDevice(currentDevice)
+
+local function destroy(handle)
+ local currentDevice = cutorch.getDevice()
+ for i=1,numDevices do
+ cutorch.setDevice(i)
+ errcheck('cudnnDestroy', handle[i-1]);
+ end
+ cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)