diff options
author | Max Losch <mmlosch@kth.se> | 2015-03-03 14:48:53 +0300 |
---|---|---|
committer | Max Losch <mmlosch@kth.se> | 2015-03-03 14:48:53 +0300 |
commit | 5b3d27fa72c7b8731d41c87362ad98b7ebfea245 (patch) | |
tree | 0d6d3e4e57642a2e0ed668368bd099047237edc0 /test.lua | |
parent | 24a1715cd5095b3b92ec10b5f4764c13c7522ec1 (diff) |
Add batch mode to VolumetricMaxPooling
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 12 |
1 files changed, 12 insertions, 0 deletions
@@ -2096,6 +2096,18 @@ function nntest.VolumetricMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') + + -- batch + local nbatch = math.random(2,3) + module = nn.VolumetricMaxPooling(kt, ki, kj, st, si, sj) + input = torch.Tensor(nbatch, from, int, inj, ini):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state (Batch) ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') end function nntest.Module_getParameters_1() |