diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-20 22:55:33 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-21 03:51:16 +0400 |
commit | 70433d6359cdae6833c315bb8151038ed9f75a1c (patch) | |
tree | 9d5fc34dfa6e63a920d80ea0e7090b368986eb45 /SpatialConvolution.lua | |
parent | ae62e2be0dc9cf7500972b7355ccfd46e3d2b1a8 (diff) |
Multi-GPU support
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 6d96cdd..ceb42cd 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -62,13 +62,13 @@ end function SpatialConvolution:updateOutput(input) assert(input:dim() == 4 and input:isContiguous()); self:createIODescriptors(input) - errcheck('cudnnConvolutionForward', cudnn.handle[0], + errcheck('cudnnConvolutionForward', cudnn.handle[cutorch.getDevice()-1], self.iDesc[0], input:data(), self.weightDesc[0], self.weight:data(), self.convDesc[0], self.oDesc[0], self.output:data(), 'CUDNN_RESULT_NO_ACCUMULATE'); local alpha = torch.FloatTensor({1}); - errcheck('cudnnAddTensor4d', cudnn.handle[0], 'CUDNN_ADD_SAME_C', + errcheck('cudnnAddTensor4d', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ADD_SAME_C', alpha:data(), self.biasDesc[0], self.bias:data(), self.oDesc[0], self.output:data()); return self.output @@ -77,7 +77,7 @@ end function SpatialConvolution:updateGradInput(input, gradOutput) assert(input:dim() == 4 and input:isContiguous()); assert(gradOutput:dim() == 4 and gradOutput:isContiguous()); - errcheck('cudnnConvolutionBackwardData', cudnn.handle[0], + errcheck('cudnnConvolutionBackwardData', cudnn.handle[cutorch.getDevice()-1], self.weightDesc[0], self.weight:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], @@ -91,12 +91,12 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) assert(input:dim() == 4 and input:isContiguous()); assert(gradOutput:dim() == 4 and gradOutput:isContiguous()); -- gradBias - errcheck('cudnnConvolutionBackwardBias', cudnn.handle[0], + errcheck('cudnnConvolutionBackwardBias', cudnn.handle[cutorch.getDevice()-1], self.oDesc[0], gradOutput:data(), self.biasDesc[0], self.gradBias:data(), 'CUDNN_RESULT_ACCUMULATE'); -- gradWeight - errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[0], + errcheck('cudnnConvolutionBackwardFilter', cudnn.handle[cutorch.getDevice()-1], self.iDesc[0], input:data(), self.oDesc[0], gradOutput:data(), self.convDesc[0], |