diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-07-21 22:07:44 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-07-21 22:07:44 +0300 |
commit | b29c6fd53bbad1935a321a3ffb5a8eb26832cdbf (patch) | |
tree | 0844e2063391aef8386d92b7301d11b4dca208d5 /test.lua | |
parent | 81348688b7089c733f88f7c43e875e012db6cdbe (diff) | |
parent | a53cb3cce78cb218eee7fe2f3c67a8d0fc411dcc (diff) |
Merge pull request #306 from erosennin/fix-hessian
Add unit tests for hessian.lua, fix bugs detected by the tests
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 83 |
1 files changed, 83 insertions, 0 deletions
@@ -2,6 +2,8 @@ -- th -lnn -e "nn.test{'LookupTable'}" -- th -lnn -e "nn.test{'LookupTable', 'Add'}" +nn.hessian.enable() + local mytester = torch.Tester() local jac local sjac @@ -484,6 +486,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diagHessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -513,6 +524,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1384,6 +1404,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1424,6 +1453,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1555,6 +1593,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1593,6 +1640,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1637,6 +1693,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1669,6 +1734,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1714,6 +1788,15 @@ function nntest.SpatialFullConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) |