diff options
author | Soumith Chintala <soumith@fb.com> | 2014-12-21 05:41:55 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@fb.com> | 2014-12-21 05:41:55 +0300 |
commit | e757d336b3934793bfc52c7c2995aad612a569b8 (patch) | |
tree | 97eda33d77af6e9376eaf8a8fb4c5037c0424f60 /test.lua | |
parent | 1e9ca31233f2e3914c8342356e642a89ce0554d2 (diff) |
VolumetricConvolution batch mode + test
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 24 |
1 files changed, 22 insertions, 2 deletions
@@ -1253,9 +1253,9 @@ function nntest.SpatialFullConvolutionCompare() end local function batchcompare(smod, sin, plist) - local bs = torch.LongStorage(sin:size():size()+1) + local bs = torch.LongStorage(sin:dim()+1) bs[1] = 1 - for i=1,sin:size():size() do bs[i+1] = sin:size()[i] end + for i=1,sin:dim() do bs[i+1] = sin:size()[i] end local bin = torch.Tensor(bs):copy(sin) local bmod = smod:clone() @@ -1780,6 +1780,26 @@ function nntest.VolumetricConvolution() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntest.VolumetricConvolutionBatchCompare() + local from = math.random(2,3) + local to = math.random(2,3) + local kt = math.random(3,4) + local ki = math.random(3,4) + local kj = math.random(3,4) + local st = math.random(2,3) + local si = math.random(2,3) + local sj = math.random(2,3) + local outt = math.random(3,4) + local outi = math.random(3,4) + local outj = math.random(3,4) + local int = (outt-1)*st+kt + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj) + local input = torch.randn(from, int, inj, ini) + batchcompare(module,input, {'weight','bias','gradWeight','gradBias'}) +end + function nntest.VolumetricMaxPooling() local from = math.random(2,3) local kt = math.random(3,4) |