diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-07 21:28:19 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-12 00:22:56 +0300 |
commit | b675b47c0f3548dcca8a8fcaf88e92ef346d3977 (patch) | |
tree | c68c15ccff5169fea389b1725b6c195928e123cc | |
parent | a9910be01c19d9c914c7eb0bded002f1cf299f79 (diff) |
[cutorch rand2gen] make sampleMultinomialWithReplacement utility function generic
-rw-r--r-- | lib/THC/THCTensorRandom.cu | 39 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cuh | 42 |
2 files changed, 42 insertions, 39 deletions
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index f2471ca..f51b8bb 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -364,45 +364,6 @@ void THCudaTensor_renormRows(struct THCState* state, } __global__ void -sampleMultinomialWithReplacement(curandStateMtgp32* state, - int totalSamples, - float* dest, - long distributions, - int categories, - float* normDistPrefixSum) { - // At the moment, each warp computes one sample value in the binary - // search due to divergence. It seems possible to compute multiple - // values and limit divergence though later on. However, no matter - // what, all block threads must participate in the curand_uniform - // call to update the generator state. - - // The block determines the distribution for which we generate a point - for (long curDist = blockIdx.x; - curDist < distributions; - curDist += gridDim.x) { - for (int sampleBase = 0; - sampleBase < totalSamples; sampleBase += blockDim.y) { - // The warp determines the sample - int sample = sampleBase + threadIdx.y; - - // All threads participate in this - float r = curand_uniform(&state[blockIdx.x]); - - if (threadIdx.x == 0 && sample < totalSamples) { - // Find the bucket that a uniform sample lies in - int choice = binarySearchForMultinomial<float>( - normDistPrefixSum + curDist * categories, - categories, - r); - - // Torch indices are 1-based - dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE; - } - } - } -} - -__global__ void sampleMultinomialWithoutReplacement(curandStateMtgp32* state, int totalSamples, int sample, diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh index f7387a4..8ce64fb 100644 --- a/lib/THC/THCTensorRandom.cuh +++ b/lib/THC/THCTensorRandom.cuh @@ -5,6 +5,8 @@ #include "THCReduceApplyUtils.cuh" #include "THCTensorMathReduce.cuh" +#include <curand_kernel.h> + // Normalizes the L1 norm of every row to 1; used by multinomial template <typename T> __global__ void renormRowsL1(T* dist, long rows, long cols) { @@ -152,4 +154,44 @@ __device__ int binarySearchForMultinomial(T* dist, return start; } +template <typename T> +__global__ void +sampleMultinomialWithReplacement(curandStateMtgp32* state, + int totalSamples, + T* dest, + long distributions, + int categories, + T* normDistPrefixSum) { + // At the moment, each warp computes one sample value in the binary + // search due to divergence. It seems possible to compute multiple + // values and limit divergence though later on. However, no matter + // what, all block threads must participate in the curand_uniform + // call to update the generator state. + + // The block determines the distribution for which we generate a point + for (long curDist = blockIdx.x; + curDist < distributions; + curDist += gridDim.x) { + for (int sampleBase = 0; + sampleBase < totalSamples; sampleBase += blockDim.y) { + // The warp determines the sample + int sample = sampleBase + threadIdx.y; + + // All threads participate in this + float r = curand_uniform(&state[blockIdx.x]); + + if (threadIdx.x == 0 && sample < totalSamples) { + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial<T>( + normDistPrefixSum + curDist * categories, + categories, + r); + + // Torch indices are 1-based + dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE); + } + } + } +} + #endif // THC_TENSOR_RANDOM_CUH |