From 9465aae4f41734c8218adaf2d50c7b3f5c9e80f7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 22 Sep 2016 02:19:15 -0700 Subject: Revamped workspace handling in find.lua Retired functional.lua: impossible to maintain consistently with Find. Simplified FindEx state machine: replaced witgh warmup iterations concept, controllable by user. FindEx still needs some work. Improved cache handling and debug print --- SpatialConvolution.lua | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) (limited to 'SpatialConvolution.lua') diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index f2ab112..512e7c2 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -126,12 +126,12 @@ function SpatialConvolution:createIODescriptors(input) self.convDesc = cudnn.createDescriptors(1, 'struct cudnnConvolutionStruct*[?]', 'cudnnCreateConvolutionDescriptor', 'cudnnDestroyConvolutionDescriptor') self.padH, self.padW = self.padH or 0, self.padW or 0 - local pad = torch.IntTensor({self.padH, self.padW}) - local stride = torch.IntTensor({self.dH, self.dW}) + self.pad = torch.IntTensor({self.padH, self.padW}) + self.stride = torch.IntTensor({self.dH, self.dW}) local upscale = torch.IntTensor({1,1}) errcheck(self,'cudnnSetConvolutionNdDescriptor', self.convDesc[0], - 2, pad:data(), - stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', + 2, self.pad:data(), + self.stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', cudnn.configmap(torch.type(self.weight))); @@ -188,9 +188,8 @@ function SpatialConvolution:updateOutput(input) self:createIODescriptors(input) local finder = find.get() -- force recalculation - if not (self.fmode and finder.useCalculatedWorkspaceSize) then - self.fmode = finder:forwardAlgorithm(self, { self.iDesc[0], self.input_slice, self.weightDesc[0], self.weight, self.convDesc[0], self.oDesc[0], self.output_slice}) - end + self.fmode = finder:forwardAlgorithm(self, { self.iDesc[0], self.input_slice, self.weightDesc[0], self.weight, self.convDesc[0], self.oDesc[0], self.output_slice}) + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0, self.groups - 1 do errcheck(self,'cudnnConvolutionForward', cudnn.getHandle(), @@ -221,9 +220,8 @@ function SpatialConvolution:updateGradInput(input, gradOutput) input, gradOutput = makeContiguous(self, input, gradOutput) self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bdmode) then - self.bdmode = finder:backwardDataAlgorithm(self, { self.weightDesc[0], self.weight, self.oDesc[0], self.output_slice, self.convDesc[0], self.iDesc[0], self.input_slice }) - end + self.bdmode = finder:backwardDataAlgorithm(self, { self.weightDesc[0], self.weight, self.oDesc[0], self.output_slice, self.convDesc[0], self.iDesc[0], self.input_slice }) + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0,self.groups - 1 do errcheck(self,'cudnnConvolutionBackwardData', cudnn.getHandle(), @@ -249,10 +247,8 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) input, gradOutput = makeContiguous(self, input, gradOutput) self:createIODescriptors(input) local finder = find.get() - if not (finder.useCalculatedWorkspaceSize and self.bmode) then - self.bmode=finder:backwardFilterAlgorithm(self, { self.iDesc[0], self.input_slice, self.oDesc[0], self.output_slice, self.convDesc[0], self.weightDesc[0], self.weight}) - end - + self.bmode=finder:backwardFilterAlgorithm(self, { self.iDesc[0], self.input_slice, self.oDesc[0], self.output_slice, self.convDesc[0], self.weightDesc[0], self.weight}) + finder:setCalculatedWorkspaceSize(true) -- gradBias if self.bias then errcheck(self,'cudnnConvolutionBackwardBias', cudnn.getHandle(), @@ -261,6 +257,7 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) cudnn.scalar(input, 1), self.biasDesc[0], self.gradBias:data()) end + finder:setCalculatedWorkspaceSize(true) local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace() for g = 0, self.groups - 1 do -- gradWeight -- cgit v1.2.3