Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-14 19:34:04 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-14 19:34:04 +0300
commitedcd5d6413ec7241e20643a6fd7f1f489b2a222c (patch)
treedc7e74c52082eadcbfedde759f19aa31fc11c000
parent31c4148a5cdcee88e393ca34748ee8d4eddade0a (diff)
[cutorch rand2gen] make geometric test less flakey
-rw-r--r--test/test.lua10
1 files changed, 8 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua
index 40aa5c3..33147b1 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2591,14 +2591,20 @@ end
function test.geometric()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
- local p = torch.uniform()
+
+ -- unlike other tests, we pick a large p-value to lower the variance, so
+ -- that its highly unlikely the mean falls outside the bounds of the
+ -- specified tolerance
+ local p = 0.8
+ local tolerance = 0.2
+
local t = torch.CudaTensor(sz1, sz2)
local mean = (1 / p)
for _, typename in ipairs(float_typenames) do
local x = t:type(typename)
x:geometric(p)
- tester:assertalmosteq(x:mean(), mean, 0.2, "mean is wrong")
+ tester:assertalmosteq(x:mean(), mean, tolerance, "mean is wrong")
end
checkMultiDevice(t, 'geometric', p)
end