diff options
-rw-r--r-- | Pointwise.lua | 4 | ||||
-rw-r--r-- | Pooling.lua | 4 | ||||
-rw-r--r-- | SpatialConvolution.lua | 18 | ||||
-rw-r--r-- | SpatialSoftMax.lua | 4 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 14 | ||||
-rw-r--r-- | init.lua | 26 |
6 files changed, 45 insertions, 25 deletions
diff --git a/Pointwise.lua b/Pointwise.lua index bf1358a..9b3e71a 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -28,7 +28,7 @@ function Pointwise:updateOutput(input) self:createIODescriptors(input) if self.inplace then self.output = input end errcheck('cudnnActivationForward', - cudnn.handle[cutorch.getDevice()-1], self.mode, + cudnn.getHandle(), self.mode, one:data(), self.iDesc[0], input:data(), zero:data(), @@ -46,7 +46,7 @@ function Pointwise:updateGradInput(input, gradOutput) self:createIODescriptors(input) if self.inplace then self.output = input; self.gradInput = gradOutput end errcheck('cudnnActivationBackward', - cudnn.handle[cutorch.getDevice()-1], self.mode, + cudnn.getHandle(), self.mode, one:data(), self.iDesc[0], self.output:data(), self.iDesc[0], gradOutput:data(), diff --git a/Pooling.lua b/Pooling.lua index 085e3f3..7190059 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -86,7 +86,7 @@ local zero = torch.FloatTensor({0}); function Pooling:updateOutput(input) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) - errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnPoolingForward', cudnn.getHandle(), self.poolDesc[0], one:data(), self.iDesc[0], input:data(), @@ -105,7 +105,7 @@ function Pooling:updateGradInput(input, gradOutput) if not self.poolDesc then self:resetPoolDescriptors() end self:createIODescriptors(input) errcheck('cudnnPoolingBackward', - cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], + cudnn.getHandle(), self.poolDesc[0], one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 861d62f..78aa056 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -105,14 +105,14 @@ function SpatialConvolution:createIODescriptors(input) local algWorkspaceLimit = self.nInputPlane * self.kH * self.kW * 4 -- 4 = sizeof int. if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end errcheck('cudnnGetConvolutionForwardAlgorithm', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algSearchMode, algWorkspaceLimit, algType) self.algType = algType local bufSize = torch.LongTensor(1) errcheck('cudnnGetConvolutionForwardWorkspaceSize', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algType[0], bufSize:data()) @@ -144,7 +144,7 @@ function SpatialConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) for g=0,self.groups-1 do - errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionForward', cudnn.getHandle(), one:data(), self.iDesc[0], input:data() + g*self.input_offset, self.weightDesc[0], self.weight:data() + g*self.weight_offset, @@ -152,7 +152,7 @@ function SpatialConvolution:updateOutput(input) self.extraBuffer:data(), self.extraBuffer:nElement(), zero:data(), self.oDesc[0], self.output:data() + g*self.output_offset); - errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnAddTensor', cudnn.getHandle(), 'CUDNN_ADD_SAME_C', one:data(), self.biasDesc[0], self.bias:data() + g*self.bias_offset, one:data(), self.oDesc[0], self.output:data() + g*self.output_offset); @@ -167,7 +167,7 @@ function SpatialConvolution:updateGradInput(input, gradOutput) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) for g=0,self.groups-1 do - errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(), one:data(), self.weightDesc[0], self.weight:data() + g*self.weight_offset, self.oDesc[0], gradOutput:data() + g*self.output_offset, @@ -190,13 +190,13 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) if not self.weightDesc then self:resetWeightDescriptors() end for g=0,self.groups-1 do -- gradBias - errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(), self.scaleT:data(), self.oDesc[0], gradOutput:data() + g*self.output_offset, one:data(), self.biasDesc[0], self.gradBias:data() + g*self.bias_offset); -- gradWeight - errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(), self.scaleT:data(), self.iDesc[0], input:data() + g*self.input_offset, self.oDesc[0], gradOutput:data() + g*self.output_offset, @@ -209,9 +209,9 @@ end --[[ function SpatialConvolution:zeroGradParameters() -- gradWeight, gradBias to zero - errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnSetTensor', cudnn.getHandle(), self.weightDesc, self.gradWeight:data(), zero:data()); - errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnSetTensor', cudnn.getHandle(), self.biasDesc, self.gradBias:data(), zero:data()); end ]]-- diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 97c1e38..adbb044 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -58,7 +58,7 @@ local zero = torch.FloatTensor({0}); function SpatialSoftMax:updateOutput(input) self:createIODescriptors(input) errcheck('cudnnSoftmaxForward', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.algorithm, self.mode, one:data(), self.iDesc[0], input:data(), @@ -71,7 +71,7 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) assert(gradOutput:isContiguous()); self:createIODescriptors(input) errcheck('cudnnSoftmaxBackward', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.algorithm, self.mode, one:data(), self.oDesc[0], self.output:data(), diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 74857e2..1c71477 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -88,14 +88,14 @@ function VolumetricConvolution:createIODescriptors(input) -- create forwardAlgorithm descriptors for local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1) errcheck('cudnnGetConvolutionForwardAlgorithm', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST', -1, algType) self.algType = algType local bufSize = torch.LongTensor(1) errcheck('cudnnGetConvolutionForwardWorkspaceSize', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.iDesc[0], self.weightDesc[0], self.convDesc[0], self.oDesc[0], algType[0], bufSize:data()) @@ -121,7 +121,7 @@ local zero = torch.FloatTensor({0}); function VolumetricConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) - errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionForward', cudnn.getHandle(), one:data(), self.iDesc[0], input:data(), self.weightDesc[0], self.weight:data(), @@ -129,7 +129,7 @@ function VolumetricConvolution:updateOutput(input) self.extraBuffer:data(), self.extraBuffer:nElement(), zero:data(), self.oDesc[0], self.output:data()); - errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnAddTensor', cudnn.getHandle(), 'CUDNN_ADD_SAME_C', one:data(), self.biasDesc[0], self.bias:data(), one:data(), self.oDescBias[0], self.output:data()); @@ -142,7 +142,7 @@ function VolumetricConvolution:updateGradInput(input, gradOutput) and gradOutput:isContiguous()); if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) - errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(), one:data(), self.weightDesc[0], self.weight:data(), self.oDesc[0], gradOutput:data(), @@ -164,14 +164,14 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale) self:createIODescriptors(input) if not self.weightDesc then self:resetWeightDescriptors() end -- gradBias - errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], + errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(), self.scaleT:data(), self.oDescBias[0], gradOutput:data(), one:data(), self.biasDesc[0], self.gradBias:data()); -- gradWeight errcheck('cudnnConvolutionBackwardFilter', - cudnn.handle[cutorch.getDevice()-1], + cudnn.getHandle(), self.scaleT:data(), self.iDesc[0], input:data(), self.oDesc[0], gradOutput:data(), @@ -5,7 +5,21 @@ include 'ffi.lua' local C = cudnn.C local ffi = require 'ffi' +local initialized = false +local maxStreamsPerDevice = 100 + +function cudnn.getHandle() + local curStream = cutorch.getStream() + assert(curStream < maxStreamsPerDevice, 'cudnn bindings only support max of : ' + .. maxStreamsPerDevice .. ' streams per device') + return cudnn.handle[(((cutorch.getDevice()-1)*maxStreamsPerDevice) + curStream)] +end + local errcheck = function(f, ...) + if initialized then + C.cudnnSetStream(cudnn.getHandle(), + ffi.C.THCState_getCurrentStream(cutorch.getState())) + end local status = C[f](...) if status ~= 'CUDNN_STATUS_SUCCESS' then local str = ffi.string(C.cudnnGetErrorString(status)) @@ -16,11 +30,13 @@ cudnn.errcheck = errcheck local numDevices = cutorch.getDeviceCount() local currentDevice = cutorch.getDevice() -cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices) +cudnn.handle = ffi.new('struct cudnnContext*[?]', numDevices*maxStreamsPerDevice) -- create handle for i=1,numDevices do cutorch.setDevice(i) - errcheck('cudnnCreate', cudnn.handle+i-1) + for j=0,maxStreamsPerDevice-1 do + errcheck('cudnnCreate', cudnn.handle+(((i-1)*maxStreamsPerDevice) + j)) + end end cutorch.setDevice(currentDevice) @@ -28,12 +44,16 @@ local function destroy(handle) local currentDevice = cutorch.getDevice() for i=1,numDevices do cutorch.setDevice(i) - errcheck('cudnnDestroy', handle[i-1]); + for j=0,maxStreamsPerDevice-1 do + errcheck('cudnnDestroy', handle[(((i-1)*maxStreamsPerDevice) + j)]); + end end cutorch.setDevice(currentDevice) end ffi.gc(cudnn.handle, destroy) +initialized = true + function cudnn.toDescriptor(t) assert(torch.typename(t) == 'torch.CudaTensor') local descriptor = ffi.new('struct cudnnTensorStruct*[1]') |