Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2016-08-04 12:12:54 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2016-08-04 12:12:54 +0300
commitfb1bec17939eb26f94da6a22f410ad316730b9e4 (patch)
tree951b8203eeee55ded736943365300acae19771ae /SpatialConvolution.lua
parenta33739d6346adb3ea262c03a4ff900cef999d8c8 (diff)
Completing cudnnFind refactoring; addressing code review notes
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua54
1 files changed, 26 insertions, 28 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 5295bd5..1656154 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -22,10 +22,11 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane,
self:reset()
-- should nil for serialization, the reset will still work
self.reset = nil
+ return self
end
function SpatialConvolution:createWeightDescriptors()
- assert(cudnn.typemap[torch.typename(self.weight)], 'Only Cuda supported duh!')
+ assert(cudnn.typemap[torch.typename(self.weight)] or not self.weight, 'Only Cuda supported duh!')
assert(cudnn.typemap[torch.typename(self.bias)] or not self.bias, 'Only Cuda supported duh!')
-- create descriptor for bias
if self.bias then
@@ -37,23 +38,22 @@ function SpatialConvolution:createWeightDescriptors()
end
-- if you change the configuration of the module manually, call this
-function SpatialConvolution:resetWeightDescriptors()
+function SpatialConvolution:resetWeightDescriptors(desc)
-- for compatibility
self.groups = self.groups or 1
self.weightDesc = SpatialConvolution.createWeightDescriptors(self)
- local desc = torch.IntTensor({self.nOutputPlane/self.groups,
- self.nInputPlane/self.groups,
- self.kH, self.kW})
+ desc = desc or torch.IntTensor({self.nOutputPlane/self.groups,
+ self.nInputPlane/self.groups,
+ self.kH, self.kW})
errcheck('cudnnSetFilterNdDescriptor', self.weightDesc[0],
- cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', 4,
+ cudnn.typemap[torch.typename(self.weight)], 'CUDNN_TENSOR_NCHW', self.nDim,
desc:data());
end
function SpatialConvolution:fastest(mode)
if mode == nil then mode = true end
self.fastest_mode = mode
- self.iSize = self.iSize or torch.LongStorage(4)
- self.iSize:fill(0)
+ self.iDesc = nil
return self
end
@@ -67,8 +67,7 @@ function SpatialConvolution:setMode(fmode, bdmode, bwmode)
if bwmode ~= nil then
self.bwmode = bwmode
end
- self.iSize = self.iSize or torch.LongStorage(4)
- self.iSize:fill(0)
+ self.iDesc = nil
return self
end
@@ -87,10 +86,14 @@ end
function SpatialConvolution:checkInputChanged(input)
- assert(input:dim() == 4 and input:isContiguous());
- self.iSize = self.iSize or torch.LongStorage(4):fill(0)
+ self.nDim = self.nDim or 4
+ assert(input:dim() == self.nDim)
+ assert(input:isContiguous())
+ self.iSize = self.iSize or torch.LongStorage(self.nDim):fill(0)
+ self.groups = self.groups or 1
+ if not self.weightDesc then self:resetWeightDescriptors() end
if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
- or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
+ or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] or (self.nDim==5 and input:size(5) ~= self.iSize[5]) then
self.iSize = input:size()
assert(self.nInputPlane == input:size(2), 'input has to contain: '
@@ -127,11 +130,11 @@ function SpatialConvolution:createIODescriptors(input)
-- get output shape, resize output
- local oSize = torch.IntTensor(4)
+ local oSize = torch.IntTensor(self.nDim)
local oSizeD = oSize:data()
errcheck('cudnnGetConvolutionNdForwardOutputDim',
self.convDesc[0], self.iDesc[0],
- self.weightDesc[0], 4, oSizeD)
+ self.weightDesc[0], self.nDim, oSizeD)
oSize[2] = oSize[2] * self.groups
self.output:resize(oSize:long():storage())
@@ -162,7 +165,7 @@ end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
-local function makeContiguous(self, input, gradOutput)
+function SpatialConvolution:makeContiguous(input, gradOutput)
if not input:isContiguous() then
self._input = self._input or input.new()
self._input:typeAs(input):resizeAs(input):copy(input)
@@ -177,8 +180,7 @@ local function makeContiguous(self, input, gradOutput)
end
function SpatialConvolution:updateOutput(input)
- if not self.weightDesc then self:resetWeightDescriptors() end
- input = makeContiguous(self, input)
+ input = SpatialConvolution.makeContiguous(self, input)
self:createIODescriptors(input)
if not self.fwdAlgType then
algo.setupForwardAlgorithm(self)
@@ -207,10 +209,8 @@ end
function SpatialConvolution:updateGradInput(input, gradOutput)
if not self.gradInput then return end
self.gradInput:resizeAs(input)
-
- input, gradOutput = makeContiguous(self, input, gradOutput)
- assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D');
- if not self.weightDesc then self:resetWeightDescriptors() end
+ input, gradOutput = SpatialConvolution.makeContiguous(self, input, gradOutput)
+ assert(gradOutput:dim() == self.nDim-1 or gradOutput:dim() == self.nDim, 'gradOutput has to be nDim or nDim-1');
self:createIODescriptors(input)
if not self.bwdDataAlgType then
algo.setupBackwardDataAlgorithm(self)
@@ -236,12 +236,10 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
self.scaleT = self.scaleT:float()
scale = scale or 1.0
self.scaleT[1] = scale
-
- input, gradOutput = makeContiguous(self, input, gradOutput)
-
- assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D');
- if not self.weightDesc then self:resetWeightDescriptors() end
+ input, gradOutput = SpatialConvolution.makeContiguous(self, input, gradOutput)
+ assert(gradOutput:dim() == self.nDim-1 or gradOutput:dim() == self.nDim, 'gradOutput has to be nDim or nDim-1');
self:createIODescriptors(input)
+
if not self.bwdFilterAlgType then
algo.setupBackwardFilterAlgorithm(self)
end
@@ -295,7 +293,7 @@ end
function SpatialConvolution:clearState()
self:clearDesc()
- nn.utils.clear(self, '_input', '_gradOutput')
+ nn.utils.clear(self, 'extraBuffer', '_input', '_gradOutput')
return nn.Module.clearState(self)
end