diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-12-13 22:10:28 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-12-13 22:10:28 +0300 |
commit | 6fdd58c414dea8afeb97b736ed5f4f1a86906df1 (patch) | |
tree | 387107a879505ad0a115a462569a2e2002c5c77f /test | |
parent | e00f7d4c0f70e3583ff0a5359095ad7afcaa7009 (diff) |
Implement bernoulli with element-wise probabilities for all types
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/test/test.lua b/test/test.lua index bce8109..53c9563 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2652,18 +2652,28 @@ function test.bernoulli() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) local p = torch.uniform() + local p_fl = torch.rand(sz1, sz2):cuda() + local p_dbl = torch.rand(sz1, sz2):cudaDouble() local t = torch.CudaTensor(sz1, sz2) for _, typename in ipairs(typenames) do local x = t:type(typename) - x:bernoulli(p) - local mean = x:sum() / (sz1 * sz2) - tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p") - local f = x:float() - tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), - torch.FloatTensor(sz1, sz2):fill(1), - 1e-6, - "each value must be either 0 or 1") + local expected_mean + for i, p in ipairs({p, p_fl, p_dbl}) do + x:bernoulli(p) + local mean = x:sum() / (sz1 * sz2) + if torch.type(p) == 'number' then + expected_mean = p + else + expected_mean = p:mean() + end + tester:assertalmosteq(mean, expected_mean, 0.1, "mean is not equal to the expected value") + local f = x:float() + tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), + torch.FloatTensor(sz1, sz2):fill(1), + 1e-6, + "each value must be either 0 or 1") + end end checkMultiDevice(t, 'bernoulli', p) end |