diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-08-04 12:12:54 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2016-08-04 12:12:54 +0300 |
commit | fb1bec17939eb26f94da6a22f410ad316730b9e4 (patch) | |
tree | 951b8203eeee55ded736943365300acae19771ae /SpatialConvolution.lua | |
parent | a33739d6346adb3ea262c03a4ff900cef999d8c8 (diff) |
Completing cudnnFind refactoring; addressing code review notes
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 54 |
1 files changed, 26 insertions, 28 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 5295bd5..1656154 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -22,10 +22,11 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, self:reset() -- should nil for serialization, the reset will still work self.reset = nil + return self end function SpatialConvolution:createWeightDescriptors() - assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!') + assert(cudnn.typemap[torch.typename(self.weight)] or not self.weight, 'Only Cuda supported duh!') assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!') -- create descriptor for bias if self.bias then @@ -37,23 +38,22 @@ function SpatialConvolution:createWeightDescriptors() end -- if you change the configuration of the module manually, call this -function SpatialConvolution:resetWeightDescriptors() +function SpatialConvolution:resetWeightDescriptors(desc) -- for compatibility self.groups = self.groups or 1 self.weightDesc = SpatialConvolution.createWeightDescriptors(self) - local desc = torch.IntTensor({self.nOutputPlane/self.groups, - self.nInputPlane/self.groups, - self.kH, self.kW}) + desc = desc or torch.IntTensor({self.nOutputPlane/self.groups, + self.nInputPlane/self.groups, + self.kH, self.kW}) errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0], - cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', 4, + cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', self.nDim, desc:data()); end function SpatialConvolution:fastest(mode) if mode == nil then mode = true end self.fastest_mode = mode - self.iSize = self.iSize or torch.LongStorage(4) - self.iSize:fill(0) + self.iDesc = nil return self end @@ -67,8 +67,7 @@ function SpatialConvolution:setMode(fmode, bdmode, bwmode) if bwmode ~= nil then self.bwmode = bwmode end - self.iSize = self.iSize or torch.LongStorage(4) - self.iSize:fill(0) + self.iDesc = nil return self end @@ -87,10 +86,14 @@ end function SpatialConvolution:checkInputChanged(input) - assert(input:dim() == 4 and input:isContiguous()); - self.iSize = self.iSize or torch.LongStorage(4):fill(0) + self.nDim = self.nDim or 4 + assert(input:dim() == self.nDim) + assert(input:isContiguous()) + self.iSize = self.iSize or torch.LongStorage(self.nDim):fill(0) + self.groups = self.groups or 1 + if not self.weightDesc then self:resetWeightDescriptors() end if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] - or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] or (self.nDim==5 and input:size(5) ~= self.iSize[5]) then self.iSize = input:size() assert(self.nInputPlane == input:size(2), 'input has to contain: ' @@ -127,11 +130,11 @@ function SpatialConvolution:createIODescriptors(input) -- get output shape, resize output - local oSize = torch.IntTensor(4) + local oSize = torch.IntTensor(self.nDim) local oSizeD = oSize:data() errcheck('cudnnGetConvolutionNdForwardOutputDim', self.convDesc[0], self.iDesc[0], - self.weightDesc[0], 4, oSizeD) + self.weightDesc[0], self.nDim, oSizeD) oSize[2] = oSize[2] * self.groups self.output:resize(oSize:long():storage()) @@ -162,7 +165,7 @@ end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); -local function makeContiguous(self, input, gradOutput) +function SpatialConvolution:makeContiguous(input, gradOutput) if not input:isContiguous() then self._input = self._input or input.new() self._input:typeAs(input):resizeAs(input):copy(input) @@ -177,8 +180,7 @@ local function makeContiguous(self, input, gradOutput) end function SpatialConvolution:updateOutput(input) - if not self.weightDesc then self:resetWeightDescriptors() end - input = makeContiguous(self, input) + input = SpatialConvolution.makeContiguous(self, input) self:createIODescriptors(input) if not self.fwdAlgType then algo.setupForwardAlgorithm(self) @@ -207,10 +209,8 @@ end function SpatialConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end self.gradInput:resizeAs(input) - - input, gradOutput = makeContiguous(self, input, gradOutput) - assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); - if not self.weightDesc then self:resetWeightDescriptors() end + input, gradOutput = SpatialConvolution.makeContiguous(self, input, gradOutput) + assert(gradOutput:dim() == self.nDim-1 or gradOutput:dim() == self.nDim, 'gradOutput has to be nDim or nDim-1'); self:createIODescriptors(input) if not self.bwdDataAlgType then algo.setupBackwardDataAlgorithm(self) @@ -236,12 +236,10 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) self.scaleT = self.scaleT:float() scale = scale or 1.0 self.scaleT[1] = scale - - input, gradOutput = makeContiguous(self, input, gradOutput) - - assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); - if not self.weightDesc then self:resetWeightDescriptors() end + input, gradOutput = SpatialConvolution.makeContiguous(self, input, gradOutput) + assert(gradOutput:dim() == self.nDim-1 or gradOutput:dim() == self.nDim, 'gradOutput has to be nDim or nDim-1'); self:createIODescriptors(input) + if not self.bwdFilterAlgType then algo.setupBackwardFilterAlgorithm(self) end @@ -295,7 +293,7 @@ end function SpatialConvolution:clearState() self:clearDesc() - nn.utils.clear(self, '_input', '_gradOutput') + nn.utils.clear(self, 'extraBuffer', '_input', '_gradOutput') return nn.Module.clearState(self) end |