Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-01-11 20:19:40 +0300
committerSam Gross <sgross@fb.com>2016-01-11 20:19:40 +0300
commita3c65377739c1dc81bf3065902c6122859dde129 (patch)
tree6511441499b2714e59009ca5ea76fe80855722f3 /SpatialConvolution.lua
parentb6175136131bbe9cdcd419927c21025e109d1c3f (diff)
Make bias optional in cudnn.SpatialConvolution
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua26
1 files changed, 16 insertions, 10 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 1d6dbf1..f7348d6 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -33,7 +33,7 @@ end
function SpatialConvolution:resetWeightDescriptors()
assert(torch.typename(self.weight) == 'torch.CudaTensor',
'Only Cuda supported duh!')
- assert(torch.typename(self.bias) == 'torch.CudaTensor',
+ assert(torch.typename(self.bias) == 'torch.CudaTensor' or not self.bias,
'Only Cuda supported duh!')
-- for compatibility
self.groups = self.groups or 1
@@ -52,7 +52,9 @@ function SpatialConvolution:resetWeightDescriptors()
ffi.gc(self.weightDesc, destroyWDesc)
-- create descriptor for bias
- self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1))
+ if self.bias then
+ self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1))
+ end
end
function SpatialConvolution:fastest(mode)
@@ -382,9 +384,11 @@ function SpatialConvolution:updateOutput(input)
end
-- add bias
- errcheck('cudnnAddTensor', cudnn.getHandle(),
- one:data(), self.biasDesc[0], self.bias:data(),
- one:data(), self.oDescForBias[0], self.output:data())
+ if self.bias then
+ errcheck('cudnnAddTensor', cudnn.getHandle(),
+ one:data(), self.biasDesc[0], self.bias:data(),
+ one:data(), self.oDescForBias[0], self.output:data())
+ end
return self.output
end
@@ -424,11 +428,13 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
self:createIODescriptors(input)
-- gradBias
- errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
- self.scaleT:data(),
- self.oDescForBias[0], gradOutput:data(),
- one:data(),
- self.biasDesc[0], self.gradBias:data())
+ if self.bias then
+ errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
+ self.scaleT:data(),
+ self.oDescForBias[0], gradOutput:data(),
+ one:data(),
+ self.biasDesc[0], self.gradBias:data())
+ end
for g = 0, self.groups - 1 do
-- gradWeight