diff options
author | soumith <soumith@fb.com> | 2016-02-26 07:26:31 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-02-26 07:26:31 +0300 |
commit | af7240b9ce4058b7490014987803b077e9191566 (patch) | |
tree | 80ac88833073bd1b10619f6d834d67b5500a9cf6 | |
parent | 756111325aa68030240bc0267e7c2864876bde6f (diff) |
allow non-contiguous inputs for Spatial and Volumetric convolution
-rw-r--r-- | SpatialConvolution.lua | 20 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 25 |
2 files changed, 39 insertions, 6 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index dd27fc3..09e421d 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -344,8 +344,23 @@ end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); +local function makeContiguous(self, input, gradOutput) + if not input:isContiguous() then + self._input = self._input or input.new() + self._input:typeAs(input):resizeAs(input):copy(input) + input = self._input + end + if gradOutput and not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:typeAs(gradOutput):resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end + return input, gradOutput +end + function SpatialConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end + input = makeContiguous(self, input) self:createIODescriptors(input) for g = 0, self.groups - 1 do @@ -372,8 +387,8 @@ end function SpatialConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end + input, gradOutput = makeContiguous(self, input, gradOutput) assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); - assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous') if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) @@ -398,8 +413,9 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) scale = scale or 1.0 self.scaleT[1] = scale + input, gradOutput = makeContiguous(self, input, gradOutput) + assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D'); - assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous') if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index db352a5..9081e61 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -199,8 +199,23 @@ end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); +local function makeContiguous(self, input, gradOutput) + if not input:isContiguous() then + self._input = self._input or input.new() + self._input:typeAs(input):resizeAs(input):copy(input) + input = self._input + end + if gradOutput and not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:typeAs(gradOutput):resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end + return input, gradOutput +end + function VolumetricConvolution:updateOutput(input) if not self.weightDesc then self:resetWeightDescriptors() end + input = makeContiguous(self, input) self:createIODescriptors(input) errcheck('cudnnConvolutionForward', cudnn.getHandle(), one:data(), @@ -219,8 +234,9 @@ end function VolumetricConvolution:updateGradInput(input, gradOutput) if not self.gradInput then return end - assert((gradOutput:dim() == 4 or gradOutput:dim() == 5) - and gradOutput:isContiguous()); + input, gradOutput = makeContiguous(self, input, gradOutput) + assert(gradOutput:dim() == 4 or gradOutput:dim() == 5, + 'gradOutput has to be a 4D or 5D tensor'); if not self.weightDesc then self:resetWeightDescriptors() end self:createIODescriptors(input) errcheck('cudnnConvolutionBackwardData_v3', cudnn.getHandle(), @@ -242,8 +258,9 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale) scale = scale or 1.0 self.scaleT[1] = scale - assert((gradOutput:dim() == 4 or gradOutput:dim() == 5) - and gradOutput:isContiguous()); + input, gradOutput = makeContiguous(self, input, gradOutput) + assert(gradOutput:dim() == 4 or gradOutput:dim() == 5, + 'gradOutput has to be a 4D or 5D tensor'); self:createIODescriptors(input) if not self.weightDesc then self:resetWeightDescriptors() end -- gradBias |