diff options
-rw-r--r-- | TemporalConvolution.lua | 2 | ||||
-rw-r--r-- | test/test.lua | 23 |
2 files changed, 24 insertions, 1 deletions
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua index f59a5bd..a9e6470 100644 --- a/TemporalConvolution.lua +++ b/TemporalConvolution.lua @@ -57,7 +57,7 @@ function TemporalConvolution:updateOutput(input) self.buffer = self.buffer or torch.CudaTensor() self._output = self._output or torch.CudaTensor() if self.output:storage() then self._output:set(self.output:storage()) else self._output = self.output end - if self.buffer:storage() then self.output:set(self.buffer:storage()) else self.output = self.buffer end + if self.buffer:storage() then self.output:set(self.buffer:storage(), 1, self.output:size()) else self.output = self.buffer end cudnn.SpatialConvolution.updateOutput(self,_input) self.buffer = self.output:view(self.oSize):transpose(2,3) self.output = self._output:resize(self.buffer:size()):copy(self.buffer) diff --git a/test/test.lua b/test/test.lua index 1410b8c..e2ac33c 100644 --- a/test/test.lua +++ b/test/test.lua @@ -448,6 +448,29 @@ function cudnntest.TemporalConvolution_single() mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') end +function cudnntest.TemporalConvolution_reduceBatchSize() + local inputFrameSize = math.random(1,64) + local outputFrameSize = math.random(1,64) + local ki = math.random(1,15) + local si = math.random(1,ki) + local outi = math.random(1,15) + local ini = (outi-1)*si+ki + local batchSize = 128 + local smallerBatchSize = batchSize/2 + + local input + input = torch.randn(batchSize,ini,inputFrameSize):cuda() + local conv = cudnn.TemporalConvolution(inputFrameSize,outputFrameSize,ki,si):cuda() + local o1 = conv:updateOutput(input) + mytester:asserteq(o1:size(1), batchSize, 'batch size didn\'t match') + + input = torch.randn(smallerBatchSize,ini,inputFrameSize):cuda() + local o2 = conv:updateOutput(input) + mytester:asserteq(o2:size(1), smallerBatchSize, 'batch size didn\'t match') + -- do this again to check it doesn't crash + local o2 = conv:updateOutput(input) + mytester:asserteq(o2:size(1), smallerBatchSize, 'batch size didn\'t match') +end function cudnntest.VolumetricConvolution_forward_single() |