diff options
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index d691081..99d4acd 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -83,17 +83,19 @@ function SpatialConvolution:resetMode() end function SpatialConvolution:createIODescriptors(input) - local batch = true - if input:dim() == 3 then - input = input:view(1, input:size(1), input:size(2), input:size(3)) - batch = false - end - assert(input:dim() == 4 and input:isContiguous()); - self.iSize = self.iSize or torch.LongStorage(4):fill(0) + assert((input:dim() == 3 or input:dim() == 4) and input:isContiguous()); + if not self.iDesc or not self.oDesc or - input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] - or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or + input:size(3) ~= self.iSize[3] or + (#input:size() == 4 and #self.iSize == 4 and + input:size(4) ~= self.iSize[4]) then self.iSize = input:size() + local batch = true + if input:dim() == 3 then + input = input:view(1, input:size(1), input:size(2), input:size(3)) + batch = false + end -- resize gradInput if self.gradInput then self.gradInput:resizeAs(input); end |