diff options
author | soumith <soumith@fb.com> | 2015-01-19 11:01:49 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-01-19 11:01:49 +0300 |
commit | abee8f6de6294f2aafc3af128fe85009a1264076 (patch) | |
tree | 3d2bd1225225f9c42043ce0fc278102b4482f552 /test.lua | |
parent | a9768c444f4a455fc8a7654981350c9ecc6e7aee (diff) | |
parent | 98cf6f146be4e032396daa0da12d5f181537d6a9 (diff) |
Merge https://github.com/fmassa/nn
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 31 |
1 files changed, 31 insertions, 0 deletions
@@ -1678,6 +1678,37 @@ function nntest.SpatialAveragePooling() mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ') end +function nntest.SpatialAdaptiveMaxPooling() + local from = math.random(1,5) + local ki = math.random(1,12) + local kj = math.random(1,12) + local ini = math.random(1,64) + local inj = math.random(1,64) + + local module = nn.SpatialAdaptiveMaxPooling(ki,kj) + local input = torch.rand(from,ini,inj) + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch + local nbatch = math.random(2,5) + input = torch.rand(nbatch,from,ini,inj) + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state (Batch) ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + +end + function nntest.SpatialLPPooling() local fanin = math.random(1,4) local osizex = math.random(1,4) |