diff options
author | Aysegul Dundar <adundar@purdue.edu> | 2014-10-30 18:12:59 +0300 |
---|---|---|
committer | Aysegul Dundar <adundar@purdue.edu> | 2014-11-10 21:11:58 +0300 |
commit | d057c86aaa21e11816d7d2ff6cf09dc76551bd51 (patch) | |
tree | 08c94beb9c9d13a8f9bcf2641d63477e4cbe9719 /test | |
parent | 704684a27efd82da3f4ac05cc9ecb6f44aa6d510 (diff) |
Batchmode is added to SpatialConvmap
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 0ae0e61..c7416ea 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1003,6 +1003,44 @@ function nntest.SpatialConvolutionMap() local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') + + + + -- batch + + --verbose = true + local batch = math.random(2,6) + module = nn.SpatialConvolutionMap(nn.tables.random(from, to, fanin), ki, kj, si, sj) + input = torch.Tensor(batch,from,inj,ini):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'batch error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err , precision, 'batch error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err , precision, 'batch error on bias ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err , precision, 'batch error on weight [direct update] ') + + local err = jac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'batch error on bias [%s]', t)) + end + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end |