Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-03-15 21:36:10 +0300
committerGitHub <noreply@github.com>2017-03-15 21:36:10 +0300
commitb610a353c19baba213c2fdd6f0cc3a9dcc87a466 (patch)
treea9f755e48af060842cb1527a2187070827a553ef
parent7293be03f00e902179fcb6518b7c43bca2f5790c (diff)
parent008f4b9e8d16d07dbf4a073115730a3531d43406 (diff)
Merge pull request #449 from gchanan/precision_testing
Improve precision testing
-rw-r--r--test.lua51
1 files changed, 29 insertions, 22 deletions
diff --git a/test.lua b/test.lua
index 40b25e4..f8c88f7 100644
--- a/test.lua
+++ b/test.lua
@@ -129,7 +129,8 @@ local function pointwise_backward(proto_module, name, max_error)
local error = rescuda:double() - groundgrad:double()
- mytester:assertlt(error:abs():max(), precision_backward_type(max_error, typename),
+ mytester:assertlt(error:abs():max(),
+ precision_backward_type(max_error, typename, rescuda:abs():max()),
string.format('error on state (backward) with %s', typename))
end
end
@@ -3588,11 +3589,13 @@ function cunntest.mse()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename),
- string.format('error on output with %s', typename))
+ mytester:assertlt(math.abs(fout-cout),
+ precision_forward_type(0.03, typename, math.abs(fout)),
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
- mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ mytester:assertlt(gerr:abs():max(),
+ precision_forward_type(precision_forward, typename),
+ string.format('error on gradInput with %s', typename))
end
end
end
@@ -3619,10 +3622,12 @@ function cunntest.SmoothL1()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename))
+ mytester:assertlt(math.abs(fout-cout),
+ math.max(precision_forward_type(precision_forward, typename, math.abs(fout)), 0.01),
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
end
@@ -3648,10 +3653,10 @@ function cunntest.SoftMarginCriterion()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output')
+ mytester:assertlt(math.abs(fout-cout), 0.01, 'error on output')
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
end
@@ -3680,10 +3685,10 @@ function cunntest.distkldiv()
local cgin = cmod:backward(cinput,ctarget)
mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
- string.format('error on output with %s', typename))
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_backward_type(precision_backward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
end
@@ -4450,11 +4455,13 @@ function cunntest.l1cost()
local cout = cmod:forward(cinput)
local cgin = cmod:backward(cinput)
- mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
- string.format('error on output with %s', typename))
+ mytester:assertlt(math.abs(fout-cout),
+ precision_forward_type(precision_forward, typename, math.abs(fout)),
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
- mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ mytester:assertlt(gerr:abs():max(),
+ precision_forward_type(precision_forward, typename),
+ string.format('error on gradInput with %s', typename))
end
end
@@ -4481,10 +4488,10 @@ function cunntest.ClassNLLCriterionSingleTarget()
mytester:assertlt(
math.abs(fout-cout), precision_forward_type(precision_forward, typename),
- string.format('error on output with %s', typename))
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
@@ -4513,10 +4520,10 @@ function cunntest.ClassNLLCriterionSingleTargetWeights()
mytester:assertlt(
math.abs(fout-cout), precision_forward_type(precision_forward, typename),
- string.format('error on output with %s', typename))
+ string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
@@ -4547,7 +4554,7 @@ function cunntest.ClassNLLCriterionMultipleTarget()
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
@@ -4581,7 +4588,7 @@ function cunntest.SpatialClassNLLCriterion()
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end
@@ -4615,7 +4622,7 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights()
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on gradInput with %s', typename))
+ string.format('error on gradInput with %s', typename))
end
end