Welcome to mirror list, hosted at ThFree Co, Russian Federation.

init.lua - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6eeef676f29e2613fb8cdf4a54ae162af4f9efb6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
require 'cutorch'
require 'nn'
cudnn = {}
include 'ffi.lua'
local C = cudnn.C
local ffi = require 'ffi'

local initialized = false
local maxStreamsPerDevice = 100

function cudnn.getHandle()
   local curStream = cutorch.getStream()
   assert(curStream < maxStreamsPerDevice, 'cudnn bindings only support max of : '
             .. maxStreamsPerDevice .. ' streams per device')
   return cudnn.handle[(((cutorch.getDevice()-1)*maxStreamsPerDevice) + curStream)]
end

local errcheck = function(f, ...)
   if initialized then
      C.cudnnSetStream(cudnn.getHandle(),
                       ffi.C.THCState_getCurrentStream(cutorch.getState()))
   end
   local status = C[f](...)
   if status ~= 'CUDNN_STATUS_SUCCESS' then
      local str = ffi.string(C.cudnnGetErrorString(status))
      error('Error in CuDNN: ' .. str)
   end
end
cudnn.errcheck = errcheck

local numDevices = cutorch.getDeviceCount()
local currentDevice = cutorch.getDevice()
cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices*maxStreamsPerDevice)
-- create handle
for i=1,numDevices do
   cutorch.setDevice(i)
   for j=0,maxStreamsPerDevice-1 do
      errcheck('cudnnCreate', cudnn.handle+(((i-1)*maxStreamsPerDevice) + j))
   end
end
cutorch.setDevice(currentDevice)

local function destroy(handle)
   local currentDevice = cutorch.getDevice()
   for i=1,numDevices do
      cutorch.setDevice(i)
      for j=0,maxStreamsPerDevice-1 do
         errcheck('cudnnDestroy', handle[(((i-1)*maxStreamsPerDevice) + j)]);
      end
   end
   cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)

initialized = true

function cudnn.toDescriptor(t)
   assert(torch.typename(t) == 'torch.CudaTensor')
   local descriptor = ffi.new('struct cudnnTensorStruct*[1]')
   -- create descriptor
   errcheck('cudnnCreateTensorDescriptor', descriptor)
   -- set gc hook
   local function destroy(d)
      errcheck('cudnnDestroyTensorDescriptor', d[0]);
   end
   ffi.gc(descriptor, destroy)
   -- set descriptor
   local size = torch.LongTensor(t:size()):int()
   local stride = torch.LongTensor(t:stride()):int()
   errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_FLOAT',
            t:dim(), size:data(), stride:data())
   return descriptor
end

include 'SpatialConvolution.lua'
include 'VolumetricConvolution.lua'
include 'Pooling.lua'
include 'SpatialMaxPooling.lua'
include 'SpatialAveragePooling.lua'
include 'Pointwise.lua'
include 'ReLU.lua'
include 'Tanh.lua'
include 'Sigmoid.lua'
include 'SpatialSoftMax.lua'
include 'SoftMax.lua'