Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-07 21:28:19 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-12 00:22:56 +0300
commitb675b47c0f3548dcca8a8fcaf88e92ef346d3977 (patch)
treec68c15ccff5169fea389b1725b6c195928e123cc
parenta9910be01c19d9c914c7eb0bded002f1cf299f79 (diff)
[cutorch rand2gen] make sampleMultinomialWithReplacement utility function generic
-rw-r--r--lib/THC/THCTensorRandom.cu39
-rw-r--r--lib/THC/THCTensorRandom.cuh42
2 files changed, 42 insertions, 39 deletions
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index f2471ca..f51b8bb 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -364,45 +364,6 @@ void THCudaTensor_renormRows(struct THCState* state,
}
__global__ void
-sampleMultinomialWithReplacement(curandStateMtgp32* state,
- int totalSamples,
- float* dest,
- long distributions,
- int categories,
- float* normDistPrefixSum) {
- // At the moment, each warp computes one sample value in the binary
- // search due to divergence. It seems possible to compute multiple
- // values and limit divergence though later on. However, no matter
- // what, all block threads must participate in the curand_uniform
- // call to update the generator state.
-
- // The block determines the distribution for which we generate a point
- for (long curDist = blockIdx.x;
- curDist < distributions;
- curDist += gridDim.x) {
- for (int sampleBase = 0;
- sampleBase < totalSamples; sampleBase += blockDim.y) {
- // The warp determines the sample
- int sample = sampleBase + threadIdx.y;
-
- // All threads participate in this
- float r = curand_uniform(&state[blockIdx.x]);
-
- if (threadIdx.x == 0 && sample < totalSamples) {
- // Find the bucket that a uniform sample lies in
- int choice = binarySearchForMultinomial<float>(
- normDistPrefixSum + curDist * categories,
- categories,
- r);
-
- // Torch indices are 1-based
- dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE;
- }
- }
- }
-}
-
-__global__ void
sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
int totalSamples,
int sample,
diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh
index f7387a4..8ce64fb 100644
--- a/lib/THC/THCTensorRandom.cuh
+++ b/lib/THC/THCTensorRandom.cuh
@@ -5,6 +5,8 @@
#include "THCReduceApplyUtils.cuh"
#include "THCTensorMathReduce.cuh"
+#include <curand_kernel.h>
+
// Normalizes the L1 norm of every row to 1; used by multinomial
template <typename T>
__global__ void renormRowsL1(T* dist, long rows, long cols) {
@@ -152,4 +154,44 @@ __device__ int binarySearchForMultinomial(T* dist,
return start;
}
+template <typename T>
+__global__ void
+sampleMultinomialWithReplacement(curandStateMtgp32* state,
+ int totalSamples,
+ T* dest,
+ long distributions,
+ int categories,
+ T* normDistPrefixSum) {
+ // At the moment, each warp computes one sample value in the binary
+ // search due to divergence. It seems possible to compute multiple
+ // values and limit divergence though later on. However, no matter
+ // what, all block threads must participate in the curand_uniform
+ // call to update the generator state.
+
+ // The block determines the distribution for which we generate a point
+ for (long curDist = blockIdx.x;
+ curDist < distributions;
+ curDist += gridDim.x) {
+ for (int sampleBase = 0;
+ sampleBase < totalSamples; sampleBase += blockDim.y) {
+ // The warp determines the sample
+ int sample = sampleBase + threadIdx.y;
+
+ // All threads participate in this
+ float r = curand_uniform(&state[blockIdx.x]);
+
+ if (threadIdx.x == 0 && sample < totalSamples) {
+ // Find the bucket that a uniform sample lies in
+ int choice = binarySearchForMultinomial<T>(
+ normDistPrefixSum + curDist * categories,
+ categories,
+ r);
+
+ // Torch indices are 1-based
+ dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE);
+ }
+ }
+ }
+}
+
#endif // THC_TENSOR_RANDOM_CUH