diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-11 02:10:12 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-12 00:23:03 +0300 |
commit | f7350de131b71fc5175dd6c07dbdf1b0e0c73486 (patch) | |
tree | ce3451f81a0e6a313c91d4ce96a9a80c3057be44 | |
parent | 6da426b51f3557045d90442c95bf82de6a8caf82 (diff) |
[cutorch rand2gen] fix illegal memory access in multinomial code, update unit tests
-rw-r--r-- | lib/THC/generic/THCTensorRandom.cu | 4 | ||||
-rw-r--r-- | test/test.lua | 10 |
2 files changed, 7 insertions, 7 deletions
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index 28af274..50aa2e8 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -96,7 +96,7 @@ void THCTensor_(renormRows)(struct THCState* state, dim3 block(cols < maxThreads ? cols : maxThreads); renormRowsL1<real> - <<<grid, block, block.x * sizeof(float), + <<<grid, block, block.x * sizeof(real), THCState_getCurrentStream(state)>>>(THCTensor_(data)(state, t), rows, cols); } @@ -164,7 +164,7 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); sampleMultinomialOnce - <<<grid, block, block.x * sizeof(float), + <<<grid, block, block.x * sizeof(real), THCState_getCurrentStream(state)>>>( THCTensor_(data)(state, self), numDist, diff --git a/test/test.lua b/test/test.lua index 7747c3a..724d5ff 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2700,7 +2700,7 @@ function test.multinomial_with_replacement() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(t2cpu[typename]) + local pd = prob_dist:type(typename) local sample_indices = torch.multinomial(pd, n_sample, true) tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") @@ -2726,7 +2726,7 @@ function test.multinomial_without_replacement() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(t2cpu[typename]) + local pd = prob_dist:type(typename) local sample_indices = torch.multinomial(pd, n_sample, false) tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") @@ -2764,7 +2764,7 @@ function test.multinomial_without_replacement_gets_all() local orig = t:clone():long() for _, typename in ipairs(float_typenames) do - local x = t:type(t2cpu[typename]) + local x = t:type(typename) -- Sample without replacement local result = torch.multinomial(x, distSize) @@ -2774,7 +2774,7 @@ function test.multinomial_without_replacement_gets_all() -- Sort, and we should have the original results, since without replacement -- sampling everything, we should have chosen every value uniquely result = result:sort(2) - tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all") + tester:assertTensorEq(orig:type(typename), result, 0, "error in multinomial_without_replacement_gets_all") end end end @@ -2784,7 +2784,7 @@ function test.multinomial_vector() local prob_dist = torch.CudaTensor(n_col):uniform() local n_sample = n_col for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(t2cpu[typename]) + local pd = prob_dist:type(typename) local sample_indices = torch.multinomial(pd, n_sample, true) tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim") -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize |