diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 83 |
1 files changed, 83 insertions, 0 deletions
@@ -2,6 +2,8 @@ -- th -lnn -e "nn.test{'LookupTable'}" -- th -lnn -e "nn.test{'LookupTable', 'Add'}" +nn.hessian.enable() + local mytester = torch.Tester() local jac local sjac @@ -484,6 +486,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diagHessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -513,6 +524,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1384,6 +1404,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1424,6 +1453,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1555,6 +1593,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1593,6 +1640,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1637,6 +1693,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1669,6 +1734,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1714,6 +1788,15 @@ function nntest.SpatialFullConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) |