Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAmartya Sanyal <amartya18x@gmail.com>2017-07-11 20:23:35 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-11 20:23:35 +0300
commit4e085cd2dfb5aba7bb959fd1be9616ee1e35ddf5 (patch)
tree6b9faf7527255f56fe63ae3d4b6283d08ec4803d /test
parent0ea5868b68cb12f5d8d3e61a1237b46711c68a84 (diff)
Alias multinomial sampling in Cuda (#784)
* Support Multinomial Alias sampling in cuda Moving benchmark file * Review changes
Diffstat (limited to 'test')
-rw-r--r--test/benchmarks/multinomial_alias_compare.lua92
-rw-r--r--test/test.lua44
2 files changed, 136 insertions, 0 deletions
diff --git a/test/benchmarks/multinomial_alias_compare.lua b/test/benchmarks/multinomial_alias_compare.lua
new file mode 100644
index 0000000..9386743
--- /dev/null
+++ b/test/benchmarks/multinomial_alias_compare.lua
@@ -0,0 +1,92 @@
+local tester = torch.Tester()
+
+cmd = torch.CmdLine()
+cmd:text()
+cmd:text()
+cmd:text('Testing alias multinomial on cuda')
+cmd:text()
+cmd:text('Options')
+cmd:option('--compare',false,'compare with cutorch multinomial')
+cmd:text()
+
+-- parse input params
+params = cmd:parse(arg)
+
+function aliasMultinomial()
+ local n_class = 10000
+ local n_sample = 100000
+ print("")
+ print("Benchmarking multinomial with "..n_class.." classes and "..n_sample.." samples")
+ torch.seed()
+ local probs = torch.DoubleTensor(n_class):uniform(0,1)
+ probs:div(probs:sum())
+ local a = torch.Timer()
+ local state = torch.multinomialAliasSetup(probs)
+ local cold_time = a:time().real
+ a:reset()
+ local state = torch.multinomialAliasSetup(probs, state)
+ print("[C] torch.aliasMultinomialSetup: "..cold_time.." seconds (cold) and "..a:time().real.." seconds (hot)")
+ a:reset()
+
+ local output = torch.LongTensor(n_sample)
+ torch.multinomialAlias(output, state)
+ print("[C] : torch.aliasMultinomial: "..a:time().real.." seconds (hot)")
+
+ require 'cutorch'
+ a:reset()
+ local cuda_prob = torch.CudaTensor(n_class):copy(probs)
+ cutorch.synchronize()
+ a:reset()
+ local cuda_state
+ for i =1,5 do
+ cuda_state = torch.multinomialAliasSetup(cuda_prob)
+ cutorch.synchronize()
+ end
+ local cold_time = a:time().real/5
+ a:reset()
+ for i = 1,10 do
+ cuda_state = torch.multinomialAliasSetup(cuda_prob, cuda_state)
+ cutorch.synchronize()
+ end
+ print("[CUDA] : torch.aliasMultinomialSetup: "..cold_time.." seconds (cold) and "..(a:time().real/10).." seconds (hot)")
+ tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0")
+ tester:assert(output:max() <= n_class, "indices has an index exceeding num_class")
+ local output = torch.CudaLongTensor(n_sample)
+ local mult_output = torch.CudaTensor(n_sample)
+ cutorch.synchronize()
+ if params['compare'] then
+ a:reset()
+ for i = 1,10 do
+ cuda_prob.multinomial(output, cuda_prob, n_sample, true)
+ cutorch.synchronize()
+ end
+ print("[CUDA] : torch.multinomial draw: "..(a:time().real/10).." seconds (hot)")
+ end
+ a:reset()
+ for i = 1,10 do
+ torch.multinomialAlias(output:view(-1), cuda_state)
+ cutorch.synchronize()
+ end
+ print("[CUDA] : torch.multinomialAlias draw: "..(a:time().real/10).." seconds (hot)")
+
+
+ tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0")
+ tester:assert(output:max() <= n_class, "indices has an index exceeding num_class")
+ a:reset()
+ tester:assert(cuda_state[1]:min() >= 0, "alias indices has an index below or equal to 0")
+ tester:assert(cuda_state[1]:max() < n_class, cuda_state[1]:max().." alias indices has an index exceeding num_class")
+ state[1] = torch.CudaLongTensor(state[1]:size()):copy(state[1])
+ state[2] = torch.CudaTensor(state[2]:size()):copy(state[2])
+ tester:eq(cuda_state[1], state[1], 0.1, "Alias table should be equal")
+ tester:eq(cuda_state[2], state[2], 0.1, "Alias prob table should be equal")
+ local counts = torch.Tensor(n_class):zero()
+ output:long():apply(function(x) counts[x] = counts[x] + 1 end)
+ counts:div(counts:sum())
+ tester:eq(probs, counts, 0.001, "probs and counts should be approximately equal")
+
+end
+
+
+
+tester:add(aliasMultinomial)
+tester:run()
diff --git a/test/test.lua b/test/test.lua
index bd78a4f..92b6e51 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -3233,6 +3233,50 @@ function test.multinomial_vector()
end
end
+function test.multinomial_alias()
+ for tries = 1, 10 do
+ local n_class = torch.random(100)
+ local prob_dist = torch.CudaTensor(n_class):uniform()
+ local n_sample = torch.random(100)
+ local dim_1 = torch.random(10)
+ for _, typename in ipairs(float_typenames) do
+ if typename ~= 'torch.CudaHalfTensor' then
+ -- Get the probability distribution
+ local pd = prob_dist:type(typename)
+ local state = torch.multinomialAliasSetup(pd)
+
+ -- Checking the validity of the setup tables
+ tester:assert(state[1]:min() >= 0, "alias indices has an index below or equal to 0(cold)")
+ tester:assert(state[1]:max() < n_class, state[1]:max().." alias indices has an index exceeding num_class(cold)")
+
+ --Checking the same things if the memory is already allocated
+ local state = torch.multinomialAliasSetup(pd, state)
+ tester:assert(state[1]:min() >= 0, "alias indices has an index below or equal to 0(hot)")
+ tester:assert(state[1]:max() < n_class, state[1]:max().." alias indices has an index exceeding num_class(hot)")
+
+ --Generating a 1d and a 2d long tensor to be filled with indices
+ local sample_indices = torch.CudaLongTensor(n_sample)
+ local sample_indices_dim2 = torch.CudaLongTensor(n_sample/dim_1, dim_1)
+ local state = {state[1], state[2]:type('torch.CudaTensor')}
+ cutorch.synchronize()
+ torch.multinomialAlias(sample_indices, state)
+ cutorch.synchronize()
+ torch.multinomialAlias(sample_indices_dim2:view(-1), state)
+
+ --Checking the validity of the sampled indices
+ tester:assert(sample_indices_dim2:dim() == 2, "wrong sample_indices dim")
+ tester:assert(sample_indices_dim2:size(2) == dim_1, "wrong number of samples")
+ tester:assert(sample_indices:min() > 0, sample_indices:min().."sampled indices has an index below or equal to 0")
+ tester:assert(sample_indices:max() <= n_class, sample_indices:max().."indices has an index exceeding num_class")
+ tester:assert(sample_indices_dim2:min() > 0, sample_indices_dim2:min().."sampled indices has an index below or equal to 0")
+ tester:assert(sample_indices_dim2:max() <= n_class, sample_indices_dim2:max().."indices has an index exceeding num_class")
+
+ end
+ end
+ end
+end
+
+
function test.get_device()
local device_count = cutorch.getDeviceCount()
local tensors = { }