diff options
author | Amartya Sanyal <amartya18x@gmail.com> | 2017-07-11 20:22:36 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-11 20:22:36 +0300 |
commit | e17f93ae8976a611cbda9f619ed129614e1c443d (patch) | |
tree | 5444df5041ecadd587a1463b73057cdfef852001 /test/test.lua | |
parent | 01bcef96bd8d3b77004ca9b6f2718e8f223e8600 (diff) |
Implementation of Alias Multinomial for faster Multinomial sampling (#1046)
Diffstat (limited to 'test/test.lua')
-rw-r--r-- | test/test.lua | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 2abf016..7b83b9d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1811,6 +1811,29 @@ function torchtest.multinomialwithoutreplacement() end end end +function torchtest.aliasMultinomial() + for i =1,5 do + local n_class = 5 + local t=os.time() + torch.manualSeed(t) + local probs = torch.Tensor(n_class):uniform(0,1) + probs:div(probs:sum()) + local output = torch.LongTensor(1000, 10000) + local n_samples = output:nElement() + local prob_state = torch.multinomialAliasSetup(probs) + mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0") + mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max().." alias indices has an index exceeding num_class") + local prob_state = torch.multinomialAliasSetup(probs, prob_state) + mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0(cold)") + mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max()..","..prob_state[1]:min().." alias indices has an index exceeding num_class(cold)") + local output = torch.LongTensor(n_samples) + output = torch.multinomialAlias(output, prob_state) + mytester:assert(output:nElement() == n_samples, "wrong number of samples") + mytester:assert(output:min() > 0, "sampled indices has an index below or equal to 0") + mytester:assert(output:max() <= n_class, "indices has an index exceeding num_class") + end + +end function torchtest.multinomialvector() local n_col = 4 local t=os.time() |