Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-12-13 22:10:28 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-12-13 22:10:28 +0300
commit6fdd58c414dea8afeb97b736ed5f4f1a86906df1 (patch)
tree387107a879505ad0a115a462569a2e2002c5c77f /test
parente00f7d4c0f70e3583ff0a5359095ad7afcaa7009 (diff)
Implement bernoulli with element-wise probabilities for all types
Diffstat (limited to 'test')
-rw-r--r--test/test.lua26
1 files changed, 18 insertions, 8 deletions
diff --git a/test/test.lua b/test/test.lua
index bce8109..53c9563 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2652,18 +2652,28 @@ function test.bernoulli()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local p = torch.uniform()
+ local p_fl = torch.rand(sz1, sz2):cuda()
+ local p_dbl = torch.rand(sz1, sz2):cudaDouble()
local t = torch.CudaTensor(sz1, sz2)
for _, typename in ipairs(typenames) do
local x = t:type(typename)
- x:bernoulli(p)
- local mean = x:sum() / (sz1 * sz2)
- tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p")
- local f = x:float()
- tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
- torch.FloatTensor(sz1, sz2):fill(1),
- 1e-6,
- "each value must be either 0 or 1")
+ local expected_mean
+ for i, p in ipairs({p, p_fl, p_dbl}) do
+ x:bernoulli(p)
+ local mean = x:sum() / (sz1 * sz2)
+ if torch.type(p) == 'number' then
+ expected_mean = p
+ else
+ expected_mean = p:mean()
+ end
+ tester:assertalmosteq(mean, expected_mean, 0.1, "mean is not equal to the expected value")
+ local f = x:float()
+ tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
+ torch.FloatTensor(sz1, sz2):fill(1),
+ 1e-6,
+ "each value must be either 0 or 1")
+ end
end
checkMultiDevice(t, 'bernoulli', p)
end