Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-03-30 02:37:02 +0300
committerSoumith Chintala <soumith@gmail.com>2016-03-30 02:37:02 +0300
commitc1b9af7efa559607c1801f24e83b7f1c27c31533 (patch)
tree6079be921a8e5d97d6062aa0e5399165df96d70b
parent317ec406bf3d459de86e0953fb580dda12567f4b (diff)
parent02bc67c446d0c42b7831a83ba15bff9677a372d2 (diff)
Merge pull request #152 from gheinrich/fix-temporal-convolution
Fix TemporalConvolution output size
-rw-r--r--TemporalConvolution.lua2
-rw-r--r--test/test.lua23
2 files changed, 24 insertions, 1 deletions
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua
index f59a5bd..a9e6470 100644
--- a/TemporalConvolution.lua
+++ b/TemporalConvolution.lua
@@ -57,7 +57,7 @@ function TemporalConvolution:updateOutput(input)
self.buffer = self.buffer or torch.CudaTensor()
self._output = self._output or torch.CudaTensor()
if self.output:storage() then self._output:set(self.output:storage()) else self._output = self.output end
- if self.buffer:storage() then self.output:set(self.buffer:storage()) else self.output = self.buffer end
+ if self.buffer:storage() then self.output:set(self.buffer:storage(), 1, self.output:size()) else self.output = self.buffer end
cudnn.SpatialConvolution.updateOutput(self,_input)
self.buffer = self.output:view(self.oSize):transpose(2,3)
self.output = self._output:resize(self.buffer:size()):copy(self.buffer)
diff --git a/test/test.lua b/test/test.lua
index 1410b8c..e2ac33c 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -448,6 +448,29 @@ function cudnntest.TemporalConvolution_single()
mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
end
+function cudnntest.TemporalConvolution_reduceBatchSize()
+ local inputFrameSize = math.random(1,64)
+ local outputFrameSize = math.random(1,64)
+ local ki = math.random(1,15)
+ local si = math.random(1,ki)
+ local outi = math.random(1,15)
+ local ini = (outi-1)*si+ki
+ local batchSize = 128
+ local smallerBatchSize = batchSize/2
+
+ local input
+ input = torch.randn(batchSize,ini,inputFrameSize):cuda()
+ local conv = cudnn.TemporalConvolution(inputFrameSize,outputFrameSize,ki,si):cuda()
+ local o1 = conv:updateOutput(input)
+ mytester:asserteq(o1:size(1), batchSize, 'batch size didn\'t match')
+
+ input = torch.randn(smallerBatchSize,ini,inputFrameSize):cuda()
+ local o2 = conv:updateOutput(input)
+ mytester:asserteq(o2:size(1), smallerBatchSize, 'batch size didn\'t match')
+ -- do this again to check it doesn't crash
+ local o2 = conv:updateOutput(input)
+ mytester:asserteq(o2:size(1), smallerBatchSize, 'batch size didn\'t match')
+end
function cudnntest.VolumetricConvolution_forward_single()