Welcome to mirror list, hosted at ThFree Co, Russian Federation.

VolumetricConvolution.lua - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6ec33029ff4ac95bc1ad07f4ceb86b234e0e104a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
local VolumetricConvolution, parent
   = torch.class('cudnn.VolumetricConvolution', 'nn.VolumetricConvolution')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck

-- if you change the configuration of the module manually, call this
function VolumetricConvolution:resetWeightDescriptors()
   assert(torch.typename(self.weight) == 'torch.CudaTensor',
          'Only Cuda supported duh!')
   assert(torch.typename(self.bias) == 'torch.CudaTensor',
          'Only Cuda supported duh!')
   -- create filterDescriptor for weight
   self.weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
   errcheck('cudnnCreateFilterDescriptor', self.weightDesc)
   local desc = torch.IntTensor({self.nOutputPlane, self.nInputPlane,
                             self.kT, self.kH, self.kW})
   errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
            'CUDNN_DATA_FLOAT', 5,
            desc:data());
   local function destroyWDesc(d)
      errcheck('cudnnDestroyFilterDescriptor', d[0]);
   end
   ffi.gc(self.weightDesc, destroyWDesc)

   -- create descriptor for bias
   self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,
                                                     1, 1))
end

function VolumetricConvolution:fastest(mode)
   if mode == nil then mode = true end
   self.fastest_mode = mode
   self.iSize = self.iSize or torch.LongStorage(4)
   self.iSize:fill(0)
   return self
end

function VolumetricConvolution:setMode(fmode, bdmode, bwmode)
   if fmode ~= nil then
      self.fmode = fmode
   end
   if bdmode ~= nil then
      self.bdmode = bdmode
   end
   if bwmode ~= nil then
      self.bwmode = bwmode
   end
   self.iSize = self.iSize or torch.LongStorage(4)
   self.iSize:fill(0)
   return self
end

function VolumetricConvolution:resetMode()
   self.fmode = nil
   self.bdmode = nil
   self.bwmode = nil
   return self
end

function VolumetricConvolution:createIODescriptors(input)
   local batch = true
   if input:dim() == 4 then
      input = input:view(1, input:size(1), input:size(2),
                         input:size(3), input:size(4))
      batch = false
   end
   assert(input:dim() == 5 and input:isContiguous());
   self.iSize = self.iSize or torch.LongStorage(4):fill(0)
   if not self.iDesc or not self.oDesc or
      input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
   or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4]
   or input:size(5) ~= self.iSize[5] then
         self.iSize = input:size()
         -- resize gradInput
         if self.gradInput then self.gradInput:resizeAs(input); end
         -- create input descriptor
         self.iDesc = cudnn.toDescriptor(input)
         -- create conv descriptor
         self.convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
         errcheck('cudnnCreateConvolutionDescriptor', self.convDesc)
         local pad = torch.IntTensor({self.padT, self.padH, self.padW})
         local stride = torch.IntTensor({self.dT, self.dH, self.dW})
         local upscale = torch.IntTensor({1,1,1})
         errcheck('cudnnSetConvolutionNdDescriptor_v3', self.convDesc[0],
                  3, pad:data(),
                  stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
          'CUDNN_DATA_FLOAT');
         local function destroyConvDesc(d)
            errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
         end
         ffi.gc(self.convDesc, destroyConvDesc)

         -- create output descriptor and resize output
         local oSize = torch.IntTensor(5)
         local oSizeD = oSize:data()
         errcheck('cudnnGetConvolutionNdForwardOutputDim',
                  self.convDesc[0], self.iDesc[0],
                  self.weightDesc[0], 5, oSizeD)
         self.output:resize(oSize:long():storage())
         -- create descriptor for output
         self.oDesc = cudnn.toDescriptor(self.output)
         self.oDescBias = cudnn.toDescriptor(
            self.output:view(self.output:size(1),
                             self.output:size(2),
                             self.output:size(3)*self.output:size(4),
                             self.output:size(5)))
     -----------------------------------------------------------------
     local maxBufSize = 0
         -- create forwardAlgorithm descriptors for
         local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
     local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
     local algWorkspaceLimit = self.workspace_limit
        or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
     if self.fastest_mode  or cudnn.fastest == true then
         algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST'
     end
     errcheck('cudnnGetConvolutionForwardAlgorithm',
              cudnn.getHandle(),
              self.iDesc[0], self.weightDesc[0],
              self.convDesc[0], self.oDesc[0],
              algSearchMode, algWorkspaceLimit, algType)
     algType[0] = self.fmode or algType[0]
         self.fwdAlgType = algType
         local bufSize = torch.LongTensor(1)
         errcheck('cudnnGetConvolutionForwardWorkspaceSize',
                  cudnn.getHandle(),
                  self.iDesc[0], self.weightDesc[0],
                  self.convDesc[0], self.oDesc[0],
                  algType[0], bufSize:data())
     maxBufSize = math.max(maxBufSize, bufSize[1])

     -- create backwardFilterAlgorithm descriptors for
         local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
     local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
     local algWorkspaceLimit = self.workspace_limit
        or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
     if self.fastest_mode  or cudnn.fastest == true then
         algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST'
     end
         errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
                  cudnn.getHandle(),
                  self.iDesc[0], self.oDesc[0],
                  self.convDesc[0], self.weightDesc[0],
                  algSearchMode, algWorkspaceLimit, algType)
     algType[0] = self.bwmode or algType[0]
         self.bwdFilterAlgType = algType
         local bufSize = torch.LongTensor(1)
         errcheck('cudnnGetConvolutionBackwardFilterWorkspaceSize',
                  cudnn.getHandle(),
                  self.iDesc[0], self.oDesc[0],
                  self.convDesc[0], self.weightDesc[0],
                  algType[0], bufSize:data())
     maxBufSize = math.max(maxBufSize, bufSize[1])

     -- create backwardDataAlgorithm descriptors for
         local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
     local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
     local algWorkspaceLimit = self.workspace_limit
        or (self.nInputPlane * self.kT * self.kH * self.kW * 4) -- 4 = sizeof int/float.
     if self.fastest_mode  or cudnn.fastest == true then
         algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST'
     end
         errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
                  cudnn.getHandle(),
                  self.weightDesc[0], self.oDesc[0],
                  self.convDesc[0], self.iDesc[0],
                  algSearchMode, algWorkspaceLimit, algType)
     algType[0] = self.bdmode or algType[0]
         self.bwdDataAlgType = algType
         local bufSize = torch.LongTensor(1)
         errcheck('cudnnGetConvolutionBackwardDataWorkspaceSize',
                  cudnn.getHandle(),
                  self.weightDesc[0], self.oDesc[0],
                  self.convDesc[0], self.iDesc[0],
                  algType[0], bufSize:data())
     maxBufSize = math.max(maxBufSize, bufSize[1])

     self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace()
     self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float
         if maxBufSize > self.extraBufferSizeInBytes then
           self.extraBuffer:resize(math.ceil(maxBufSize/4))
           self.extraBufferSizeInBytes = maxBufSize
         end

     -----------------------------------------------------------------
         if not batch then
            self.gradInput = self.gradInput:view(self.gradInput:size(2),
                                                 self.gradInput:size(3),
                                                 self.gradInput:size(4),
                                                 self.gradInput:size(5))
            self.output = self.output:view(self.output:size(2),
                                           self.output:size(3),
                                           self.output:size(4),
                                           self.output:size(5))
         end
   end
end

local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});

local function makeContiguous(self, input, gradOutput)
   if not input:isContiguous() then
      self._input = self._input or input.new()
      self._input:typeAs(input):resizeAs(input):copy(input)
      input = self._input
   end
   if gradOutput and not gradOutput:isContiguous() then
      self._gradOutput = self._gradOutput or gradOutput.new()
      self._gradOutput:typeAs(gradOutput):resizeAs(gradOutput):copy(gradOutput)
      gradOutput = self._gradOutput
   end
   return input, gradOutput
end

function VolumetricConvolution:updateOutput(input)
   if not self.weightDesc then self:resetWeightDescriptors() end
   input = makeContiguous(self, input)
   self:createIODescriptors(input)
   errcheck('cudnnConvolutionForward', cudnn.getHandle(),
            one:data(),
            self.iDesc[0], input:data(),
            self.weightDesc[0], self.weight:data(),
            self.convDesc[0], self.fwdAlgType[0],
            self.extraBuffer:data(), self.extraBufferSizeInBytes,
            zero:data(),
            self.oDesc[0], self.output:data());
   errcheck('cudnnAddTensor', cudnn.getHandle(),
            one:data(),
            self.biasDesc[0], self.bias:data(), one:data(),
            self.oDescBias[0], self.output:data());
   return self.output
end

function VolumetricConvolution:updateGradInput(input, gradOutput)
   if not self.gradInput then return end
   input, gradOutput = makeContiguous(self, input, gradOutput)
   assert(gradOutput:dim() == 4 or gradOutput:dim() == 5,
          'gradOutput has to be a 4D or 5D tensor');
   if not self.weightDesc then self:resetWeightDescriptors() end
   self:createIODescriptors(input)
   errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(),
        one:data(),
        self.weightDesc[0], self.weight:data(),
        self.oDesc[0], gradOutput:data(),
        self.convDesc[0],
        self.bwdDataAlgType[0],
        self.extraBuffer:data(), self.extraBufferSizeInBytes,
        zero:data(),
        self.iDesc[0], self.gradInput:data());
   return self.gradInput
end

function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
   self.scaleT = self.scaleT or torch.FloatTensor(1):fill(1.0)
   -- this line forces this member to always be on CPU (needed for cudnn)
   self.scaleT = self.scaleT:float()

   scale = scale or 1.0
   self.scaleT[1] = scale
   input, gradOutput = makeContiguous(self, input, gradOutput)
   assert(gradOutput:dim() == 4 or gradOutput:dim() == 5,
          'gradOutput has to be a 4D or 5D tensor');
   self:createIODescriptors(input)
   if not self.weightDesc then self:resetWeightDescriptors() end
   -- gradBias
   errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
            self.scaleT:data(),
            self.oDescBias[0], gradOutput:data(),
            one:data(),
            self.biasDesc[0], self.gradBias:data());
   -- gradWeight
   errcheck('cudnnConvolutionBackwardFilter_v3', cudnn.getHandle(),
        self.scaleT:data(),
        self.iDesc[0], input:data(),
        self.oDesc[0], gradOutput:data(),
        self.convDesc[0],
        self.bwdFilterAlgType[0],
        self.extraBuffer:data(), self.extraBufferSizeInBytes,
        one:data(),
        self.weightDesc[0], self.gradWeight:data());
end

function VolumetricConvolution:clearDesc()
   self.weightDesc = nil
   self.biasDesc = nil
   self.convDesc = nil
   self.iDesc = nil
   self.oDesc = nil
   self.oDescBias = nil
   self.fwdAlgType = nil
   self.bwdDataAlgType = nil
   self.bwdFilterAlgType = nil
   self.extraBuffer = nil
   self.extraBufferInBytes = nil
   self.scaleT = nil
end

function VolumetricConvolution:write(f)
   self:clearDesc()
   local var = {}
   for k,v in pairs(self) do
      var[k] = v
   end
   f:writeObject(var)
end

function VolumetricConvolution:clearState()
   self:clearDesc()
   nn.utils.clear(self, 'extraBuffer', '_input', '_gradOutput')
   return nn.Module.clearState(self)
end