From 70433d6359cdae6833c315bb8151038ed9f75a1c Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Sat, 20 Sep 2014 14:55:33 -0400 Subject: Multi-GPU support --- SpatialMaxPooling.lua | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'SpatialMaxPooling.lua') diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index e3f9fe4..bd9039d 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -43,7 +43,7 @@ end function SpatialMaxPooling:updateOutput(input) assert(input:dim() == 4 and input:isContiguous()); self:createIODescriptors(input) - errcheck('cudnnPoolingForward', cudnn.handle[0], self.poolDesc[0], + errcheck('cudnnPoolingForward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], self.iDesc[0], input:data(), self.oDesc[0], self.output:data()); return self.output @@ -52,7 +52,7 @@ end function SpatialMaxPooling:updateGradInput(input, gradOutput) assert(input:dim() == 4 and input:isContiguous()); assert(gradOutput:dim() == 4 and gradOutput:isContiguous()); - errcheck('cudnnPoolingBackward', cudnn.handle[0], self.poolDesc[0], + errcheck('cudnnPoolingBackward', cudnn.handle[cutorch.getDevice()-1], self.poolDesc[0], self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), self.iDesc[0], input:data(), -- cgit v1.2.3