Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-12 02:15:01 +0300
committerGitHub <noreply@github.com>2016-11-12 02:15:01 +0300
commitbdf2b06c9d7e1e53e78f4ac78f6f5dfa4f1b9020 (patch)
tree3f44c6cfba108883738b632637ec836a751a35ff
parent71df98b3cbe84ce943e492732f036dd5526ed329 (diff)
Revert "Move random functions to generic"revert-589-random-refactor
-rw-r--r--TensorMath.lua43
-rw-r--r--lib/THC/CMakeLists.txt3
-rw-r--r--lib/THC/THCTensorMath.h2
-rw-r--r--lib/THC/THCTensorMath2.cu13
-rw-r--r--lib/THC/THCTensorRandom.cu567
-rw-r--r--lib/THC/THCTensorRandom.cuh278
-rw-r--r--lib/THC/THCTensorRandom.h12
-rw-r--r--lib/THC/generic/THCTensorRandom.cu328
-rw-r--r--lib/THC/generic/THCTensorRandom.h21
-rw-r--r--test/test.lua177
10 files changed, 634 insertions, 810 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index f94154d..a5d436f 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -894,15 +894,6 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor .. "Array"},
{name="index", default=lastdimarray(2)}})
- for _,f in ipairs({{name='geometric'},
- {name='bernoulli', a=0.5}}) do
-
- wrap(f.name,
- cname(f.name),
- {{name=Tensor, returned=true},
- {name='double', default=f.a}})
- end
-
wrap("nonzero",
cname("nonzero"),
{{name="CudaLongTensor", default=true, returned=true},
@@ -934,40 +925,6 @@ for k, Tensor_ in pairs(handledTypenames) do
{name = real},
{name=Tensor, method={default=1}}})
- wrap("rand",
- cname("rand"),
- {{name=Tensor, default=true, returned=true, method={default='nil'}},
- {name="LongArg"}})
-
- wrap("randn",
- cname("randn"),
- {{name=Tensor, default=true, returned=true, method={default='nil'}},
- {name="LongArg"}})
-
- wrap("multinomial",
- cname("multinomial"),
- {{name=Tensor, default=true, returned=true, method={default='nil'}},
- {name=Tensor},
- {name="int"},
- {name="boolean", default=false}})
-
- for _,f in ipairs({{name='uniform', a=0, b=1},
- {name='cauchy', a=0, b=1},
- {name='normal', a=0, b=1},
- {name='logNormal', a=1, b=2}}) do
-
- wrap(f.name,
- cname(f.name),
- {{name=Tensor, returned=true},
- {name='double', default=f.a},
- {name='double', default=f.b}})
- end
-
- wrap('exponential',
- cname('exponential'),
- {{name=Tensor, returned=true},
- {name='double', default=nil}})
-
wrap("norm",
cname("normall"),
{{name=Tensor},
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 51f568a..e5c202a 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -238,7 +238,6 @@ INSTALL(FILES
THCTensorSort.cuh
THCTensorInfo.cuh
THCTensorTypeUtils.cuh
- THCTensorRandom.cuh
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC")
INSTALL(FILES
@@ -279,6 +278,4 @@ INSTALL(FILES
generic/THCTensorSort.h
generic/THCTensorSort.cu
generic/THCDeviceTensorUtils.cu
- generic/THCTensorRandom.h
- generic/THCTensorRandom.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 7cbef32..759c9a3 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -53,6 +53,8 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
+THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
+THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size);
THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 7e6af9b..9933b7e 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -28,3 +28,16 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
+void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, r_));
+ THCudaTensor_resize(state, r_, size, NULL);
+ THCudaTensor_uniform(state, r_, 0, 1);
+}
+
+void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, r_));
+ THCudaTensor_resize(state, r_, size, NULL);
+ THCudaTensor_normal(state, r_, 0, 1);
+}
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index 4493fe8..827692f 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -4,7 +4,6 @@
#include "THCTensorCopy.h"
#include "THCTensorMath.h"
#include "THCReduceApplyUtils.cuh"
-#include "THCTensorRandom.cuh"
#include <thrust/functional.h>
#include <curand.h>
@@ -188,53 +187,567 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
memcpy(&gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size);
}
-#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \
-__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \
+#define GENERATE_KERNEL1(NAME, ARG1, CURAND_FUNC, TRANSFORM) \
+__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
- CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
+ float x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
- T y = TRANSFORM; \
- result[i] = y; \
+ x = TRANSFORM; \
+ result[i] = x; \
} \
} \
}
-#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \
-__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \
+#define GENERATE_KERNEL2(NAME, ARG1, ARG2, CURAND_FUNC, TRANSFORM) \
+__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1, ARG2) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
- CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
+ float x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
- T y = TRANSFORM; \
- result[i] = y; \
+ x = TRANSFORM; \
+ result[i] = x; \
} \
} \
}
-GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, x * (b-a) + a)
-GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a)
-GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<float, half>::to(x * (b-a) + a)))
+GENERATE_KERNEL2(generate_uniform, double a, double b, curand_uniform, x * (b-a) + a)
+GENERATE_KERNEL1(generate_bernoulli, double p, curand_uniform, (float)x <= p)
+GENERATE_KERNEL2(generate_normal, double mean, double stdv, curand_normal, (x * stdv) + mean)
+GENERATE_KERNEL1(generate_geometric, double p, curand_uniform, (log(1-x) / log(p)) + 1)
+GENERATE_KERNEL1(generate_exponential, double lambda, curand_uniform, (float)(-1. / lambda * log(1-x)))
+GENERATE_KERNEL2(generate_cauchy, double median, double sigma, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
-GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
-GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)
-GENERATE_KERNEL2(generate_normal, half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, half>::to((x * stdv) + mean)))
+#undef GENERATE_KERNEL1
+#undef GENERATE_KERNEL2
-GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(1-x)))
-GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(1-x)))
-GENERATE_KERNEL1(generate_exponential, half, double lambda, float, curand_uniform, (ScalarConvert<float, half>::to((float)(-1. / lambda * log(1-x)))))
+/* Separate kernel because curand_log_normal gets extra parameters. */
+__global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *result, float mean, float stddev)
+{
+ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
+ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
+ float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
+ if (i < size) {
+ result[i] = x;
+ }
+ }
+}
-GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
-GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))
-GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))
+#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
+THC_API void THCudaTensor_uniform(THCState* state, THCudaTensor *self_, double a, double b)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
-#include "generic/THCTensorRandom.cu"
-#include "THCGenerateAllTypes.h"
+ generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, a, b);
-#undef GENERATE_KERNEL1
-#undef GENERATE_KERNEL2
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_bernoulli(THCState* state, THCudaTensor *self_, double p)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_bernoulli<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, p);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_normal(THCState* state, THCudaTensor *self_, double mean, double stdv)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, mean, stdv);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_logNormal(THCState* state, THCudaTensor *self_, double mean, double stdv)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_log_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, mean, stdv);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_geometric(THCState* state, THCudaTensor *self_, double p)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, p);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_exponential(THCState* state, THCudaTensor *self_, double lambda)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, lambda);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+THC_API void THCudaTensor_cauchy(THCState* state, THCudaTensor *self_, double median, double sigma)
+{
+ THAssert(THCudaTensor_checkGPU(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ ptrdiff_t size = THCudaTensor_nElement(state, self);
+ float *data = THCudaTensor_data(state, self);
+
+ generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, median, sigma);
+
+ THCudaTensor_freeCopyTo(state, self, self_);
+};
+
+__device__ int binarySearchForMultinomial(float* dist,
+ int size,
+ float val) {
+ int start = 0;
+ int end = size;
+
+ while (end - start > 0) {
+ int mid = start + (end - start) / 2;
+
+ float midVal = dist[mid];
+ if (midVal < val) {
+ start = mid + 1;
+ } else {
+ end = mid;
+ }
+ }
+
+ if (start == size) {
+ // No probability mass or precision problems; just return the
+ // first element
+ start = 0;
+ }
+
+ return start;
+}
+
+// Normalizes the L1 norm of every row to 1; used by multinomial
+__global__ void renormRowsL1(float* dist, long rows, long cols) {
+ extern __shared__ float smem[];
+
+ for (long row = blockIdx.x; row < rows; row += gridDim.x) {
+ float sum = 0.0f;
+ for (long col = threadIdx.x; col < cols; col += blockDim.x) {
+ sum += dist[row * cols + col];
+ }
+
+ sum = reduceBlock(smem, blockDim.x, sum, thrust::plus<float>(), 0.0f);
+ if (threadIdx.x == 0) {
+ smem[0] = sum;
+ }
+ __syncthreads();
+
+ sum = smem[0];
+ if (sum > 0.0f) {
+ for (long col = threadIdx.x; col < cols; col += blockDim.x) {
+ dist[row * cols + col] /= sum;
+ }
+ }
+ }
+}
+
+void THCudaTensor_renormRows(struct THCState* state,
+ THCudaTensor* t) {
+ THAssert(THCudaTensor_nDimension(state, t) == 2);
+ long rows = THCudaTensor_size(state, t, 0);
+ long cols = THCudaTensor_size(state, t, 1);
+
+ cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
+ THAssert(props != NULL);
+
+ int numSM = props->multiProcessorCount;
+ int maxThreads = props->maxThreadsPerBlock;
+
+ dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
+ dim3 block(cols < maxThreads ? cols : maxThreads);
+
+ renormRowsL1
+ <<<grid, block, block.x * sizeof(float),
+ THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, t),
+ rows, cols);
+}
+
+__global__ void
+sampleMultinomialOnce(float* dest,
+ long distributions,
+ int categories,
+ float* dist) {
+ extern __shared__ float smem[];
+
+ for (long curDist = blockIdx.x;
+ curDist < distributions; curDist += gridDim.x) {
+ // Each block handles one distribution
+ // First pass, find the total sum of the distribution
+ float sum = 0.0f;
+ for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
+ sum += dist[curDist * categories + cat];
+ }
+
+ // threadIdx.x == 0 has the sum value from this
+ sum = reduceBlock(smem, blockDim.x, sum, thrust::plus<float>(), 0.0f);
+
+ // Broadcast sum and sample value
+ if (threadIdx.x == 0) {
+ smem[0] = sum;
+ smem[1] = dest[curDist];
+ }
+ __syncthreads();
+
+ sum = smem[0];
+ float sample = smem[1];
+ __syncthreads();
+
+ if (sum == 0.0f || sample == 0.0f) {
+ // Choose the first element
+ if (threadIdx.x == 0) {
+ dest[curDist] = 1;
+ }
+
+ continue;
+ }
+
+ int chunks = THCCeilDiv(categories, (int) blockDim.x);
+ float prevHighProb = 0.0f;
+
+ for (int chunk = 0; chunk < chunks; ++chunk) {
+ // All threads in bounds load a value
+ int cat = chunk * blockDim.x + threadIdx.x;
+
+ float val =
+ cat < categories ? dist[curDist * categories + cat] / sum : 0.0f;
+ smem[threadIdx.x] = val;
+ __syncthreads();
+
+ // Perform an inclusive prefix sum of the shared memory contents
+ for (int offset = 1; offset < blockDim.x; offset *= 2) {
+ float val = 0.0f;
+
+ if (threadIdx.x >= offset) {
+ val = smem[threadIdx.x - offset] + smem[threadIdx.x];
+ }
+
+ __syncthreads();
+ if (threadIdx.x >= offset) {
+ smem[threadIdx.x] = val;
+ }
+ __syncthreads();
+ }
+
+ // Each thread will check to see if the sample falls in its
+ // bucket
+ float curBucket =
+ smem[threadIdx.x] + prevHighProb;
+ float prevBucket =
+ threadIdx.x == 0 ? prevHighProb : smem[threadIdx.x - 1] + prevHighProb;
+ bool inBucket =
+ (cat < categories) && (sample <= curBucket) && (sample > prevBucket);
+
+ if (inBucket) {
+ // We're done; we have the sample
+ // Torch indices are 1-based
+ // FIXME: broadcast exit flag?
+ dest[curDist] = cat + TH_INDEX_BASE;
+ }
+
+ // Store the previous scan's high value for future use
+ prevHighProb += smem[blockDim.x - 1];
+
+ __syncthreads();
+ }
+ }
+}
+
+__global__ void
+sampleMultinomialWithReplacement(curandStateMtgp32* state,
+ int totalSamples,
+ float* dest,
+ long distributions,
+ int categories,
+ float* normDistPrefixSum) {
+ // At the moment, each warp computes one sample value in the binary
+ // search due to divergence. It seems possible to compute multiple
+ // values and limit divergence though later on. However, no matter
+ // what, all block threads must participate in the curand_uniform
+ // call to update the generator state.
+
+ // The block determines the distribution for which we generate a point
+ for (long curDist = blockIdx.x;
+ curDist < distributions;
+ curDist += gridDim.x) {
+ for (int sampleBase = 0;
+ sampleBase < totalSamples; sampleBase += blockDim.y) {
+ // The warp determines the sample
+ int sample = sampleBase + threadIdx.y;
+
+ // All threads participate in this
+ float r = curand_uniform(&state[blockIdx.x]);
+
+ if (threadIdx.x == 0 && sample < totalSamples) {
+ // Find the bucket that a uniform sample lies in
+ int choice = binarySearchForMultinomial(
+ normDistPrefixSum + curDist * categories,
+ categories,
+ r);
+
+ // Torch indices are 1-based
+ dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE;
+ }
+ }
+ }
+}
+
+__global__ void
+sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
+ int totalSamples,
+ int sample,
+ float* dest,
+ long distributions,
+ int categories,
+ float* origDist,
+ float* normDistPrefixSum) {
+ // At the moment, each warp computes one sample value in the binary
+ // search due to divergence. It seems possible to compute multiple
+ // values and limit divergence though later on. However, no matter
+ // what, all block threads must participate in the curand_uniform
+ // call to update the generator state.
+
+ // The block and warp determines the distribution for which we
+ // generate a point
+ for (long curDistBase = blockIdx.x * blockDim.y;
+ curDistBase < distributions;
+ curDistBase += gridDim.x * blockDim.y) {
+ // The warp determines the distribution
+ long curDist = curDistBase + threadIdx.y;
+
+ // All threads must participate in this
+ float r = curand_uniform(&state[blockIdx.x]);
+
+ if (threadIdx.x == 0 && curDist < distributions) {
+ // Find the bucket that a uniform sample lies in
+ int choice = binarySearchForMultinomial(
+ normDistPrefixSum + curDist * categories,
+ categories,
+ r);
+
+ // Torch indices are 1-based
+ dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE;
+
+ // Without replacement, so update the original probability so it
+ // is not considered a second time
+ origDist[curDist * categories + choice] = 0.0f;
+ }
+ }
+}
+
+THC_API void THCudaTensor_multinomial(struct THCState *state,
+ THCudaTensor *self,
+ THCudaTensor *prob_dist,
+ int n_sample,
+ int with_replacement)
+{
+ THAssert(THCudaTensor_checkGPU(state, 2, self, prob_dist));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ int inputSize = THCudaTensor_nDimension(state, prob_dist);
+ THArgCheck(inputSize > 0 && inputSize <= 2, 2,
+ "prob_dist must be 1 or 2 dim");
+
+ // Categories are in the innermost dimension
+ long numDist =
+ inputSize == 1 ? 1 : THCudaTensor_size(state, prob_dist, 0);
+ long numCategoriesLong =
+ inputSize == 1 ? THCudaTensor_size(state, prob_dist, 0) :
+ THCudaTensor_size(state, prob_dist, 1);
+
+ // Since the index tensor is float, numCategories cannot exceed max
+ // float integer precision
+ THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2,
+ "number of categories cannot exceed 2^24");
+ int numCategories = (int) numCategoriesLong;
+
+ THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples");
+
+ if (!with_replacement) {
+ THArgCheck(n_sample <= numCategories, 2,
+ "cannot sample n_sample > prob_dist:size(1) samples without "
+ "replacement");
+ }
+
+ // It is possible that prob_dist is non-contiguous
+ THCudaTensor* probDistContig =
+ THCudaTensor_newContiguous(state, prob_dist);
+
+ // Restructure data for 2d
+ if (inputSize == 1) {
+ THCudaTensor_resize2d(state, probDistContig, 1, numCategories);
+ }
+
+ THCudaTensor_resize2d(state, self, numDist, n_sample);
+
+ if (n_sample == 1) {
+ // Optimized allocation-free implementation
+
+ // To exploit greater parallelism for the sampling, generate the
+ // Uniform random samples in a separate kernel launch, into the
+ // result memory. The device RNG is thread-limited
+ THCudaTensor_uniform(state, self, 0.0, 1.0);
+
+ cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
+ THAssert(props != NULL);
+
+ int numSM = props->multiProcessorCount;
+ int maxThreads = props->maxThreadsPerBlock;
+
+ dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
+ dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
+
+ sampleMultinomialOnce
+ <<<grid, block, block.x * sizeof(float),
+ THCState_getCurrentStream(state)>>>(
+ THCudaTensor_data(state, self),
+ numDist,
+ numCategories,
+ THCudaTensor_data(state, probDistContig));
+ } else {
+ // Generic, slow implementation with memory allocations
+
+ // For sampling without replacement, we modify the distribution
+ // for subsequent samples in this space
+ THCudaTensor* origDist = THCudaTensor_new(state);
+ THCudaTensor_resizeAs(state, origDist, probDistContig);
+ THCudaTensor_copy(state, origDist, probDistContig);
+
+ THCudaTensor* normDist = THCudaTensor_new(state);
+ THCudaTensor_resizeAs(state, normDist, probDistContig);
+
+ THCudaTensor* prefixSum = THCudaTensor_new(state);
+
+ // Renorm along rows
+ THCudaTensor_copy(state, normDist, origDist);
+ THCudaTensor_renormRows(state, normDist);
+
+ // Prefix sum along rows
+ THCudaTensor_cumsum(state, prefixSum, normDist, 1);
+
+ if (with_replacement) {
+ // Sample with replacement
+
+ // Binary search is warp divergent (so effectively we're running
+ // with just a single thread), but for better utilization,
+ // we need each block to have at least 4 warps.
+ dim3 block(32, 4);
+
+ // Each warp in a block will generate a sample from one
+ // distribution concurrently.
+ dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
+
+ sampleMultinomialWithReplacement
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states,
+ n_sample,
+ THCudaTensor_data(state, self),
+ numDist, numCategories,
+ THCudaTensor_data(state, prefixSum));
+ } else {
+ // Sample without replacement
+
+ // Binary search is warp divergent (so effectively we're running
+ // with just a single thread), but for better utilization,
+ // we need each block to have at least 4 warps.
+ dim3 block(32, 4);
+
+ // Each warp in a block will generate a sample from a different
+ // distribution concurrently.
+ ptrdiff_t numBlocks = THCCeilDiv(numDist, 4L);
+ dim3 grid(numBlocks < MAX_NUM_BLOCKS ? numBlocks : MAX_NUM_BLOCKS);
+
+ for (int sample = 0; sample < n_sample; ++sample) {
+ if (sample > 0) {
+ // Update probabilities
+ // Renorm along rows
+ THCudaTensor_copy(state, normDist, origDist);
+ THCudaTensor_renormRows(state, normDist);
+
+ // Prefix sum along rows
+ THCudaTensor_cumsum(state, prefixSum, normDist, 1);
+ }
+
+ // The kernel can only draw one sample before we have to
+ // recalculate our distribution
+ sampleMultinomialWithoutReplacement
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states,
+ n_sample,
+ sample,
+ THCudaTensor_data(state, self),
+ numDist, numCategories,
+ THCudaTensor_data(state, origDist),
+ THCudaTensor_data(state, prefixSum));
+ }
+ }
+
+ THCudaTensor_free(state, prefixSum);
+ THCudaTensor_free(state, normDist);
+ THCudaTensor_free(state, origDist);
+ }
+
+ // Revert data restructuring based on input sizes
+ if (inputSize == 1) {
+ THCudaTensor_resize1d(state, self, n_sample);
+
+ // Unfortunately, if prob_dist is contiguous already,
+ // newContiguous is not a private copy, so we have to restructure
+ // this too, so as to not affect prob_dist
+ THCudaTensor_resize1d(state, probDistContig, numCategories);
+ }
+
+ THCudaTensor_free(state, probDistContig);
+}
+#undef NUM_BLOCKS
diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh
deleted file mode 100644
index 003e960..0000000
--- a/lib/THC/THCTensorRandom.cuh
+++ /dev/null
@@ -1,278 +0,0 @@
-#ifndef THC_TENSOR_RANDOM_CUH
-#define THC_TENSOR_RANDOM_CUH
-
-#include "THCNumerics.cuh"
-#include "THCReduceApplyUtils.cuh"
-#include "THCTensorMathReduce.cuh"
-
-#include <curand_kernel.h>
-
-#define MAX_NUM_BLOCKS 64
-#define BLOCK_SIZE 256
-/* Separate kernel because curand_log_normal gets extra parameters. */
-
-template <typename T>
-__global__ void generateLogNormal(curandStateMtgp32 *state, int size, T *result, double mean, double stddev)
-{
- int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
- int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
- for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
- float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
- if (i < size) {
- result[i] = ScalarConvert<float, T>::to(x);
- }
- }
-}
-
-template <>
-__global__ void generateLogNormal<double>(curandStateMtgp32 *state, int size, double *result, double mean, double stddev)
-{
- int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
- int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
- for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
- double x = curand_log_normal_double(&state[blockIdx.x], mean, stddev);
- if (i < size) {
- result[i] = x;
- }
- }
-}
-
-#undef MAX_NUM_BLOCKS
-#undef BLOCK_SIZE
-
-// Normalizes the L1 norm of every row to 1; used by multinomial
-template <typename T>
-__global__ void renormRowsL1(T* dist, long rows, long cols) {
- extern __shared__ __align__(sizeof(T)) unsigned char my_smem[];
- T *smem = reinterpret_cast<T *>(my_smem);
-
- for (long row = blockIdx.x; row < rows; row += gridDim.x) {
- T sum = ScalarConvert<int, T>::to(0);
- for (long col = threadIdx.x; col < cols; col += blockDim.x) {
- sum = THCNumerics<T>::add(sum, dist[row * cols + col]);
- }
-
- sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), ScalarConvert<int, T>::to(0));
- if (threadIdx.x == 0) {
- smem[0] = sum;
- }
- __syncthreads();
-
- sum = smem[0];
- if (THCNumerics<T>::gt(sum, ScalarConvert<int, T>::to(0))) {
- for (long col = threadIdx.x; col < cols; col += blockDim.x) {
- dist[row * cols + col] = THCNumerics<T>::div(dist[row * cols + col], sum);
- }
- }
- }
-}
-
-template <typename T>
-__global__ void
-sampleMultinomialOnce(T* dest,
- long distributions,
- int categories,
- T* dist) {
- extern __shared__ __align__(sizeof(T)) unsigned char my_smem[];
- T *smem = reinterpret_cast<T *>(my_smem);
- T zero = ScalarConvert<int, T>::to(0);
-
- for (long curDist = blockIdx.x;
- curDist < distributions; curDist += gridDim.x) {
- // Each block handles one distribution
- // First pass, find the total sum of the distribution
- T sum = zero;
- for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
- sum = THCNumerics<T>::add(sum, dist[curDist * categories + cat]);
- }
-
- // threadIdx.x == 0 has the sum value from this
- sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), zero);
-
- // Broadcast sum and sample value
- if (threadIdx.x == 0) {
- smem[0] = sum;
- smem[1] = dest[curDist];
- }
- __syncthreads();
-
- sum = smem[0];
- T sample = smem[1];
- __syncthreads();
-
- if (THCNumerics<T>::eq(sum, zero) || THCNumerics<T>::eq(sample, zero)) {
- // Choose the first element
- if (threadIdx.x == 0) {
- dest[curDist] = ScalarConvert<int, T>::to(1);
- }
-
- continue;
- }
-
- int chunks = THCCeilDiv(categories, (int) blockDim.x);
- T prevHighProb = zero;
-
- for (int chunk = 0; chunk < chunks; ++chunk) {
- // All threads in bounds load a value
- int cat = chunk * blockDim.x + threadIdx.x;
-
- T val =
- cat < categories ? THCNumerics<T>::div(dist[curDist * categories + cat], sum) :
- zero;
-
- smem[threadIdx.x] = val;
- __syncthreads();
-
- // Perform an inclusive prefix sum of the shared memory contents
- for (int offset = 1; offset < blockDim.x; offset *= 2) {
- T val = zero;
-
- if (threadIdx.x >= offset) {
- val = THCNumerics<T>::add(smem[threadIdx.x - offset], smem[threadIdx.x]);
- }
-
- __syncthreads();
- if (threadIdx.x >= offset) {
- smem[threadIdx.x] = val;
- }
- __syncthreads();
- }
-
- // Each thread will check to see if the sample falls in its
- // bucket
- T curBucket = THCNumerics<T>::add(smem[threadIdx.x], prevHighProb);
- T prevBucket =
- threadIdx.x == 0 ? prevHighProb :
- THCNumerics<T>::add(smem[threadIdx.x - 1], prevHighProb);
- bool inBucket =
- (cat < categories) &&
- (!THCNumerics<T>::gt(sample, curBucket)) &&
- (THCNumerics<T>::gt(sample, prevBucket));
-
- if (inBucket) {
- // We're done; we have the sample
- // Torch indices are 1-based
- // FIXME: broadcast exit flag?
- dest[curDist] = ScalarConvert<int, T>::to(cat + TH_INDEX_BASE);
- }
-
- // Store the previous scan's high value for future use
- prevHighProb = THCNumerics<T>::add(prevHighProb, smem[blockDim.x - 1]);
-
- __syncthreads();
- }
- }
-}
-
-template <typename T>
-__device__ int binarySearchForMultinomial(T* dist,
- int size,
- T val) {
- int start = 0;
- int end = size;
-
- while (end - start > 0) {
- int mid = start + (end - start) / 2;
-
- T midVal = dist[mid];
- if (THCNumerics<T>::lt(midVal, val)) {
- start = mid + 1;
- } else {
- end = mid;
- }
- }
-
- if (start == size) {
- // No probability mass or precision problems; just return the
- // first element
- start = 0;
- }
-
- return start;
-}
-
-template <typename T>
-__global__ void
-sampleMultinomialWithReplacement(curandStateMtgp32* state,
- int totalSamples,
- T* dest,
- long distributions,
- int categories,
- T* normDistPrefixSum) {
- // At the moment, each warp computes one sample value in the binary
- // search due to divergence. It seems possible to compute multiple
- // values and limit divergence though later on. However, no matter
- // what, all block threads must participate in the curand_uniform
- // call to update the generator state.
-
- // The block determines the distribution for which we generate a point
- for (long curDist = blockIdx.x;
- curDist < distributions;
- curDist += gridDim.x) {
- for (int sampleBase = 0;
- sampleBase < totalSamples; sampleBase += blockDim.y) {
- // The warp determines the sample
- int sample = sampleBase + threadIdx.y;
-
- // All threads participate in this
- T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
-
- if (threadIdx.x == 0 && sample < totalSamples) {
- // Find the bucket that a uniform sample lies in
- int choice = binarySearchForMultinomial<T>(
- normDistPrefixSum + curDist * categories,
- categories,
- r);
-
- // Torch indices are 1-based
- dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE);
- }
- }
- }
-}
-
-template <typename T>
-__global__ void
-sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
- int totalSamples,
- int sample,
- T* dest,
- long distributions,
- int categories,
- T* origDist,
- T* normDistPrefixSum) {
- // At the moment, each warp computes one sample value in the binary
- // search due to divergence. It seems possible to compute multiple
- // values and limit divergence though later on. However, no matter
- // what, all block threads must participate in the curand_uniform
- // call to update the generator state.
-
- // The block and warp determines the distribution for which we
- // generate a point
- for (long curDistBase = blockIdx.x * blockDim.y;
- curDistBase < distributions;
- curDistBase += gridDim.x * blockDim.y) {
- // The warp determines the distribution
- long curDist = curDistBase + threadIdx.y;
-
- // All threads must participate in this
- T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
-
- if (threadIdx.x == 0 && curDist < distributions) {
- // Find the bucket that a uniform sample lies in
- int choice = binarySearchForMultinomial<T>(
- normDistPrefixSum + curDist * categories,
- categories,
- r);
-
- // Torch indices are 1-based
- dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE);
-
- // Without replacement, so update the original probability so it
- // is not considered a second time
- origDist[curDist * categories + choice] = ScalarConvert<int, T>::to(0);
- }
- }
-}
-
-#endif // THC_TENSOR_RANDOM_CUH
diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h
index 12128cd..93c0d77 100644
--- a/lib/THC/THCTensorRandom.h
+++ b/lib/THC/THCTensorRandom.h
@@ -3,9 +3,6 @@
#include "THCTensor.h"
-#include "generic/THCTensorRandom.h"
-#include "THCGenerateAllTypes.h"
-
/* Generator */
typedef struct _Generator {
struct curandStateMtgp32* gen_states;
@@ -31,6 +28,15 @@ THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_s
THC_API unsigned long THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);
+THC_API void THCudaTensor_geometric(struct THCState *state, THCudaTensor *self, double p);
+THC_API void THCudaTensor_bernoulli(struct THCState *state, THCudaTensor *self, double p);
+THC_API void THCudaTensor_uniform(struct THCState *state, THCudaTensor *self, double a, double b);
+THC_API void THCudaTensor_normal(struct THCState *state, THCudaTensor *self, double mean, double stdv);
+THC_API void THCudaTensor_exponential(struct THCState *state, THCudaTensor *self, double lambda);
+THC_API void THCudaTensor_cauchy(struct THCState *state, THCudaTensor *self, double median, double sigma);
+THC_API void THCudaTensor_logNormal(struct THCState *state, THCudaTensor *self, double mean, double stdv);
+
+THC_API void THCudaTensor_multinomial(struct THCState *state, THCudaTensor *self, THCudaTensor *prob_dist, int n_sample, int with_replacement);
THC_API struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state);
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
deleted file mode 100644
index 50aa2e8..0000000
--- a/lib/THC/generic/THCTensorRandom.cu
+++ /dev/null
@@ -1,328 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "generic/THCTensorRandom.cu"
-#else
-
-#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
-
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
-
-THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, double b)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, a, b);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, mean, stdv);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-THC_API void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, double stdv)
-{
-
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generateLogNormal<real><<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, mean, stdv);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, lambda);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median, double sigma)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, median, sigma);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-void THCTensor_(renormRows)(struct THCState* state,
- THCTensor* t) {
- THAssert(THCTensor_(nDimension)(state, t) == 2);
- long rows = THCTensor_(size)(state, t, 0);
- long cols = THCTensor_(size)(state, t, 1);
-
- cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
- THAssert(props != NULL);
-
- int numSM = props->multiProcessorCount;
- int maxThreads = props->maxThreadsPerBlock;
-
- dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
- dim3 block(cols < maxThreads ? cols : maxThreads);
-
- renormRowsL1<real>
- <<<grid, block, block.x * sizeof(real),
- THCState_getCurrentStream(state)>>>(THCTensor_(data)(state, t),
- rows, cols);
-}
-
-THC_API void THCTensor_(multinomial)(struct THCState *state,
- THCTensor *self,
- THCTensor *prob_dist,
- int n_sample,
- int with_replacement)
-{
- THAssert(THCTensor_(checkGPU)(state, 2, self, prob_dist));
- Generator* gen = THCRandom_getGenerator(state);
-
- int inputSize = THCTensor_(nDimension)(state, prob_dist);
- THArgCheck(inputSize > 0 && inputSize <= 2, 2,
- "prob_dist must be 1 or 2 dim");
-
- // Categories are in the innermost dimension
- long numDist =
- inputSize == 1 ? 1 : THCTensor_(size)(state, prob_dist, 0);
- long numCategoriesLong =
- inputSize == 1 ? THCTensor_(size)(state, prob_dist, 0) :
- THCTensor_(size)(state, prob_dist, 1);
-
- // Since the index tensor is float, numCategories cannot exceed max
- // float integer precision
- THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2,
- "number of categories cannot exceed 2^24");
- int numCategories = (int) numCategoriesLong;
-
- THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples");
-
- if (!with_replacement) {
- THArgCheck(n_sample <= numCategories, 2,
- "cannot sample n_sample > prob_dist:size(1) samples without "
- "replacement");
- }
-
- // It is possible that prob_dist is non-contiguous
- THCTensor* probDistContig =
- THCTensor_(newContiguous)(state, prob_dist);
-
- // Restructure data for 2d
- if (inputSize == 1) {
- THCTensor_(resize2d)(state, probDistContig, 1, numCategories);
- }
-
- THCTensor_(resize2d)(state, self, numDist, n_sample);
-
- if (n_sample == 1) {
- // Optimized allocation-free implementation
-
- // To exploit greater parallelism for the sampling, generate the
- // Uniform random samples in a separate kernel launch, into the
- // result memory. The device RNG is thread-limited
- THCTensor_(uniform)(state, self, 0.0, 1.0);
-
- cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
- THAssert(props != NULL);
-
- int numSM = props->multiProcessorCount;
- int maxThreads = props->maxThreadsPerBlock;
-
- dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
- dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
-
- sampleMultinomialOnce
- <<<grid, block, block.x * sizeof(real),
- THCState_getCurrentStream(state)>>>(
- THCTensor_(data)(state, self),
- numDist,
- numCategories,
- THCTensor_(data)(state, probDistContig));
- } else {
- // Generic, slow implementation with memory allocations
-
- // For sampling without replacement, we modify the distribution
- // for subsequent samples in this space
- THCTensor* origDist = THCTensor_(new)(state);
- THCTensor_(resizeAs)(state, origDist, probDistContig);
- THCTensor_(copy)(state, origDist, probDistContig);
-
- THCTensor* normDist = THCTensor_(new)(state);
- THCTensor_(resizeAs)(state, normDist, probDistContig);
-
- THCTensor* prefixSum = THCTensor_(new)(state);
-
- // Renorm along rows
- THCTensor_(copy)(state, normDist, origDist);
- THCTensor_(renormRows)(state, normDist);
-
- // Prefix sum along rows
- THCTensor_(cumsum)(state, prefixSum, normDist, 1);
-
- if (with_replacement) {
- // Sample with replacement
-
- // Binary search is warp divergent (so effectively we're running
- // with just a single thread), but for better utilization,
- // we need each block to have at least 4 warps.
- dim3 block(32, 4);
-
- // Each warp in a block will generate a sample from one
- // distribution concurrently.
- dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
-
- sampleMultinomialWithReplacement
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states,
- n_sample,
- THCTensor_(data)(state, self),
- numDist, numCategories,
- THCTensor_(data)(state, prefixSum));
- } else {
- // Sample without replacement
-
- // Binary search is warp divergent (so effectively we're running
- // with just a single thread), but for better utilization,
- // we need each block to have at least 4 warps.
- dim3 block(32, 4);
-
- // Each warp in a block will generate a sample from a different
- // distribution concurrently.
- ptrdiff_t numBlocks = THCCeilDiv(numDist, 4L);
- dim3 grid(numBlocks < MAX_NUM_BLOCKS ? numBlocks : MAX_NUM_BLOCKS);
-
- for (int sample = 0; sample < n_sample; ++sample) {
- if (sample > 0) {
- // Update probabilities
- // Renorm along rows
- THCTensor_(copy)(state, normDist, origDist);
- THCTensor_(renormRows)(state, normDist);
-
- // Prefix sum along rows
- THCTensor_(cumsum)(state, prefixSum, normDist, 1);
- }
-
- // The kernel can only draw one sample before we have to
- // recalculate our distribution
- sampleMultinomialWithoutReplacement
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states,
- n_sample,
- sample,
- THCTensor_(data)(state, self),
- numDist, numCategories,
- THCTensor_(data)(state, origDist),
- THCTensor_(data)(state, prefixSum));
- }
- }
-
- THCTensor_(free)(state, prefixSum);
- THCTensor_(free)(state, normDist);
- THCTensor_(free)(state, origDist);
- }
-
- // Revert data restructuring based on input sizes
- if (inputSize == 1) {
- THCTensor_(resize1d)(state, self, n_sample);
-
- // Unfortunately, if prob_dist is contiguous already,
- // newContiguous is not a private copy, so we have to restructure
- // this too, so as to not affect prob_dist
- THCTensor_(resize1d)(state, probDistContig, numCategories);
- }
-
- THCTensor_(free)(state, probDistContig);
-}
-
-THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, r_));
- THCTensor_(resize)(state, r_, size, NULL);
- THCTensor_(uniform)(state, r_, 0, 1);
-}
-
-void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, r_));
- THCTensor_(resize)(state, r_, size, NULL);
- THCTensor_(normal)(state, r_, 0, 1);
-}
-
-#endif
-
-#if defined(THC_REAL_IS_DOUBLE)
-GENERATE_KERNEL1(generate_bernoulli, double, double p, double, curand_uniform_double, x <= p)
-#else
-GENERATE_KERNEL1(generate_bernoulli, real, double p, float, curand_uniform, (ScalarConvert<bool, real>::to(x <= p)))
-#endif
-
-THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_bernoulli<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, p);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
-#if defined(THC_REAL_IS_DOUBLE)
-
-GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, floor((log(1-x) / log(p)) + 1))
-#else
-GENERATE_KERNEL1(generate_geometric, real, double p, float, curand_uniform, (ScalarConvert<float, real>::to(floorf((log(1-x) / log(p)) + 1))))
-#endif
-
-THC_API void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p)
-{
- THAssert(THCTensor_(checkGPU)(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- ptrdiff_t size = THCTensor_(nElement)(state, self);
- real *data = THCTensor_(data)(state, self);
-
- generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, p);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-#undef NUM_BLOCKS
-
-#endif
diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h
deleted file mode 100644
index a2896c3..0000000
--- a/lib/THC/generic/THCTensorRandom.h
+++ /dev/null
@@ -1,21 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "generic/THCTensorRandom.h"
-#else
-
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
-
-THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b);
-THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size);
-THC_API void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size);
-THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
-THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
-THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
-THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma);
-THC_API void THCTensor_(multinomial)(struct THCState *state, THCTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement);
-
-#endif
-
-THC_API void THCTensor_(bernoulli)(struct THCState *state, THCTensor *self, double p);
-THC_API void THCTensor_(geometric)(struct THCState *state, THCTensor *self, double p);
-
-#endif
diff --git a/test/test.lua b/test/test.lua
index 724d5ff..00cfa66 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2526,11 +2526,8 @@ function test.uniform()
local max = min + torch.uniform()
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(typename)
- x:uniform(min, max)
- checkIfUniformlyDistributed(x, min, max)
- end
+ t:uniform(min, max)
+ checkIfUniformlyDistributed(t, min, max)
checkMultiDevice(t, 'uniform', min, max)
end
@@ -2540,17 +2537,13 @@ function test.bernoulli()
local p = torch.uniform()
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(typenames) do
- local x = t:type(typename)
- x:bernoulli(p)
- local mean = x:sum() / (sz1 * sz2)
- tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p")
- local f = t:float()
- tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
- torch.FloatTensor(sz1, sz2):fill(1),
- 1e-6,
- "each value must be either 0 or 1")
- end
+ t:bernoulli(p)
+ tester:assertalmosteq(t:mean(), p, 0.1, "mean is not equal to p")
+ local f = t:float()
+ tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
+ torch.FloatTensor(sz1, sz2):fill(1),
+ 1e-6,
+ "each value must be either 0 or 1")
checkMultiDevice(t, 'bernoulli', p)
end
@@ -2561,13 +2554,9 @@ function test.normal()
local tolerance = 0.01
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(t2cpu[typename])
- x:normal(mean, std)
- tester:assertalmosteq(x:mean(), mean, tolerance, "mean is wrong")
- tester:assertalmosteq(x:std(), std, tolerance, "standard deviation is wrong")
- end
-
+ t:normal(mean, std)
+ tester:assertalmosteq(t:mean(), mean, tolerance, "mean is wrong")
+ tester:assertalmosteq(t:std(), std, tolerance, "standard deviation is wrong")
checkMultiDevice(t, 'normal', mean, std)
end
@@ -2578,13 +2567,10 @@ function test.logNormal()
local tolerance = 0.01
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(typename)
- x:logNormal(mean, std)
- local logt = x:log()
- tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong")
- tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong")
- end
+ t:logNormal(mean, std)
+ local logt = t:log()
+ tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong")
+ tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong")
checkMultiDevice(t, 'logNormal', mean, std)
end
@@ -2594,14 +2580,10 @@ function test.geometric()
local p = torch.uniform()
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(typename)
- x:geometric(p)
-
- local u = torch.FloatTensor(sz1, sz2):fill(1) -
- ((x:float() - 1) * math.log(p)):exp()
- checkIfUniformlyDistributed(u, 0, 1)
- end
+ t:geometric(p)
+ local u = torch.FloatTensor(sz1, sz2):fill(1) -
+ ((t:float() - 1) * math.log(p)):exp()
+ checkIfUniformlyDistributed(u, 0, 1)
checkMultiDevice(t, 'geometric', p)
end
@@ -2611,13 +2593,10 @@ function test.exponential()
local lambda = torch.uniform()
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(t2cpu[typename])
- x:exponential(lambda)
- local u = torch.FloatTensor(sz1, sz2):fill(1) -
- (x:float() * -lambda):exp()
- checkIfUniformlyDistributed(u, 0, 1)
- end
+ t:exponential(lambda)
+ local u = torch.FloatTensor(sz1, sz2):fill(1) -
+ (t:float() * -lambda):exp()
+ checkIfUniformlyDistributed(u, 0, 1)
checkMultiDevice(t, 'exponential', lambda)
end
@@ -2627,12 +2606,9 @@ function test.cauchy()
local median, sigma = torch.uniform(), torch.uniform()
local t = torch.CudaTensor(sz1, sz2)
- for _, typename in ipairs(float_typenames) do
- local x = t:type(typename)
- x:cauchy(median, sigma)
- local u = ((x:float() - median) / sigma):atan() / math.pi + 0.5
- checkIfUniformlyDistributed(u, 0, 1)
- end
+ t:cauchy(median, sigma)
+ local u = ((t:float() - median) / sigma):atan() / math.pi + 0.5
+ checkIfUniformlyDistributed(u, 0, 1)
checkMultiDevice(t, 'cauchy', median, sigma)
end
@@ -2699,19 +2675,16 @@ function test.multinomial_with_replacement()
local prob_dist = torch.CudaTensor(n_row, n_col):uniform()
prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled
local n_sample = torch.random(n_col - 1)
- for _, typename in ipairs(float_typenames) do
- local pd = prob_dist:type(typename)
- local sample_indices = torch.multinomial(pd, n_sample, true)
- tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
- tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
-
- for i = 1, n_row do
- for j = 1, n_sample do
- local val = sample_indices[{i,j}]
- tester:assert(val == math.floor(val) and val >= 1 and val < n_col,
- "sampled an invalid index: " .. val)
- end
- end
+ local sample_indices = torch.multinomial(prob_dist, n_sample, true)
+ tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
+ tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
+
+ for i = 1, n_row do
+ for j = 1, n_sample do
+ local val = sample_indices[{i,j}]
+ tester:assert(val == math.floor(val) and val >= 1 and val < n_col,
+ "sampled an invalid index: " .. val)
+ end
end
end
end
@@ -2725,27 +2698,24 @@ function test.multinomial_without_replacement()
local prob_dist = torch.CudaTensor(n_row, n_col):uniform()
prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled
local n_sample = torch.random(n_col - 1)
- for _, typename in ipairs(float_typenames) do
- local pd = prob_dist:type(typename)
- local sample_indices = torch.multinomial(pd, n_sample, false)
- tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
- tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
-
- sample_indices = sample_indices:float()
-
- for i = 1, n_row do
- local row_samples = {}
- for j = 1, n_sample do
- local sample_idx = sample_indices[{i,j}]
- tester:assert(
- sample_idx ~= n_col, "sampled an index with zero probability"
- )
- tester:assert(
- not row_samples[sample_idx], "sampled an index twice"
- )
- row_samples[sample_idx] = true
- end
- end
+ local sample_indices = torch.multinomial(prob_dist, n_sample, false)
+ tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim")
+ tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples")
+
+ sample_indices = sample_indices:float()
+
+ for i = 1, n_row do
+ local row_samples = {}
+ for j = 1, n_sample do
+ local sample_idx = sample_indices[{i,j}]
+ tester:assert(
+ sample_idx ~= n_col, "sampled an index with zero probability"
+ )
+ tester:assert(
+ not row_samples[sample_idx], "sampled an index twice"
+ )
+ row_samples[sample_idx] = true
+ end
end
end
end
@@ -2761,21 +2731,17 @@ function test.multinomial_without_replacement_gets_all()
t[dist] = linear
end
- local orig = t:clone():long()
-
- for _, typename in ipairs(float_typenames) do
- local x = t:type(typename)
+ local orig = t:clone()
- -- Sample without replacement
- local result = torch.multinomial(x, distSize)
- tester:assert(result:size(1) == distributions)
- tester:assert(result:size(2) == distSize)
+ -- Sample without replacement
+ local result = torch.multinomial(t, distSize)
+ tester:assert(result:size(1) == distributions)
+ tester:assert(result:size(2) == distSize)
- -- Sort, and we should have the original results, since without replacement
- -- sampling everything, we should have chosen every value uniquely
- result = result:sort(2)
- tester:assertTensorEq(orig:type(typename), result, 0, "error in multinomial_without_replacement_gets_all")
- end
+ -- Sort, and we should have the original results, since without replacement
+ -- sampling everything, we should have chosen every value uniquely
+ result = result:sort(2)
+ tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all")
end
end
@@ -2783,15 +2749,12 @@ function test.multinomial_vector()
local n_col = torch.random(100)
local prob_dist = torch.CudaTensor(n_col):uniform()
local n_sample = n_col
- for _, typename in ipairs(float_typenames) do
- local pd = prob_dist:type(typename)
- local sample_indices = torch.multinomial(pd, n_sample, true)
- tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim")
- -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize
- -- was undone
- tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions")
- tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples")
- end
+ local sample_indices = torch.multinomial(prob_dist, n_sample, true)
+ tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim")
+ -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize
+ -- was undone
+ tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions")
+ tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples")
end
function test.get_device()