Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'VolumetricConvolution.lua')
-rw-r--r--VolumetricConvolution.lua39
1 files changed, 21 insertions, 18 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index b255467..3f32c3d 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -275,10 +275,12 @@ function VolumetricConvolution:createIODescriptors(input)
maxBufSize = math.max(maxBufSize, bufSize[1])
self.extraBuffer = self.extraBuffer or cudnn.getSharedWorkspace()
- self.extraBufferSizeInBytes = self.extraBuffer:nElement() * 4 -- float
+ self.extraBuffer = self.extraBuffer:cuda() -- always force float
+ self.extraBufferSizeInBytes =
+ self.extraBuffer:nElement() * 4 -- extraBuffer is always float
if maxBufSize > self.extraBufferSizeInBytes then
- self.extraBuffer:resize(math.ceil(maxBufSize/4))
- self.extraBufferSizeInBytes = maxBufSize
+ self.extraBuffer:resize(math.ceil(maxBufSize / 4))
+ self.extraBufferSizeInBytes = maxBufSize
end
-----------------------------------------------------------------------
@@ -291,8 +293,8 @@ function VolumetricConvolution:createIODescriptors(input)
end
end
-local one = torch.FloatTensor({1});
-local zero = torch.FloatTensor({0});
+
+
local function makeContiguous(self, input, gradOutput)
if not input:isContiguous() then
@@ -313,16 +315,16 @@ function VolumetricConvolution:updateOutput(input)
input = makeContiguous(self, input)
self:createIODescriptors(input)
errcheck('cudnnConvolutionForward', cudnn.getHandle(),
- one:data(),
+ cudnn.scalar(input, 1),
self.iDesc[0], input:data(),
self.weightDesc[0], self.weight:data(),
self.convDesc[0], self.fwdAlgType[0],
self.extraBuffer:data(), self.extraBufferSizeInBytes,
- zero:data(),
+ cudnn.scalar(input, 0),
self.oDesc[0], self.output:data());
errcheck('cudnnAddTensor', cudnn.getHandle(),
- one:data(),
- self.biasDesc[0], self.bias:data(), one:data(),
+ cudnn.scalar(input, 1),
+ self.biasDesc[0], self.bias:data(), cudnn.scalar(input, 1),
self.oDescBias[0], self.output:data());
return self.output
end
@@ -337,24 +339,25 @@ function VolumetricConvolution:updateGradInput(input, gradOutput)
if not self.weightDesc then self:resetWeightDescriptors() end
self:createIODescriptors(input)
errcheck('cudnnConvolutionBackwardData', cudnn.getHandle(),
- one:data(),
+ cudnn.scalar(input, 1),
self.weightDesc[0], self.weight:data(),
self.oDesc[0], gradOutput:data(),
self.convDesc[0],
self.bwdDataAlgType[0],
self.extraBuffer:data(), self.extraBufferSizeInBytes,
- zero:data(),
+ cudnn.scalar(input, 0),
self.iDesc[0], self.gradInput:data());
return self.gradInput
end
function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
- self.scaleT = self.scaleT or torch.FloatTensor(1):fill(1.0)
- -- this line forces this member to always be on CPU (needed for cudnn)
- self.scaleT = self.scaleT:float()
+ self.scaleT = self.scaleT or self.weight.new(1)
+ -- this line forces this member to always be on CPU (needed for cudnn)
+ self.scaleT = torch.type(self.weight) == 'torch.CudaDoubleTensor'
+ and self.scaleT:double() or self.scaleT:float()
+ scale = scale or 1.0
+ self.scaleT[1] = scale
- scale = scale or 1.0
- self.scaleT[1] = scale
input, gradOutput = makeContiguous(self, input, gradOutput)
assert(gradOutput:dim() == 4 or gradOutput:dim() == 5,
'gradOutput has to be a 4D or 5D tensor');
@@ -364,7 +367,7 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
self.scaleT:data(),
self.oDescBias[0], gradOutput:data(),
- one:data(),
+ cudnn.scalar(input, 1),
self.biasDesc[0], self.gradBias:data());
-- gradWeight
errcheck('cudnnConvolutionBackwardFilter', cudnn.getHandle(),
@@ -374,7 +377,7 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
self.convDesc[0],
self.bwdFilterAlgType[0],
self.extraBuffer:data(), self.extraBufferSizeInBytes,
- one:data(),
+ cudnn.scalar(input, 1),
self.weightDesc[0], self.gradWeight:data());
end