diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-15 21:36:10 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-15 21:36:10 +0300 |
commit | b610a353c19baba213c2fdd6f0cc3a9dcc87a466 (patch) | |
tree | a9f755e48af060842cb1527a2187070827a553ef | |
parent | 7293be03f00e902179fcb6518b7c43bca2f5790c (diff) | |
parent | 008f4b9e8d16d07dbf4a073115730a3531d43406 (diff) |
Merge pull request #449 from gchanan/precision_testing
Improve precision testing
-rw-r--r-- | test.lua | 51 |
1 files changed, 29 insertions, 22 deletions
@@ -129,7 +129,8 @@ local function pointwise_backward(proto_module, name, max_error) local error = rescuda:double() - groundgrad:double() - mytester:assertlt(error:abs():max(), precision_backward_type(max_error, typename), + mytester:assertlt(error:abs():max(), + precision_backward_type(max_error, typename, rescuda:abs():max()), string.format('error on state (backward) with %s', typename)) end end @@ -3588,11 +3589,13 @@ function cunntest.mse() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename), - string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + precision_forward_type(0.03, typename, math.abs(fout)), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() - mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + mytester:assertlt(gerr:abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on gradInput with %s', typename)) end end end @@ -3619,10 +3622,12 @@ function cunntest.SmoothL1() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + math.max(precision_forward_type(precision_forward, typename, math.abs(fout)), 0.01), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -3648,10 +3653,10 @@ function cunntest.SoftMarginCriterion() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output') + mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output') local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -3680,10 +3685,10 @@ function cunntest.distkldiv() local cgin = cmod:backward(cinput,ctarget) mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_backward_type(precision_backward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -4450,11 +4455,13 @@ function cunntest.l1cost() local cout = cmod:forward(cinput) local cgin = cmod:backward(cinput) - mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + precision_forward_type(precision_forward, typename, math.abs(fout)), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() - mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + mytester:assertlt(gerr:abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on gradInput with %s', typename)) end end @@ -4481,10 +4488,10 @@ function cunntest.ClassNLLCriterionSingleTarget() mytester:assertlt( math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4513,10 +4520,10 @@ function cunntest.ClassNLLCriterionSingleTargetWeights() mytester:assertlt( math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4547,7 +4554,7 @@ function cunntest.ClassNLLCriterionMultipleTarget() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4581,7 +4588,7 @@ function cunntest.SpatialClassNLLCriterion() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4615,7 +4622,7 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end |