From 008f4b9e8d16d07dbf4a073115730a3531d43406 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Mon, 6 Mar 2017 08:44:02 -0800 Subject: Improve precision testing 1) In cases where tests were failing with some regularity (SmoothL1, l1cost, mse, SoftShrink-backward) scale error bounds by (absolute) value being tested. 2) Fix some spacing issues in error messages. --- test.lua | 51 +++++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/test.lua b/test.lua index 40b25e4..f8c88f7 100644 --- a/test.lua +++ b/test.lua @@ -129,7 +129,8 @@ local function pointwise_backward(proto_module, name, max_error) local error = rescuda:double() - groundgrad:double() - mytester:assertlt(error:abs():max(), precision_backward_type(max_error, typename), + mytester:assertlt(error:abs():max(), + precision_backward_type(max_error, typename, rescuda:abs():max()), string.format('error on state (backward) with %s', typename)) end end @@ -3588,11 +3589,13 @@ function cunntest.mse() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename), - string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + precision_forward_type(0.03, typename, math.abs(fout)), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() - mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + mytester:assertlt(gerr:abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on gradInput with %s', typename)) end end end @@ -3619,10 +3622,12 @@ function cunntest.SmoothL1() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + math.max(precision_forward_type(precision_forward, typename, math.abs(fout)), 0.01), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -3648,10 +3653,10 @@ function cunntest.SoftMarginCriterion() local cout = cmod:forward(cinput,ctarget) local cgin = cmod:backward(cinput,ctarget) - mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output') + mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output') local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -3680,10 +3685,10 @@ function cunntest.distkldiv() local cgin = cmod:backward(cinput,ctarget) mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_backward_type(precision_backward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end end @@ -4450,11 +4455,13 @@ function cunntest.l1cost() local cout = cmod:forward(cinput) local cgin = cmod:backward(cinput) - mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + mytester:assertlt(math.abs(fout-cout), + precision_forward_type(precision_forward, typename, math.abs(fout)), + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() - mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + mytester:assertlt(gerr:abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on gradInput with %s', typename)) end end @@ -4481,10 +4488,10 @@ function cunntest.ClassNLLCriterionSingleTarget() mytester:assertlt( math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4513,10 +4520,10 @@ function cunntest.ClassNLLCriterionSingleTargetWeights() mytester:assertlt( math.abs(fout-cout), precision_forward_type(precision_forward, typename), - string.format('error on output with %s', typename)) + string.format('error on output with %s', typename)) local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4547,7 +4554,7 @@ function cunntest.ClassNLLCriterionMultipleTarget() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4581,7 +4588,7 @@ function cunntest.SpatialClassNLLCriterion() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end @@ -4615,7 +4622,7 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights() local gerr = cgin:double() - fgin:double() mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename), - string.format('error on gradInput with %s', typename)) + string.format('error on gradInput with %s', typename)) end end -- cgit v1.2.3