Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAmartya Sanyal <amartya18x@gmail.com>2017-07-11 20:22:36 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-11 20:22:36 +0300
commite17f93ae8976a611cbda9f619ed129614e1c443d (patch)
tree5444df5041ecadd587a1463b73057cdfef852001 /test/test.lua
parent01bcef96bd8d3b77004ca9b6f2718e8f223e8600 (diff)
Implementation of Alias Multinomial for faster Multinomial sampling (#1046)
Diffstat (limited to 'test/test.lua')
-rw-r--r--test/test.lua23
1 files changed, 23 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 2abf016..7b83b9d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1811,6 +1811,29 @@ function torchtest.multinomialwithoutreplacement()
end
end
end
+function torchtest.aliasMultinomial()
+ for i =1,5 do
+ local n_class = 5
+ local t=os.time()
+ torch.manualSeed(t)
+ local probs = torch.Tensor(n_class):uniform(0,1)
+ probs:div(probs:sum())
+ local output = torch.LongTensor(1000, 10000)
+ local n_samples = output:nElement()
+ local prob_state = torch.multinomialAliasSetup(probs)
+ mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0")
+ mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max().." alias indices has an index exceeding num_class")
+ local prob_state = torch.multinomialAliasSetup(probs, prob_state)
+ mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0(cold)")
+ mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max()..","..prob_state[1]:min().." alias indices has an index exceeding num_class(cold)")
+ local output = torch.LongTensor(n_samples)
+ output = torch.multinomialAlias(output, prob_state)
+ mytester:assert(output:nElement() == n_samples, "wrong number of samples")
+ mytester:assert(output:min() > 0, "sampled indices has an index below or equal to 0")
+ mytester:assert(output:max() <= n_class, "indices has an index exceeding num_class")
+ end
+
+end
function torchtest.multinomialvector()
local n_col = 4
local t=os.time()