Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-21 22:07:44 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-21 22:07:44 +0300
commitb29c6fd53bbad1935a321a3ffb5a8eb26832cdbf (patch)
tree0844e2063391aef8386d92b7301d11b4dca208d5 /test.lua
parent81348688b7089c733f88f7c43e875e012db6cdbe (diff)
parenta53cb3cce78cb218eee7fe2f3c67a8d0fc411dcc (diff)
Merge pull request #306 from erosennin/fix-hessian
Add unit tests for hessian.lua, fix bugs detected by the tests
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua83
1 files changed, 83 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index ed774bd..3f37dac 100644
--- a/test.lua
+++ b/test.lua
@@ -2,6 +2,8 @@
-- th -lnn -e "nn.test{'LookupTable'}"
-- th -lnn -e "nn.test{'LookupTable', 'Add'}"
+nn.hessian.enable()
+
local mytester = torch.Tester()
local jac
local sjac
@@ -484,6 +486,15 @@ function nntest.Linear()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err,precision, 'error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -513,6 +524,15 @@ function nntest.Linear()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err,precision, 'error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1384,6 +1404,15 @@ function nntest.SpatialConvolution()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1424,6 +1453,15 @@ function nntest.SpatialConvolution()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'batch error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1555,6 +1593,15 @@ function nntest.SpatialConvolutionMap()
local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
mytester:assertlt(err , precision, 'error on bias ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1593,6 +1640,15 @@ function nntest.SpatialConvolutionMap()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'batch error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1637,6 +1693,15 @@ function nntest.SpatialFullConvolution()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1669,6 +1734,15 @@ function nntest.SpatialFullConvolution()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'batch error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))
@@ -1714,6 +1788,15 @@ function nntest.SpatialFullConvolutionMap()
local err = jac.testJacobianUpdateParameters(module, input, module.bias)
mytester:assertlt(err , precision, 'error on bias [direct update] ')
+ local err = jac.testDiagHessianInput(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianInput')
+
+ local err = jac.testDiagHessianWeight(module, input)
+ mytester:assertlt(err , precision, 'error on diagHessianWeight')
+
+ local err = jac.testDiagHessianBias(module, input)
+ mytester:assertlt(err , precision, 'error on diag HessianBias')
+
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
'error on weight [%s]', t))