Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-08 20:34:19 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-12 00:23:00 +0300
commite86088034ee106250dd18554ba50c415f4216368 (patch)
tree6efdeb2dc449883532b2e1d8a9acca92fa7c8355 /test
parentcb3eb45848cf445b36f95c13161ff1c7e3e545f5 (diff)
[cutorch rand2gen] move multinomial to generic
Diffstat (limited to 'test')
-rw-r--r--test/test.lua99
1 files changed, 56 insertions, 43 deletions
diff --git a/test/test.lua b/test/test.lua
index f75c628..04dce9b 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2692,16 +2692,19 @@ function test.multinomial_with_replacement()
local prob_dist = torch.CudaTensor(n_row, n_col):uniform()
prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled
local n_sample = torch.random(n_col - 1)
- local sample_indices = torch.multinomial(prob_dist, n_sample, true)
- tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
- tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
-
- for i = 1, n_row do
- for j = 1, n_sample do
- local val = sample_indices[{i,j}]
- tester:assert(val == math.floor(val) and val >= 1 and val < n_col,
- "sampled an invalid index: " .. val)
- end
+ for _, typename in ipairs(float_typenames) do
+ local pd = prob_dist:type(t2cpu[typename])
+ local sample_indices = torch.multinomial(pd, n_sample, true)
+ tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
+ tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
+
+ for i = 1, n_row do
+ for j = 1, n_sample do
+ local val = sample_indices[{i,j}]
+ tester:assert(val == math.floor(val) and val >= 1 and val < n_col,
+ "sampled an invalid index: " .. val)
+ end
+ end
end
end
end
@@ -2715,24 +2718,27 @@ function test.multinomial_without_replacement()
local prob_dist = torch.CudaTensor(n_row, n_col):uniform()
prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled
local n_sample = torch.random(n_col - 1)
- local sample_indices = torch.multinomial(prob_dist, n_sample, false)
- tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
- tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
-
- sample_indices = sample_indices:float()
-
- for i = 1, n_row do
- local row_samples = {}
- for j = 1, n_sample do
- local sample_idx = sample_indices[{i,j}]
- tester:assert(
- sample_idx ~= n_col, "sampled an index with zero probability"
- )
- tester:assert(
- not row_samples[sample_idx], "sampled an index twice"
- )
- row_samples[sample_idx] = true
- end
+ for _, typename in ipairs(float_typenames) do
+ local pd = prob_dist:type(t2cpu[typename])
+ local sample_indices = torch.multinomial(pd, n_sample, false)
+ tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
+ tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
+
+ sample_indices = sample_indices:float()
+
+ for i = 1, n_row do
+ local row_samples = {}
+ for j = 1, n_sample do
+ local sample_idx = sample_indices[{i,j}]
+ tester:assert(
+ sample_idx ~= n_col, "sampled an index with zero probability"
+ )
+ tester:assert(
+ not row_samples[sample_idx], "sampled an index twice"
+ )
+ row_samples[sample_idx] = true
+ end
+ end
end
end
end
@@ -2748,17 +2754,21 @@ function test.multinomial_without_replacement_gets_all()
t[dist] = linear
end
- local orig = t:clone()
+ local orig = t:clone():long()
- -- Sample without replacement
- local result = torch.multinomial(t, distSize)
- tester:assert(result:size(1) == distributions)
- tester:assert(result:size(2) == distSize)
+ for _, typename in ipairs(float_typenames) do
+ local x = t:type(t2cpu[typename])
- -- Sort, and we should have the original results, since without replacement
- -- sampling everything, we should have chosen every value uniquely
- result = result:sort(2)
- tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all")
+ -- Sample without replacement
+ local result = torch.multinomial(x, distSize)
+ tester:assert(result:size(1) == distributions)
+ tester:assert(result:size(2) == distSize)
+
+ -- Sort, and we should have the original results, since without replacement
+ -- sampling everything, we should have chosen every value uniquely
+ result = result:sort(2)
+ tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all")
+ end
end
end
@@ -2766,12 +2776,15 @@ function test.multinomial_vector()
local n_col = torch.random(100)
local prob_dist = torch.CudaTensor(n_col):uniform()
local n_sample = n_col
- local sample_indices = torch.multinomial(prob_dist, n_sample, true)
- tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim")
- -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize
- -- was undone
- tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions")
- tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples")
+ for _, typename in ipairs(float_typenames) do
+ local pd = prob_dist:type(t2cpu[typename])
+ local sample_indices = torch.multinomial(pd, n_sample, true)
+ tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim")
+ -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize
+ -- was undone
+ tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions")
+ tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples")
+ end
end
function test.get_device()