diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-08 20:34:19 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-12 00:23:00 +0300 |
commit | e86088034ee106250dd18554ba50c415f4216368 (patch) | |
tree | 6efdeb2dc449883532b2e1d8a9acca92fa7c8355 /test | |
parent | cb3eb45848cf445b36f95c13161ff1c7e3e545f5 (diff) |
[cutorch rand2gen] move multinomial to generic
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 99 |
1 files changed, 56 insertions, 43 deletions
diff --git a/test/test.lua b/test/test.lua index f75c628..04dce9b 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2692,16 +2692,19 @@ function test.multinomial_with_replacement() local prob_dist = torch.CudaTensor(n_row, n_col):uniform() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) - local sample_indices = torch.multinomial(prob_dist, n_sample, true) - tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") - tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") - - for i = 1, n_row do - for j = 1, n_sample do - local val = sample_indices[{i,j}] - tester:assert(val == math.floor(val) and val >= 1 and val < n_col, - "sampled an invalid index: " .. val) - end + for _, typename in ipairs(float_typenames) do + local pd = prob_dist:type(t2cpu[typename]) + local sample_indices = torch.multinomial(pd, n_sample, true) + tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") + tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") + + for i = 1, n_row do + for j = 1, n_sample do + local val = sample_indices[{i,j}] + tester:assert(val == math.floor(val) and val >= 1 and val < n_col, + "sampled an invalid index: " .. val) + end + end end end end @@ -2715,24 +2718,27 @@ function test.multinomial_without_replacement() local prob_dist = torch.CudaTensor(n_row, n_col):uniform() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) - local sample_indices = torch.multinomial(prob_dist, n_sample, false) - tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") - tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") - - sample_indices = sample_indices:float() - - for i = 1, n_row do - local row_samples = {} - for j = 1, n_sample do - local sample_idx = sample_indices[{i,j}] - tester:assert( - sample_idx ~= n_col, "sampled an index with zero probability" - ) - tester:assert( - not row_samples[sample_idx], "sampled an index twice" - ) - row_samples[sample_idx] = true - end + for _, typename in ipairs(float_typenames) do + local pd = prob_dist:type(t2cpu[typename]) + local sample_indices = torch.multinomial(pd, n_sample, false) + tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") + tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") + + sample_indices = sample_indices:float() + + for i = 1, n_row do + local row_samples = {} + for j = 1, n_sample do + local sample_idx = sample_indices[{i,j}] + tester:assert( + sample_idx ~= n_col, "sampled an index with zero probability" + ) + tester:assert( + not row_samples[sample_idx], "sampled an index twice" + ) + row_samples[sample_idx] = true + end + end end end end @@ -2748,17 +2754,21 @@ function test.multinomial_without_replacement_gets_all() t[dist] = linear end - local orig = t:clone() + local orig = t:clone():long() - -- Sample without replacement - local result = torch.multinomial(t, distSize) - tester:assert(result:size(1) == distributions) - tester:assert(result:size(2) == distSize) + for _, typename in ipairs(float_typenames) do + local x = t:type(t2cpu[typename]) - -- Sort, and we should have the original results, since without replacement - -- sampling everything, we should have chosen every value uniquely - result = result:sort(2) - tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all") + -- Sample without replacement + local result = torch.multinomial(x, distSize) + tester:assert(result:size(1) == distributions) + tester:assert(result:size(2) == distSize) + + -- Sort, and we should have the original results, since without replacement + -- sampling everything, we should have chosen every value uniquely + result = result:sort(2) + tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all") + end end end @@ -2766,12 +2776,15 @@ function test.multinomial_vector() local n_col = torch.random(100) local prob_dist = torch.CudaTensor(n_col):uniform() local n_sample = n_col - local sample_indices = torch.multinomial(prob_dist, n_sample, true) - tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim") - -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize - -- was undone - tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions") - tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples") + for _, typename in ipairs(float_typenames) do + local pd = prob_dist:type(t2cpu[typename]) + local sample_indices = torch.multinomial(pd, n_sample, true) + tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim") + -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize + -- was undone + tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions") + tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples") + end end function test.get_device() |