1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
require 'cutorch'
require 'nn'
cudnn = require 'cudnn.env'
require('cudnn.ffi')
local C = cudnn.C
local ffi = require 'ffi'
cudnn.benchmark = false
cudnn.fastest = false
local maxStreamsPerDevice = 1024
local numDevices = cutorch.getDeviceCount()
-- this tensor keeps track of whether a handle has been initialized or not
local handleStatus = torch.ByteTensor(numDevices,
maxStreamsPerDevice):zero()
-- here we create an array of cudnn handle structs
cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices*maxStreamsPerDevice)
local function destroy(handle)
local currentDevice = cutorch.getDevice()
for i=1,numDevices do
cutorch.setDevice(i)
-- streams go from 0 to maxStreamsPerDevice - 1
for j=0,maxStreamsPerDevice - 1 do
if handleStatus[i][j + 1] == 1 then -- if handle was created
errcheck('cudnnDestroy', handle[(((i-1)*maxStreamsPerDevice) + j)]);
end
end
end
cutorch.setDevice(currentDevice)
end
ffi.gc(cudnn.handle, destroy)
function cudnn.getHandle()
local device = cutorch.getDevice()
local stream = cutorch.getStream() -- starts from 0
assert(stream < maxStreamsPerDevice, 'cudnn bindings only support max of : '
.. maxStreamsPerDevice .. ' streams per device')
-- lazy initialization of handles
if handleStatus[device][stream + 1] == 0 then
local status = C['cudnnCreate'](cudnn.handle
+ (((device-1) * maxStreamsPerDevice)
+ stream))
if status ~= ffi.C.CUDNN_STATUS_SUCCESS then
local str = ffi.string(C.cudnnGetErrorString(status))
error('Error in CuDNN: ' .. str)
end
handleStatus[device][stream + 1] = 1 -- mark handle as initialized
end
return cudnn.handle[(((device-1)*maxStreamsPerDevice) + stream)]
end
local errcheck = function(f, ...)
C.cudnnSetStream(cudnn.getHandle(),
ffi.C.THCState_getCurrentStream(cutorch.getState()))
local status = C[f](...)
if status ~= ffi.C.CUDNN_STATUS_SUCCESS then
local str = ffi.string(C.cudnnGetErrorString(status))
error('Error in CuDNN: ' .. str)
end
end
cudnn.errcheck = errcheck
function cudnn.toDescriptor(t)
assert(torch.typename(t) == 'torch.CudaTensor')
local descriptor = ffi.new('struct cudnnTensorStruct*[1]')
-- create descriptor
errcheck('cudnnCreateTensorDescriptor', descriptor)
-- set gc hook
local function destroy(d)
errcheck('cudnnDestroyTensorDescriptor', d[0]);
end
ffi.gc(descriptor, destroy)
-- view 2D and 3D as 4D
if t:dim() == 2 then
t = t:view(t:size(1), t:size(2), 1, 1)
elseif t:dim() == 3 then
t = t:view(t:size(1), t:size(2), t:size(3), 1)
end
-- set descriptor
local size = torch.LongTensor(t:size()):int()
local stride = torch.LongTensor(t:stride()):int()
errcheck('cudnnSetTensorNdDescriptor', descriptor[0], 'CUDNN_DATA_HALF',
t:dim(), size:data(), stride:data())
return descriptor
end
local sharedBuffer = {}
for i=1,numDevices do
sharedBuffer[i] = {}
end
function cudnn.getSharedWorkspace()
local device = cutorch.getDevice()
local stream = cutorch.getStream() -- starts from 0
if not sharedBuffer[device][stream] then
sharedBuffer[device][stream] = torch.CudaTensor(1)
end
return sharedBuffer[device][stream]
end
require('cudnn.SpatialConvolution')
require('cudnn.VolumetricConvolution')
require('cudnn.Pooling')
require('cudnn.SpatialMaxPooling')
require('cudnn.SpatialAveragePooling')
require('cudnn.Pooling3D')
require('cudnn.VolumetricMaxPooling')
require('cudnn.VolumetricAveragePooling')
require('cudnn.Pointwise')
require('cudnn.ReLU')
require('cudnn.Tanh')
require('cudnn.Sigmoid')
require('cudnn.SpatialSoftMax')
require('cudnn.SpatialLogSoftMax')
require('cudnn.SoftMax')
require('cudnn.LogSoftMax')
require('cudnn.SpatialCrossMapLRN')
require('cudnn.BatchNormalization')
require('cudnn.SpatialBatchNormalization')
require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
require('cudnn.functional')
require('cudnn.convert')
return cudnn
|