From a53cb3cce78cb218eee7fe2f3c67a8d0fc411dcc Mon Sep 17 00:00:00 2001 From: Andrey Golovizin Date: Wed, 1 Jul 2015 10:27:41 +0200 Subject: Add unit tests for hessian.lua, fix bugs detected by the tests. * Fix initialization of diagHessianBias for nn.SpatialConvolution. * Fix computing diagHessianBias for nn.SpatialFullConvolution. * Call module:forward() with the proper input before calling accGradParameters(). Without that, accDiagHessianParameters() produces incorrect results for some convolution classes. * Move duplicate code from Module.getParameters() to Module.flatten(), which is now used by both the original Module.getParameters() in Module.lua and the replacement Module.getParameters() in hessian.lua. --- test.lua | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 55818e1..f8437dc 100644 --- a/test.lua +++ b/test.lua @@ -2,6 +2,8 @@ -- th -lnn -e "nn.test{'LookupTable'}" -- th -lnn -e "nn.test{'LookupTable', 'Add'}" +nn.hessian.enable() + local mytester = torch.Tester() local jac local sjac @@ -484,6 +486,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diagHessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -513,6 +524,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1384,6 +1404,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1424,6 +1453,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1555,6 +1593,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1593,6 +1640,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1637,6 +1693,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1669,6 +1734,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1714,6 +1788,15 @@ function nntest.SpatialFullConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) -- cgit v1.2.3