1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
|
local SpatialConvolution, parent =
torch.class('cudnn.SpatialConvolution', 'nn.SpatialConvolution')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
function SpatialConvolution:__init(nInputPlane, nOutputPlane,
kW, kH, dW, dH, padW, padH)
parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH)
self.padW = padW or 0
self.padH = padH or 0
self:reset()
self.iSize = torch.LongStorage(4):fill(0)
end
-- if you change the configuration of the module manually, call this
function SpatialConvolution:resetWeightDescriptors()
assert(torch.typename(self.weight) == 'torch.CudaTensor',
'Only Cuda supported duh!')
assert(torch.typename(self.bias) == 'torch.CudaTensor',
'Only Cuda supported duh!')
-- create filterDescriptor for weight
self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane,
self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
'CUDNN_DATA_FLOAT', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
end
ffi.gc(self.weightDesc, destroyWDesc)
-- create descriptor for bias
self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1))
end
function SpatialConvolution:createIODescriptors(input)
local batch = true
if input:dim() == 3 then
input = input:view(1, input:size(1), input:size(2), input:size(3))
batch = false
end
assert(input:dim() == 4 and input:isContiguous());
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
self.iSize = input:size()
-- resize gradInput
if self.gradInput then self.gradInput:resizeAs(input); end
-- create input descriptor
self.iDesc = cudnn.toDescriptor(input)
-- create conv descriptor
self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
local pad = torch.IntTensor({self.padH, self.padW})
local stride = torch.IntTensor({self.dH, self.dW})
local upscale = torch.IntTensor({1,1})
errcheck('cudnnSetConvolutionNdDescriptor', self.convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION');
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
ffi.gc(self.convDesc, destroyConvDesc)
-- create output descriptor and resize output
local oSize = torch.IntTensor(4)
local oSizeD = oSize:data()
errcheck('cudnnGetConvolutionNdForwardOutputDim',
self.convDesc[0], self.iDesc[0],
self.weightDesc[0], 4, oSizeD)
self.output:resize(oSize:long():storage())
-- create descriptor for output
self.oDesc = cudnn.toDescriptor(self.output)
-- create forwardAlgorithm descriptors for
local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
errcheck('cudnnGetConvolutionForwardAlgorithm',
cudnn.handle[cutorch.getDevice()-1],
self.iDesc[0], self.weightDesc[0],
self.convDesc[0], self.oDesc[0],
'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType)
self.algType = algType
local bufSize = torch.LongTensor(1)
errcheck('cudnnGetConvolutionForwardWorkspaceSize',
cudnn.handle[cutorch.getDevice()-1],
self.iDesc[0], self.weightDesc[0],
self.convDesc[0], self.oDesc[0],
algType[0], bufSize:data())
self.extraBuffer = self.extraBuffer or input.new(1)
if bufSize[1] ~= 0 then self.extraBuffer:resize(bufSize[1]) end
if not batch then
self.gradInput = self.gradInput:view(self.gradInput:size(2),
self.gradInput:size(3),
self.gradInput:size(4))
self.output = self.output:view(self.output:size(2),
self.output:size(3),
self.output:size(4))
end
end
end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
function SpatialConvolution:updateOutput(input)
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1],
one:data(),
self.iDesc[0], input:data(),
self.weightDesc[0], self.weight:data(),
self.convDesc[0], self.algType[0],
self.extraBuffer:data(), self.extraBuffer:nElement(),
zero:data(),
self.oDesc[0], self.output:data());
errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1],
'CUDNN_ADD_SAME_C',
one:data(), self.biasDesc[0], self.bias:data(), one:data(),
self.oDesc[0], self.output:data());
return self.output
end
function SpatialConvolution:updateGradInput(input, gradOutput)
if not self.gradInput then return end
assert((gradOutput:dim() == 3 or gradOutput:dim() == 4)
and gradOutput:isContiguous());
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1],
one:data(),
self.weightDesc[0], self.weight:data(),
self.oDesc[0], gradOutput:data(),
self.convDesc[0],
zero:data(),
self.iDesc[0], self.gradInput:data());
return self.gradInput
end
function SpatialConvolution:accGradParameters(input, gradOutput, scale)
self.scaleT = self.scaleT or torch.FloatTensor(1):fill(1.0)
-- this line forces this member to always be on CPU (needed for cudnn)
self.scaleT = self.scaleT:float()
scale = scale or 1.0
self.scaleT[1] = scale
assert((gradOutput:dim() == 3 or gradOutput:dim() == 4)
and gradOutput:isContiguous());
self:createIODescriptors(input)
if not self.weightDesc then self:resetWeightDescriptors() end
-- gradBias
errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1],
self.scaleT:data(),
self.oDesc[0], gradOutput:data(),
one:data(),
self.biasDesc[0], self.gradBias:data());
-- gradWeight
errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1],
self.scaleT:data(),
self.iDesc[0], input:data(),
self.oDesc[0], gradOutput:data(),
self.convDesc[0],
one:data(),
self.weightDesc[0], self.gradWeight:data());
end
--[[
function SpatialConvolution:zeroGradParameters()
-- gradWeight, gradBias to zero
errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1],
self.weightDesc, self.gradWeight:data(), zero:data());
errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1],
self.biasDesc, self.gradBias:data(), zero:data());
end
]]--
|