diff options
author | Koray Kavukcuoglu <koray@kavukcuoglu.org> | 2012-09-26 19:01:45 +0400 |
---|---|---|
committer | Koray Kavukcuoglu <koray@kavukcuoglu.org> | 2012-09-26 19:01:45 +0400 |
commit | 9614cd41480f7d2c1382f33924ad168c32b03828 (patch) | |
tree | cf81f552f2965e20ea7c9cadd34ad030903ece24 /test | |
parent | 4069dc4d9838936701d471fde7417a3223ac7c0e (diff) |
add batch mode to SpatialMaxPooling and openmpize.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua index 583deb2..db3d419 100644 --- a/test/test.lua +++ b/test/test.lua @@ -879,6 +879,19 @@ function nntest.SpatialMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch + local nbatch = math.random(2,5) + input = torch.rand(nbatch,from,ini,inj) + module = nn.SpatialMaxPooling(ki,kj,si,sj) + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state (Batch) ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + end function nntest.SpatialLPPooling() @@ -1226,9 +1239,9 @@ if not nn then mytester:run() else jac = nn.Jacobian - function nn.test() + function nn.test(tests) -- randomize stuff math.randomseed(os.time()) - mytester:run() + mytester:run(tests) end end |