Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'Pooling3D.lua')
-rw-r--r--Pooling3D.lua9
1 files changed, 3 insertions, 6 deletions
diff --git a/Pooling3D.lua b/Pooling3D.lua
index cce67c3..e4c0218 100644
--- a/Pooling3D.lua
+++ b/Pooling3D.lua
@@ -58,8 +58,6 @@ function Pooling:createIODescriptors(input)
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4]
or input:size(5) ~= self.iSize[5] then
self.iSize = input:size()
- -- resize gradInput
- self.gradInput:resizeAs(input)
-- resize output
local oW, oH, oT
if self.ceil_mode then
@@ -77,10 +75,6 @@ function Pooling:createIODescriptors(input)
self.iDesc = cudnn.toDescriptor(input)
self.oDesc = cudnn.toDescriptor(self.output)
if not batch then
- self.gradInput = self.gradInput:view(self.gradInput:size(2),
- self.gradInput:size(3),
- self.gradInput:size(4),
- self.gradInput:size(5))
self.output = self.output:view(self.output:size(2),
self.output:size(3),
self.output:size(4),
@@ -105,6 +99,9 @@ function Pooling:updateOutput(input)
end
function Pooling:updateGradInput(input, gradOutput)
+ if not self.gradInput then return end
+ self.gradInput:resizeAs(input)
+
assert(gradOutput:dim() == 4 or gradOutput:dim() == 5);
if not gradOutput:isContiguous() then
self._gradOutput = self._gradOutput or gradOutput.new()