diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-01-12 01:04:06 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-01-12 01:04:06 +0300 |
commit | 056ed8965ce9557db146a475e0ef4772b3afda77 (patch) | |
tree | 6511441499b2714e59009ca5ea76fe80855722f3 | |
parent | b6175136131bbe9cdcd419927c21025e109d1c3f (diff) | |
parent | a3c65377739c1dc81bf3065902c6122859dde129 (diff) |
Merge pull request #89 from colesbury/R4
Make bias optional in cudnn.SpatialConvolution
-rw-r--r-- | SpatialConvolution.lua | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 1d6dbf1..f7348d6 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -33,7 +33,7 @@ end function SpatialConvolution:resetWeightDescriptors() assert(torch.typename(self.weight) == 'torch.CudaTensor', 'Only Cuda supported duh!') - assert(torch.typename(self.bias) == 'torch.CudaTensor', + assert(torch.typename(self.bias) == 'torch.CudaTensor' or not self.bias, 'Only Cuda supported duh!') -- for compatibility self.groups = self.groups or 1 @@ -52,7 +52,9 @@ function SpatialConvolution:resetWeightDescriptors() ffi.gc(self.weightDesc, destroyWDesc) -- create descriptor for bias - self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1)) + if self.bias then + self.biasDesc = cudnn.toDescriptor(self.bias:view(1, self.nOutputPlane,1,1)) + end end function SpatialConvolution:fastest(mode) @@ -382,9 +384,11 @@ function SpatialConvolution:updateOutput(input) end -- add bias - errcheck('cudnnAddTensor', cudnn.getHandle(), - one:data(), self.biasDesc[0], self.bias:data(), - one:data(), self.oDescForBias[0], self.output:data()) + if self.bias then + errcheck('cudnnAddTensor', cudnn.getHandle(), + one:data(), self.biasDesc[0], self.bias:data(), + one:data(), self.oDescForBias[0], self.output:data()) + end return self.output end @@ -424,11 +428,13 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) self:createIODescriptors(input) -- gradBias - errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(), - self.scaleT:data(), - self.oDescForBias[0], gradOutput:data(), - one:data(), - self.biasDesc[0], self.gradBias:data()) + if self.bias then + errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(), + self.scaleT:data(), + self.oDescForBias[0], gradOutput:data(), + one:data(), + self.biasDesc[0], self.gradBias:data()) + end for g = 0, self.groups - 1 do -- gradWeight |