diff options
author | Ruotian Luo <rluo@ttic.edu> | 2017-01-15 00:27:52 +0300 |
---|---|---|
committer | Ruotian Luo <rluo@ttic.edu> | 2017-01-15 00:27:52 +0300 |
commit | 2374627df385b68fd919a573e09517e26fa3c254 (patch) | |
tree | 4b870c74beb5c82b655f2c4b48b58e8c450828ab /test.lua | |
parent | c489620118f335d83086ebcfcac4532f4ed760e2 (diff) |
Add SpatialAdaptiveAveragePooling.
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 56 |
1 files changed, 56 insertions, 0 deletions
@@ -4009,6 +4009,62 @@ function nntest.SpatialAdaptiveMaxPooling() end +function nntest.SpatialAdaptiveAveragePooling() + local from = math.random(1,5) + local ki = math.random(1,5) + local kj = math.random(1,5) + local ini = math.random(1,16) + local inj = math.random(1,16) + + local module = nn.SpatialAdaptiveAveragePooling(ki,kj) + local input = torch.rand(from,ini,inj) + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local ferr, berr = jac.testIO(module, input) + mytester:eq(ferr, 0, torch.typename(module) .. ' - i/o forward err ', precision) + mytester:eq(berr, 0, torch.typename(module) .. ' - i/o backward err ', precision) + + -- batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj) + module = nn.SpatialAdaptiveAveragePooling(ki,kj) + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state (Batch) ') + + local ferr, berr = jac.testIO(module, input) + mytester:eq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ', precision) + mytester:eq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ', precision) + + -- non-contiguous + + input = torch.rand(from,ini,inj):transpose(2,3) + module = nn.SpatialAdaptiveAveragePooling(ki,kj) + local inputc = input:contiguous() -- contiguous + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + + -- non-contiguous batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4) + local inputc = input:contiguous() -- contiguous + module = nn.SpatialAdaptiveAveragePooling(ki,kj) + + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + +end + function nntest.SpatialLPPooling() local fanin = math.random(1,4) local osizex = math.random(1,4) |