From 929cfc57c88952b597bec77046582b90d1122380 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 28 Jun 2015 23:51:47 +0200 Subject: SpatialMaxPooling supports padding and ceil mode - changes the way the max indexes are stored, saving memory --- test.lua | 60 +++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 27 deletions(-) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 55818e1..8a29e29 100644 --- a/test.lua +++ b/test.lua @@ -1926,38 +1926,44 @@ function nntest.SpatialSubSampling() end function nntest.SpatialMaxPooling() - local from = math.random(1,5) - local ki = math.random(1,4) - local kj = math.random(1,4) - local si = math.random(1,3) - local sj = math.random(1,3) - local outi = math.random(4,5) - local outj = math.random(4,5) - local ini = (outi-1)*si+ki - local inj = (outj-1)*sj+kj - - local module = nn.SpatialMaxPooling(ki,kj,si,sj) - local input = torch.rand(from,ini,inj) + for _,ceil_mode in pairs({true,false}) do + local from = math.random(1,5) + local ki = math.random(1,4) + local kj = math.random(1,4) + local si = math.random(1,3) + local sj = math.random(1,3) + local outi = math.random(4,5) + local outj = math.random(4,5) + local padW = math.min(math.random(0,1),math.floor(ki/2)) + local padH = math.min(math.random(0,1),math.floor(kj/2)) + local ini = (outi-1)*si+ki-2*padW + local inj = (outj-1)*sj+kj-2*padH + + local ceil_string = ceil_mode and 'ceil' or 'floor' + local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end + local input = torch.rand(from,inj,ini) - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state ') - - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ') - -- batch - local nbatch = math.random(2,5) - input = torch.rand(nbatch,from,ini,inj) - module = nn.SpatialMaxPooling(ki,kj,si,sj) + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state (Batch) ') + -- batch + local nbatch = math.random(2,5) + input = torch.rand(nbatch,from,inj,ini) + module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)') + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + end end function nntest.SpatialAveragePooling() -- cgit v1.2.3