diff options
author | Amartya Sanyal <amartya18x@gmail.com> | 2017-07-11 20:22:36 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-11 20:22:36 +0300 |
commit | e17f93ae8976a611cbda9f619ed129614e1c443d (patch) | |
tree | 5444df5041ecadd587a1463b73057cdfef852001 | |
parent | 01bcef96bd8d3b77004ca9b6f2718e8f223e8600 (diff) |
Implementation of Alias Multinomial for faster Multinomial sampling (#1046)
-rw-r--r-- | TensorMath.lua | 32 | ||||
-rwxr-xr-x | doc/maths.md | 56 | ||||
-rw-r--r-- | init.lua | 16 | ||||
-rw-r--r-- | lib/TH/generic/THTensorRandom.c | 110 | ||||
-rw-r--r-- | lib/TH/generic/THTensorRandom.h | 2 | ||||
-rw-r--r-- | test/test.lua | 23 | ||||
-rw-r--r-- | test/test_aliasMultinomial.lua | 40 |
7 files changed, 269 insertions, 10 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 45e07c6..ad7c51c 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1269,16 +1269,30 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b) wrap("multinomial", cname("multinomial"), {{name="IndexTensor", default=true, returned=true, method={default='nil'}}, - {name='Generator', default=true}, - {name=Tensor}, - {name="int"}, - {name="boolean", default=false}}) - + {name='Generator', default=true}, + {name=Tensor}, + {name="int"}, + {name="boolean", default=false}}) + + wrap("multinomialAliasSetup_", + cname("multinomialAliasSetup"), + {{name=Tensor}, + {name="IndexTensor", default=true, returned=true, method={default='nil'}}, + {name=Tensor, default=true, returned=true, method={default='nil'}}}) + + wrap("multinomialAlias_", + cname("multinomialAliasDraw"), + {{name="IndexTensor", default=true, returned=true, method={default='nil'}}, + {name='Generator', default=true}, + {name="IndexTensor"}, + {name=Tensor} + }) + for _,f in ipairs({{name='uniform', a=0, b=1}, - {name='normal', a=0, b=1}, - {name='cauchy', a=0, b=1}, - {name='logNormal', a=1, b=2}}) do - + {name='normal', a=0, b=1}, + {name='cauchy', a=0, b=1}, + {name='logNormal', a=1, b=2}}) do + wrap(f.name, string.format("THRandom_%s", f.name), {{name='Generator', default=true}, diff --git a/doc/maths.md b/doc/maths.md index eb9f5cf..c7d5ec9 100755 --- a/doc/maths.md +++ b/doc/maths.md @@ -255,6 +255,62 @@ p.multinomial(res, p, n, replacement) -- p.multinomial instead of torch.multinom This is due to the fact that the result here is of a `LongTensor` type, and we do not define a `torch.multinomial` over long `Tensor`s. +<a name="torch.multinomialAlias()"></a> +### [state] torch.multinomialAliasSetup(probs) ### +### [res] torch.multinomialAlias(output, state) +`state = torch.multinomialAliasSetup(probs)` returns a table `state` consisting of two `tensors` : `probability table` and an `alias table`. This is required once for each `probs` vectors. We can then sample from the multinomial distribution multiple times by consulting these tensors `state` table. + +`torch.multinomialAlias(output, state)` returns `output` filled with indices drawn from the multinomial distribution `probs`. `output` itself is filled with the indices and it is not necessary to get the return value of the statement. + +The sampling is done through a technique defined in a very simple way in this blog about [The Alias Method](https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/). The paper that describes this technique is present [here](http://www.tandfonline.com/doi/abs/10.1080/00031305.1979.10482697). This can only sample with replacement. + +The `output` `Tensor` that is fed into the `multinomialAlias` method need not be contiguous. The `output` tensor can only be a 1d tensor. If you are required to fill a nd tensor enter a 1d view of the same tensor. This method is exceptionally faster than `torch.multinomial` when you want to sample a lot of samples from the same distrbution or sample from the same distribution a large number of times. `torch.multinomial` is faster for sampling few samples from a distribution once because the `multinomialAliasSetup` method takes some time in this case. To see and compare how these two methods differ in speed run `th test/test_aliasMultinomial.lua`. + +```lua +th> state = torch.multinomialAliasSetup(probs) +th> state +{ + 1 : LongTensor - size: 4 + 2 : DoubleTensor - size: 4 +} +th> output = torch.LongTensor(2,3) +th> torch.multinomialAlias(output:view(-1), state) + 4 + 1 + 2 + 3 + 2 + 2 +[torch.LongTensor of size 6] +th> output + 4 1 2 + 3 2 2 +[torch.LongTensor of size 2x3] +``` + +You can also allocate memory and reuse it for the state table. + +``` +th> state = {torch.LongTensor(), torch.DoubleTensor()} +th> probs = torch.DoubleTensor({0.2, 0.3, 0.5}) +th> state = torch.multinomialAliasSetup(probs, state) +th> state +{ + 1 : LongTensor - size: 3 + 2 : DoubleTensor - size: 3 +} +th> output = torch.LongTensor(7) +th> torch.multinomialAlias(output, state) + 2 + 2 + 3 + 1 + 2 + 2 + 2 +[torch.LongTensor of size 7] +``` + <a name="torch.ones"></a> ### [res] torch.ones([res,] m [,n...]) ### <a name="torch.ones"></a> @@ -159,7 +159,6 @@ require('torch.FFInterface') require('torch.Tester') require('torch.TestSuite') require('torch.test') - function torch.totable(obj) if torch.isTensor(obj) or torch.isStorage(obj) then return obj:totable() @@ -189,4 +188,19 @@ torch.Tensor.isTensor = torch.isTensor -- remove this line to disable automatic heap-tracking for garbage collection torch.setheaptracking(true) +function torch.multinomialAliasSetup(probs, state) + if torch.type(state) == 'table' then + state[1], state[2] = torch.multinomialAliasSetup_(probs, state[1], state[2]) + else + state = {} + state[1], state[2] = torch.multinomialAliasSetup_(probs) + end + return state +end + +function torch.multinomialAlias(output, state) + torch.DoubleTensor.multinomialAlias_(output, state[1], state[2]) + return output +end + return torch diff --git a/lib/TH/generic/THTensorRandom.c b/lib/TH/generic/THTensorRandom.c index 514d3dd..595cfa7 100644 --- a/lib/TH/generic/THTensorRandom.c +++ b/lib/TH/generic/THTensorRandom.c @@ -70,6 +70,116 @@ void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(_generator, mean, stdv);); } + +void THTensor_(multinomialAliasSetup)(THTensor *probs, THLongTensor *J, THTensor *q) +{ + long inputsize = THTensor_(nElement)(probs); + long i = 0; + THLongTensor *smaller = THLongTensor_newWithSize1d(inputsize); + THLongTensor *larger = THLongTensor_newWithSize1d(inputsize); + long small_c = 0; + long large_c = 0; + THLongTensor_resize1d(J, inputsize); + THTensor_(resize1d)(q, inputsize); + real *q_data = THTensor_(data)(q); + long *J_data = THLongTensor_data(J); + + for(i = 0; i < inputsize; i++) + { + THTensor_fastSet1d(J, i, 0L); + real val = THTensor_fastGet1d(probs, i); + THTensor_fastSet1d(q, i, inputsize*val); + + if (inputsize * val < 1.0) + { + THTensor_fastSet1d(smaller, small_c, i); + small_c += 1; + } + else + { + THTensor_fastSet1d(larger, large_c, i); + large_c += 1; + } + } + + // Loop through and create little binary mixtures that + // appropriately allocate the larger outcomes over the + // overall uniform mixture. + long large, small; + while(small_c > 0 && large_c > 0) + { + large = THTensor_fastGet1d(larger, large_c-1); + small = THTensor_fastGet1d(smaller, small_c-1); + + THTensor_fastSet1d(J, small, large); + q_data[large * q->stride[0]] -= 1.0 - THTensor_fastGet1d(q, small); + + if(q_data[large] < 1.0) + { + THTensor_fastSet1d(smaller, small_c-1, large); + large_c -= 1; + } + else + { + THTensor_fastSet1d(larger, large_c-1, large); + small_c -= 1; + } + } + + real q_min = THTensor_fastGet1d(q, inputsize-1); + real q_max = q_min; + real q_temp; + for(i=0; i < inputsize; i++) + { + q_temp = THTensor_fastGet1d(q, i); + if(q_temp < q_min) + q_min = q_temp; + else if(q_temp > q_max) + q_max = q_temp; + } + THArgCheckWithCleanup((q_min > 0), + THCleanup(THLongTensor_free(smaller); THLongTensor_free(larger);), 2, + "q_min is less than 0"); + + if(q_max > 1) + { + for(i=0; i < inputsize; i++) + { + q_data[i*q->stride[0]] /= q_max; + } + } + for(i=0; i<inputsize; i++) + { + // sometimes an large index isn't added to J. + // fix it by making the probability 1 so that J isn't indexed. + if(J_data[i] <= 0) + q_data[i] = 1.0; + } + THLongTensor_free(smaller); + THLongTensor_free(larger); +} +void THTensor_(multinomialAliasDraw)(THLongTensor *self, THGenerator *_generator, THLongTensor *J, THTensor *q) +{ + long K = THLongTensor_nElement(J); + long output_nelem = THLongTensor_nElement(self); + + int i = 0, _mask=0; + real _q; + long rand_ind, sample_idx, J_sample, kk_sample; + for(i=0; i< output_nelem; i++) + { + rand_ind = (long)THRandom_uniform(_generator, 0, K) ; + _q = THTensor_fastGet1d(q, rand_ind); + + _mask = THRandom_bernoulli(_generator, _q); + + J_sample = THTensor_fastGet1d(J, rand_ind); + + sample_idx = J_sample*(1 -_mask) + (rand_ind+1L) * _mask; + + THTensor_fastSet1d(self, i, sample_idx-1L); + } +} void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement) { int start_dim = THTensor_(nDimension)(prob_dist); diff --git a/lib/TH/generic/THTensorRandom.h b/lib/TH/generic/THTensorRandom.h index d205142..e39d589 100644 --- a/lib/TH/generic/THTensorRandom.h +++ b/lib/TH/generic/THTensorRandom.h @@ -15,6 +15,8 @@ TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, doub TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma); TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv); TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement); +TH_API void THTensor_(multinomialAliasSetup)(THTensor *prob_dist, THLongTensor *J, THTensor *q); +TH_API void THTensor_(multinomialAliasDraw)(THLongTensor *self, THGenerator *_generator, THLongTensor *J, THTensor *q); #endif #if defined(TH_REAL_IS_BYTE) diff --git a/test/test.lua b/test/test.lua index 2abf016..7b83b9d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1811,6 +1811,29 @@ function torchtest.multinomialwithoutreplacement() end end end +function torchtest.aliasMultinomial() + for i =1,5 do + local n_class = 5 + local t=os.time() + torch.manualSeed(t) + local probs = torch.Tensor(n_class):uniform(0,1) + probs:div(probs:sum()) + local output = torch.LongTensor(1000, 10000) + local n_samples = output:nElement() + local prob_state = torch.multinomialAliasSetup(probs) + mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0") + mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max().." alias indices has an index exceeding num_class") + local prob_state = torch.multinomialAliasSetup(probs, prob_state) + mytester:assert(prob_state[1]:min() > 0, "Index ="..prob_state[1]:min().."alias indices has an index below or equal to 0(cold)") + mytester:assert(prob_state[1]:max() <= n_class, prob_state[1]:max()..","..prob_state[1]:min().." alias indices has an index exceeding num_class(cold)") + local output = torch.LongTensor(n_samples) + output = torch.multinomialAlias(output, prob_state) + mytester:assert(output:nElement() == n_samples, "wrong number of samples") + mytester:assert(output:min() > 0, "sampled indices has an index below or equal to 0") + mytester:assert(output:max() <= n_class, "indices has an index exceeding num_class") + end + +end function torchtest.multinomialvector() local n_col = 4 local t=os.time() diff --git a/test/test_aliasMultinomial.lua b/test/test_aliasMultinomial.lua new file mode 100644 index 0000000..d935e81 --- /dev/null +++ b/test/test_aliasMultinomial.lua @@ -0,0 +1,40 @@ +local tester = torch.Tester() + + +local function aliasMultinomial() + local n_class = 10000 + local probs = torch.Tensor(n_class):uniform(0,1) + probs:div(probs:sum()) + local a = torch.Timer() + local state = torch.multinomialAliasSetup(probs) + print("AliasMultinomial setup in "..a:time().real.." seconds(hot)") + a:reset() + state = torch.multinomialAliasSetup(probs, state) + print("AliasMultinomial setup in "..a:time().real.." seconds(cold)") + a:reset() + + tester:assert(state[1]:min() >= 0, "Index ="..state[1]:min().."alias indices has an index below or equal to 0") + tester:assert(state[1]:max() <= n_class, state[1]:max().." alias indices has an index exceeding num_class") + local output = torch.LongTensor(1000000) + torch.multinomialAlias(output, state) + local n_samples = output:nElement() + print("AliasMultinomial draw "..n_samples.." elements from "..n_class.." classes ".."in "..a:time().real.." seconds") + local counts = torch.Tensor(n_class):zero() + mult_output = torch.multinomial(probs, n_samples, true) + print("Multinomial draw "..n_samples.." elements from "..n_class.." classes ".." in "..a:time().real.." seconds") + tester:assert(output:min() > 0, "sampled indices has an index below or equal to 0") + tester:assert(output:max() <= n_class, "indices has an index exceeding num_class") + output:apply(function(x) + counts[x] = counts[x] + 1 + end) + a:reset() + + counts:div(counts:sum()) + + tester:assert(state[1]:min() >= 0, "Index ="..state[1]:min().."alias indices has an index below or equal to 0") + tester:assert(state[1]:max() <= n_class, state[1]:max().." alias indices has an index exceeding num_class") + tester:eq(probs, counts, 0.001, "probs and counts should be approximately equal") +end + +tester:add(aliasMultinomial) +tester:run() |