diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 60 |
1 files changed, 33 insertions, 27 deletions
@@ -1926,38 +1926,44 @@ function nntest.SpatialSubSampling() end function nntest.SpatialMaxPooling() - local from = math.random(1,5) - local ki = math.random(1,4) - local kj = math.random(1,4) - local si = math.random(1,3) - local sj = math.random(1,3) - local outi = math.random(4,5) - local outj = math.random(4,5) - local ini = (outi-1)*si+ki - local inj = (outj-1)*sj+kj - - local module = nn.SpatialMaxPooling(ki,kj,si,sj) - local input = torch.rand(from,ini,inj) + for _,ceil_mode in pairs({true,false}) do + local from = math.random(1,5) + local ki = math.random(1,4) + local kj = math.random(1,4) + local si = math.random(1,3) + local sj = math.random(1,3) + local outi = math.random(4,5) + local outj = math.random(4,5) + local padW = math.min(math.random(0,1),math.floor(ki/2)) + local padH = math.min(math.random(0,1),math.floor(kj/2)) + local ini = (outi-1)*si+ki-2*padW + local inj = (outj-1)*sj+kj-2*padH + + local ceil_string = ceil_mode and 'ceil' or 'floor' + local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end + local input = torch.rand(from,inj,ini) - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state ') - - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ') - -- batch - local nbatch = math.random(2,5) - input = torch.rand(nbatch,from,ini,inj) - module = nn.SpatialMaxPooling(ki,kj,si,sj) + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state (Batch) ') + -- batch + local nbatch = math.random(2,5) + input = torch.rand(nbatch,from,inj,ini) + module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)') + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + end end function nntest.SpatialAveragePooling() |