diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-03-03 15:58:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-03-03 15:58:30 +0300 |
commit | b5821697ad079258f0943a907d4172c434974030 (patch) | |
tree | 0d6d3e4e57642a2e0ed668368bd099047237edc0 /test.lua | |
parent | 24a1715cd5095b3b92ec10b5f4764c13c7522ec1 (diff) | |
parent | 5b3d27fa72c7b8731d41c87362ad98b7ebfea245 (diff) |
Merge pull request #176 from mlosch/volmaxpoolbatch
Add batch mode for VolumetricMaxPooling
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 12 |
1 files changed, 12 insertions, 0 deletions
@@ -2096,6 +2096,18 @@ function nntest.VolumetricMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') + + -- batch + local nbatch = math.random(2,3) + module = nn.VolumetricMaxPooling(kt, ki, kj, st, si, sj) + input = torch.Tensor(nbatch, from, int, inj, ini):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state (Batch) ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') end function nntest.Module_getParameters_1() |