diff options
author | Amartya Sanyal <amartya18x@gmail.com> | 2017-07-11 20:23:35 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-11 20:23:35 +0300 |
commit | 4e085cd2dfb5aba7bb959fd1be9616ee1e35ddf5 (patch) | |
tree | 6b9faf7527255f56fe63ae3d4b6283d08ec4803d /test | |
parent | 0ea5868b68cb12f5d8d3e61a1237b46711c68a84 (diff) |
Alias multinomial sampling in Cuda (#784)
* Support Multinomial Alias sampling in cuda
Moving benchmark file
* Review changes
Diffstat (limited to 'test')
-rw-r--r-- | test/benchmarks/multinomial_alias_compare.lua | 92 | ||||
-rw-r--r-- | test/test.lua | 44 |
2 files changed, 136 insertions, 0 deletions
diff --git a/test/benchmarks/multinomial_alias_compare.lua b/test/benchmarks/multinomial_alias_compare.lua new file mode 100644 index 0000000..9386743 --- /dev/null +++ b/test/benchmarks/multinomial_alias_compare.lua @@ -0,0 +1,92 @@ +local tester = torch.Tester() + +cmd = torch.CmdLine() +cmd:text() +cmd:text() +cmd:text('Testing alias multinomial on cuda') +cmd:text() +cmd:text('Options') +cmd:option('--compare',false,'compare with cutorch multinomial') +cmd:text() + +-- parse input params +params = cmd:parse(arg) + +function aliasMultinomial() + local n_class = 10000 + local n_sample = 100000 + print("") + print("Benchmarking multinomial with "..n_class.." classes and "..n_sample.." samples") + torch.seed() + local probs = torch.DoubleTensor(n_class):uniform(0,1) + probs:div(probs:sum()) + local a = torch.Timer() + local state = torch.multinomialAliasSetup(probs) + local cold_time = a:time().real + a:reset() + local state = torch.multinomialAliasSetup(probs, state) + print("[C] torch.aliasMultinomialSetup: "..cold_time.." seconds (cold) and "..a:time().real.." seconds (hot)") + a:reset() + + local output = torch.LongTensor(n_sample) + torch.multinomialAlias(output, state) + print("[C] : torch.aliasMultinomial: "..a:time().real.." seconds (hot)") + + require 'cutorch' + a:reset() + local cuda_prob = torch.CudaTensor(n_class):copy(probs) + cutorch.synchronize() + a:reset() + local cuda_state + for i =1,5 do + cuda_state = torch.multinomialAliasSetup(cuda_prob) + cutorch.synchronize() + end + local cold_time = a:time().real/5 + a:reset() + for i = 1,10 do + cuda_state = torch.multinomialAliasSetup(cuda_prob, cuda_state) + cutorch.synchronize() + end + print("[CUDA] : torch.aliasMultinomialSetup: "..cold_time.." seconds (cold) and "..(a:time().real/10).." seconds (hot)") + tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0") + tester:assert(output:max() <= n_class, "indices has an index exceeding num_class") + local output = torch.CudaLongTensor(n_sample) + local mult_output = torch.CudaTensor(n_sample) + cutorch.synchronize() + if params['compare'] then + a:reset() + for i = 1,10 do + cuda_prob.multinomial(output, cuda_prob, n_sample, true) + cutorch.synchronize() + end + print("[CUDA] : torch.multinomial draw: "..(a:time().real/10).." seconds (hot)") + end + a:reset() + for i = 1,10 do + torch.multinomialAlias(output:view(-1), cuda_state) + cutorch.synchronize() + end + print("[CUDA] : torch.multinomialAlias draw: "..(a:time().real/10).." seconds (hot)") + + + tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0") + tester:assert(output:max() <= n_class, "indices has an index exceeding num_class") + a:reset() + tester:assert(cuda_state[1]:min() >= 0, "alias indices has an index below or equal to 0") + tester:assert(cuda_state[1]:max() < n_class, cuda_state[1]:max().." alias indices has an index exceeding num_class") + state[1] = torch.CudaLongTensor(state[1]:size()):copy(state[1]) + state[2] = torch.CudaTensor(state[2]:size()):copy(state[2]) + tester:eq(cuda_state[1], state[1], 0.1, "Alias table should be equal") + tester:eq(cuda_state[2], state[2], 0.1, "Alias prob table should be equal") + local counts = torch.Tensor(n_class):zero() + output:long():apply(function(x) counts[x] = counts[x] + 1 end) + counts:div(counts:sum()) + tester:eq(probs, counts, 0.001, "probs and counts should be approximately equal") + +end + + + +tester:add(aliasMultinomial) +tester:run() diff --git a/test/test.lua b/test/test.lua index bd78a4f..92b6e51 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3233,6 +3233,50 @@ function test.multinomial_vector() end end +function test.multinomial_alias() + for tries = 1, 10 do + local n_class = torch.random(100) + local prob_dist = torch.CudaTensor(n_class):uniform() + local n_sample = torch.random(100) + local dim_1 = torch.random(10) + for _, typename in ipairs(float_typenames) do + if typename ~= 'torch.CudaHalfTensor' then + -- Get the probability distribution + local pd = prob_dist:type(typename) + local state = torch.multinomialAliasSetup(pd) + + -- Checking the validity of the setup tables + tester:assert(state[1]:min() >= 0, "alias indices has an index below or equal to 0(cold)") + tester:assert(state[1]:max() < n_class, state[1]:max().." alias indices has an index exceeding num_class(cold)") + + --Checking the same things if the memory is already allocated + local state = torch.multinomialAliasSetup(pd, state) + tester:assert(state[1]:min() >= 0, "alias indices has an index below or equal to 0(hot)") + tester:assert(state[1]:max() < n_class, state[1]:max().." alias indices has an index exceeding num_class(hot)") + + --Generating a 1d and a 2d long tensor to be filled with indices + local sample_indices = torch.CudaLongTensor(n_sample) + local sample_indices_dim2 = torch.CudaLongTensor(n_sample/dim_1, dim_1) + local state = {state[1], state[2]:type('torch.CudaTensor')} + cutorch.synchronize() + torch.multinomialAlias(sample_indices, state) + cutorch.synchronize() + torch.multinomialAlias(sample_indices_dim2:view(-1), state) + + --Checking the validity of the sampled indices + tester:assert(sample_indices_dim2:dim() == 2, "wrong sample_indices dim") + tester:assert(sample_indices_dim2:size(2) == dim_1, "wrong number of samples") + tester:assert(sample_indices:min() > 0, sample_indices:min().."sampled indices has an index below or equal to 0") + tester:assert(sample_indices:max() <= n_class, sample_indices:max().."indices has an index exceeding num_class") + tester:assert(sample_indices_dim2:min() > 0, sample_indices_dim2:min().."sampled indices has an index below or equal to 0") + tester:assert(sample_indices_dim2:max() <= n_class, sample_indices_dim2:max().."indices has an index exceeding num_class") + + end + end + end +end + + function test.get_device() local device_count = cutorch.getDeviceCount() local tensors = { } |