Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-04-10 05:13:50 +0300
committersoumith <soumith@fb.com>2015-04-10 05:13:50 +0300
commit35d4f5df368415c27dda955130bcc01d6234ffe6 (patch)
tree24ea2d025ba13791d85f97a4cbb2adaab627278f /SpatialConvolution.lua
parent48b3b6df88198c28086de74a3d74d4745d507f76 (diff)
using the new streams API (cudnn does not ovelap compute yet, weird)
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua18
1 files changed, 9 insertions, 9 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 861d62f..78aa056 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -105,14 +105,14 @@ function SpatialConvolution:createIODescriptors(input)
local algWorkspaceLimit = self.nInputPlane * self.kH * self.kW * 4 -- 4 = sizeof int.
if self.fastest_mode then algSearchMode = 'CUDNN_CONVOLUTION_FWD_PREFER_FASTEST' end
errcheck('cudnnGetConvolutionForwardAlgorithm',
- cudnn.handle[cutorch.getDevice()-1],
+ cudnn.getHandle(),
self.iDesc[0], self.weightDesc[0],
self.convDesc[0], self.oDesc[0],
algSearchMode, algWorkspaceLimit, algType)
self.algType = algType
local bufSize = torch.LongTensor(1)
errcheck('cudnnGetConvolutionForwardWorkspaceSize',
- cudnn.handle[cutorch.getDevice()-1],
+ cudnn.getHandle(),
self.iDesc[0], self.weightDesc[0],
self.convDesc[0], self.oDesc[0],
algType[0], bufSize:data())
@@ -144,7 +144,7 @@ function SpatialConvolution:updateOutput(input)
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
for g=0,self.groups-1 do
- errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnConvolutionForward', cudnn.getHandle(),
one:data(),
self.iDesc[0], input:data() + g*self.input_offset,
self.weightDesc[0], self.weight:data() + g*self.weight_offset,
@@ -152,7 +152,7 @@ function SpatialConvolution:updateOutput(input)
self.extraBuffer:data(), self.extraBuffer:nElement(),
zero:data(),
self.oDesc[0], self.output:data() + g*self.output_offset);
- errcheck('cudnnAddTensor', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnAddTensor', cudnn.getHandle(),
'CUDNN_ADD_SAME_C',
one:data(), self.biasDesc[0], self.bias:data() + g*self.bias_offset,
one:data(), self.oDesc[0], self.output:data() + g*self.output_offset);
@@ -167,7 +167,7 @@ function SpatialConvolution:updateGradInput(input, gradOutput)
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
for g=0,self.groups-1 do
- errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
one:data(),
self.weightDesc[0], self.weight:data() + g*self.weight_offset,
self.oDesc[0], gradOutput:data() + g*self.output_offset,
@@ -190,13 +190,13 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
if not self.weightDesc then self:resetWeightDescriptors() end
for g=0,self.groups-1 do
-- gradBias
- errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
self.scaleT:data(),
self.oDesc[0], gradOutput:data() + g*self.output_offset,
one:data(),
self.biasDesc[0], self.gradBias:data() + g*self.bias_offset);
-- gradWeight
- errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(),
self.scaleT:data(),
self.iDesc[0], input:data() + g*self.input_offset,
self.oDesc[0], gradOutput:data() + g*self.output_offset,
@@ -209,9 +209,9 @@ end
--[[
function SpatialConvolution:zeroGradParameters()
-- gradWeight, gradBias to zero
- errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnSetTensor', cudnn.getHandle(),
self.weightDesc, self.gradWeight:data(), zero:data());
- errcheck('cudnnSetTensor', cudnn.handle[cutorch.getDevice()-1],
+ errcheck('cudnnSetTensor', cudnn.getHandle(),
self.biasDesc, self.gradBias:data(), zero:data());
end
]]--