Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@fb.com>2014-12-21 05:41:55 +0300
committerSoumith Chintala <soumith@fb.com>2014-12-21 05:41:55 +0300
commite757d336b3934793bfc52c7c2995aad612a569b8 (patch)
tree97eda33d77af6e9376eaf8a8fb4c5037c0424f60 /test.lua
parent1e9ca31233f2e3914c8342356e642a89ce0554d2 (diff)
VolumetricConvolution batch mode + test
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua24
1 files changed, 22 insertions, 2 deletions
diff --git a/test.lua b/test.lua
index a3c1c85..29de8bc 100644
--- a/test.lua
+++ b/test.lua
@@ -1253,9 +1253,9 @@ function nntest.SpatialFullConvolutionCompare()
end
local function batchcompare(smod, sin, plist)
- local bs = torch.LongStorage(sin:size():size()+1)
+ local bs = torch.LongStorage(sin:dim()+1)
bs[1] = 1
- for i=1,sin:size():size() do bs[i+1] = sin:size()[i] end
+ for i=1,sin:dim() do bs[i+1] = sin:size()[i] end
local bin = torch.Tensor(bs):copy(sin)
local bmod = smod:clone()
@@ -1780,6 +1780,26 @@ function nntest.VolumetricConvolution()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.VolumetricConvolutionBatchCompare()
+ local from = math.random(2,3)
+ local to = math.random(2,3)
+ local kt = math.random(3,4)
+ local ki = math.random(3,4)
+ local kj = math.random(3,4)
+ local st = math.random(2,3)
+ local si = math.random(2,3)
+ local sj = math.random(2,3)
+ local outt = math.random(3,4)
+ local outi = math.random(3,4)
+ local outj = math.random(3,4)
+ local int = (outt-1)*st+kt
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj)
+ local input = torch.randn(from, int, inj, ini)
+ batchcompare(module,input, {'weight','bias','gradWeight','gradBias'})
+end
+
function nntest.VolumetricMaxPooling()
local from = math.random(2,3)
local kt = math.random(3,4)