From bdf2b06c9d7e1e53e78f4ac78f6f5dfa4f1b9020 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Fri, 11 Nov 2016 18:15:01 -0500 Subject: Revert "Move random functions to generic" --- TensorMath.lua | 43 --- lib/THC/CMakeLists.txt | 3 - lib/THC/THCTensorMath.h | 2 + lib/THC/THCTensorMath2.cu | 13 + lib/THC/THCTensorRandom.cu | 567 +++++++++++++++++++++++++++++++++++-- lib/THC/THCTensorRandom.cuh | 278 ------------------ lib/THC/THCTensorRandom.h | 12 +- lib/THC/generic/THCTensorRandom.cu | 328 --------------------- lib/THC/generic/THCTensorRandom.h | 21 -- test/test.lua | 177 +++++------- 10 files changed, 634 insertions(+), 810 deletions(-) delete mode 100644 lib/THC/THCTensorRandom.cuh delete mode 100644 lib/THC/generic/THCTensorRandom.cu delete mode 100644 lib/THC/generic/THCTensorRandom.h diff --git a/TensorMath.lua b/TensorMath.lua index f94154d..a5d436f 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -894,15 +894,6 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor .. "Array"}, {name="index", default=lastdimarray(2)}}) - for _,f in ipairs({{name='geometric'}, - {name='bernoulli', a=0.5}}) do - - wrap(f.name, - cname(f.name), - {{name=Tensor, returned=true}, - {name='double', default=f.a}}) - end - wrap("nonzero", cname("nonzero"), {{name="CudaLongTensor", default=true, returned=true}, @@ -934,40 +925,6 @@ for k, Tensor_ in pairs(handledTypenames) do {name = real}, {name=Tensor, method={default=1}}}) - wrap("rand", - cname("rand"), - {{name=Tensor, default=true, returned=true, method={default='nil'}}, - {name="LongArg"}}) - - wrap("randn", - cname("randn"), - {{name=Tensor, default=true, returned=true, method={default='nil'}}, - {name="LongArg"}}) - - wrap("multinomial", - cname("multinomial"), - {{name=Tensor, default=true, returned=true, method={default='nil'}}, - {name=Tensor}, - {name="int"}, - {name="boolean", default=false}}) - - for _,f in ipairs({{name='uniform', a=0, b=1}, - {name='cauchy', a=0, b=1}, - {name='normal', a=0, b=1}, - {name='logNormal', a=1, b=2}}) do - - wrap(f.name, - cname(f.name), - {{name=Tensor, returned=true}, - {name='double', default=f.a}, - {name='double', default=f.b}}) - end - - wrap('exponential', - cname('exponential'), - {{name=Tensor, returned=true}, - {name='double', default=nil}}) - wrap("norm", cname("normall"), {{name=Tensor}, diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index 51f568a..e5c202a 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -238,7 +238,6 @@ INSTALL(FILES THCTensorSort.cuh THCTensorInfo.cuh THCTensorTypeUtils.cuh - THCTensorRandom.cuh DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC") INSTALL(FILES @@ -279,6 +278,4 @@ INSTALL(FILES generic/THCTensorSort.h generic/THCTensorSort.cu generic/THCDeviceTensorUtils.cu - generic/THCTensorRandom.h - generic/THCTensorRandom.cu DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic") diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 7cbef32..759c9a3 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -53,6 +53,8 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b); THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a); +THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); +THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size); THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self); THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 7e6af9b..9933b7e 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -28,3 +28,16 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaCheck(cudaGetLastError()); } +void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size) +{ + THAssert(THCudaTensor_checkGPU(state, 1, r_)); + THCudaTensor_resize(state, r_, size, NULL); + THCudaTensor_uniform(state, r_, 0, 1); +} + +void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size) +{ + THAssert(THCudaTensor_checkGPU(state, 1, r_)); + THCudaTensor_resize(state, r_, size, NULL); + THCudaTensor_normal(state, r_, 0, 1); +} diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index 4493fe8..827692f 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -4,7 +4,6 @@ #include "THCTensorCopy.h" #include "THCTensorMath.h" #include "THCReduceApplyUtils.cuh" -#include "THCTensorRandom.cuh" #include #include @@ -188,53 +187,567 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) memcpy(&gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size); } -#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ +#define GENERATE_KERNEL1(NAME, ARG1, CURAND_FUNC, TRANSFORM) \ +__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \ - CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ + float x = CURAND_FUNC(&state[blockIdx.x]); \ if (i < size) { \ - T y = TRANSFORM; \ - result[i] = y; \ + x = TRANSFORM; \ + result[i] = x; \ } \ } \ } -#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \ +#define GENERATE_KERNEL2(NAME, ARG1, ARG2, CURAND_FUNC, TRANSFORM) \ +__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1, ARG2) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \ - CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ + float x = CURAND_FUNC(&state[blockIdx.x]); \ if (i < size) { \ - T y = TRANSFORM; \ - result[i] = y; \ + x = TRANSFORM; \ + result[i] = x; \ } \ } \ } -GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, x * (b-a) + a) -GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a) -GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert::to(x * (b-a) + a))) +GENERATE_KERNEL2(generate_uniform, double a, double b, curand_uniform, x * (b-a) + a) +GENERATE_KERNEL1(generate_bernoulli, double p, curand_uniform, (float)x <= p) +GENERATE_KERNEL2(generate_normal, double mean, double stdv, curand_normal, (x * stdv) + mean) +GENERATE_KERNEL1(generate_geometric, double p, curand_uniform, (log(1-x) / log(p)) + 1) +GENERATE_KERNEL1(generate_exponential, double lambda, curand_uniform, (float)(-1. / lambda * log(1-x))) +GENERATE_KERNEL2(generate_cauchy, double median, double sigma, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5)))) -GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean) -GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean) -GENERATE_KERNEL2(generate_normal, half, double mean, double stdv, float, curand_normal, (ScalarConvert::to((x * stdv) + mean))) +#undef GENERATE_KERNEL1 +#undef GENERATE_KERNEL2 -GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(1-x))) -GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(1-x))) -GENERATE_KERNEL1(generate_exponential, half, double lambda, float, curand_uniform, (ScalarConvert::to((float)(-1. / lambda * log(1-x))))) +/* Separate kernel because curand_log_normal gets extra parameters. */ +__global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *result, float mean, float stddev) +{ + int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; + int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; + for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { + float x = curand_log_normal(&state[blockIdx.x], mean, stddev); + if (i < size) { + result[i] = x; + } + } +} -GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5)))) -GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5)))) -GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, curand_uniform, (ScalarConvert::to((float)(median + sigma * tan(M_PI*(x-0.5)))))) +#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS) +THC_API void THCudaTensor_uniform(THCState* state, THCudaTensor *self_, double a, double b) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); -#include "generic/THCTensorRandom.cu" -#include "THCGenerateAllTypes.h" + generate_uniform<<>>( + gen->gen_states, size, data, a, b); -#undef GENERATE_KERNEL1 -#undef GENERATE_KERNEL2 + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_bernoulli(THCState* state, THCudaTensor *self_, double p) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_bernoulli<<>>( + gen->gen_states, size, data, p); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_normal(THCState* state, THCudaTensor *self_, double mean, double stdv) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_normal<<>>( + gen->gen_states, size, data, mean, stdv); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_logNormal(THCState* state, THCudaTensor *self_, double mean, double stdv) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_log_normal<<>>( + gen->gen_states, size, data, mean, stdv); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_geometric(THCState* state, THCudaTensor *self_, double p) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_geometric<<>>( + gen->gen_states, size, data, p); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_exponential(THCState* state, THCudaTensor *self_, double lambda) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_exponential<<>>( + gen->gen_states, size, data, lambda); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +THC_API void THCudaTensor_cauchy(THCState* state, THCudaTensor *self_, double median, double sigma) +{ + THAssert(THCudaTensor_checkGPU(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + ptrdiff_t size = THCudaTensor_nElement(state, self); + float *data = THCudaTensor_data(state, self); + + generate_cauchy<<>>( + gen->gen_states, size, data, median, sigma); + + THCudaTensor_freeCopyTo(state, self, self_); +}; + +__device__ int binarySearchForMultinomial(float* dist, + int size, + float val) { + int start = 0; + int end = size; + + while (end - start > 0) { + int mid = start + (end - start) / 2; + + float midVal = dist[mid]; + if (midVal < val) { + start = mid + 1; + } else { + end = mid; + } + } + + if (start == size) { + // No probability mass or precision problems; just return the + // first element + start = 0; + } + + return start; +} + +// Normalizes the L1 norm of every row to 1; used by multinomial +__global__ void renormRowsL1(float* dist, long rows, long cols) { + extern __shared__ float smem[]; + + for (long row = blockIdx.x; row < rows; row += gridDim.x) { + float sum = 0.0f; + for (long col = threadIdx.x; col < cols; col += blockDim.x) { + sum += dist[row * cols + col]; + } + + sum = reduceBlock(smem, blockDim.x, sum, thrust::plus(), 0.0f); + if (threadIdx.x == 0) { + smem[0] = sum; + } + __syncthreads(); + + sum = smem[0]; + if (sum > 0.0f) { + for (long col = threadIdx.x; col < cols; col += blockDim.x) { + dist[row * cols + col] /= sum; + } + } + } +} + +void THCudaTensor_renormRows(struct THCState* state, + THCudaTensor* t) { + THAssert(THCudaTensor_nDimension(state, t) == 2); + long rows = THCudaTensor_size(state, t, 0); + long cols = THCudaTensor_size(state, t, 1); + + cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state); + THAssert(props != NULL); + + int numSM = props->multiProcessorCount; + int maxThreads = props->maxThreadsPerBlock; + + dim3 grid(rows < numSM * 4 ? rows : numSM * 4); + dim3 block(cols < maxThreads ? cols : maxThreads); + + renormRowsL1 + <<>>(THCudaTensor_data(state, t), + rows, cols); +} + +__global__ void +sampleMultinomialOnce(float* dest, + long distributions, + int categories, + float* dist) { + extern __shared__ float smem[]; + + for (long curDist = blockIdx.x; + curDist < distributions; curDist += gridDim.x) { + // Each block handles one distribution + // First pass, find the total sum of the distribution + float sum = 0.0f; + for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { + sum += dist[curDist * categories + cat]; + } + + // threadIdx.x == 0 has the sum value from this + sum = reduceBlock(smem, blockDim.x, sum, thrust::plus(), 0.0f); + + // Broadcast sum and sample value + if (threadIdx.x == 0) { + smem[0] = sum; + smem[1] = dest[curDist]; + } + __syncthreads(); + + sum = smem[0]; + float sample = smem[1]; + __syncthreads(); + + if (sum == 0.0f || sample == 0.0f) { + // Choose the first element + if (threadIdx.x == 0) { + dest[curDist] = 1; + } + + continue; + } + + int chunks = THCCeilDiv(categories, (int) blockDim.x); + float prevHighProb = 0.0f; + + for (int chunk = 0; chunk < chunks; ++chunk) { + // All threads in bounds load a value + int cat = chunk * blockDim.x + threadIdx.x; + + float val = + cat < categories ? dist[curDist * categories + cat] / sum : 0.0f; + smem[threadIdx.x] = val; + __syncthreads(); + + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < blockDim.x; offset *= 2) { + float val = 0.0f; + + if (threadIdx.x >= offset) { + val = smem[threadIdx.x - offset] + smem[threadIdx.x]; + } + + __syncthreads(); + if (threadIdx.x >= offset) { + smem[threadIdx.x] = val; + } + __syncthreads(); + } + + // Each thread will check to see if the sample falls in its + // bucket + float curBucket = + smem[threadIdx.x] + prevHighProb; + float prevBucket = + threadIdx.x == 0 ? prevHighProb : smem[threadIdx.x - 1] + prevHighProb; + bool inBucket = + (cat < categories) && (sample <= curBucket) && (sample > prevBucket); + + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + // FIXME: broadcast exit flag? + dest[curDist] = cat + TH_INDEX_BASE; + } + + // Store the previous scan's high value for future use + prevHighProb += smem[blockDim.x - 1]; + + __syncthreads(); + } + } +} + +__global__ void +sampleMultinomialWithReplacement(curandStateMtgp32* state, + int totalSamples, + float* dest, + long distributions, + int categories, + float* normDistPrefixSum) { + // At the moment, each warp computes one sample value in the binary + // search due to divergence. It seems possible to compute multiple + // values and limit divergence though later on. However, no matter + // what, all block threads must participate in the curand_uniform + // call to update the generator state. + + // The block determines the distribution for which we generate a point + for (long curDist = blockIdx.x; + curDist < distributions; + curDist += gridDim.x) { + for (int sampleBase = 0; + sampleBase < totalSamples; sampleBase += blockDim.y) { + // The warp determines the sample + int sample = sampleBase + threadIdx.y; + + // All threads participate in this + float r = curand_uniform(&state[blockIdx.x]); + + if (threadIdx.x == 0 && sample < totalSamples) { + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial( + normDistPrefixSum + curDist * categories, + categories, + r); + + // Torch indices are 1-based + dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE; + } + } + } +} + +__global__ void +sampleMultinomialWithoutReplacement(curandStateMtgp32* state, + int totalSamples, + int sample, + float* dest, + long distributions, + int categories, + float* origDist, + float* normDistPrefixSum) { + // At the moment, each warp computes one sample value in the binary + // search due to divergence. It seems possible to compute multiple + // values and limit divergence though later on. However, no matter + // what, all block threads must participate in the curand_uniform + // call to update the generator state. + + // The block and warp determines the distribution for which we + // generate a point + for (long curDistBase = blockIdx.x * blockDim.y; + curDistBase < distributions; + curDistBase += gridDim.x * blockDim.y) { + // The warp determines the distribution + long curDist = curDistBase + threadIdx.y; + + // All threads must participate in this + float r = curand_uniform(&state[blockIdx.x]); + + if (threadIdx.x == 0 && curDist < distributions) { + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial( + normDistPrefixSum + curDist * categories, + categories, + r); + + // Torch indices are 1-based + dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE; + + // Without replacement, so update the original probability so it + // is not considered a second time + origDist[curDist * categories + choice] = 0.0f; + } + } +} + +THC_API void THCudaTensor_multinomial(struct THCState *state, + THCudaTensor *self, + THCudaTensor *prob_dist, + int n_sample, + int with_replacement) +{ + THAssert(THCudaTensor_checkGPU(state, 2, self, prob_dist)); + Generator* gen = THCRandom_getGenerator(state); + + int inputSize = THCudaTensor_nDimension(state, prob_dist); + THArgCheck(inputSize > 0 && inputSize <= 2, 2, + "prob_dist must be 1 or 2 dim"); + + // Categories are in the innermost dimension + long numDist = + inputSize == 1 ? 1 : THCudaTensor_size(state, prob_dist, 0); + long numCategoriesLong = + inputSize == 1 ? THCudaTensor_size(state, prob_dist, 0) : + THCudaTensor_size(state, prob_dist, 1); + + // Since the index tensor is float, numCategories cannot exceed max + // float integer precision + THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2, + "number of categories cannot exceed 2^24"); + int numCategories = (int) numCategoriesLong; + + THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples"); + + if (!with_replacement) { + THArgCheck(n_sample <= numCategories, 2, + "cannot sample n_sample > prob_dist:size(1) samples without " + "replacement"); + } + + // It is possible that prob_dist is non-contiguous + THCudaTensor* probDistContig = + THCudaTensor_newContiguous(state, prob_dist); + + // Restructure data for 2d + if (inputSize == 1) { + THCudaTensor_resize2d(state, probDistContig, 1, numCategories); + } + + THCudaTensor_resize2d(state, self, numDist, n_sample); + + if (n_sample == 1) { + // Optimized allocation-free implementation + + // To exploit greater parallelism for the sampling, generate the + // Uniform random samples in a separate kernel launch, into the + // result memory. The device RNG is thread-limited + THCudaTensor_uniform(state, self, 0.0, 1.0); + + cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state); + THAssert(props != NULL); + + int numSM = props->multiProcessorCount; + int maxThreads = props->maxThreadsPerBlock; + + dim3 block(numCategories < maxThreads ? numCategories : maxThreads); + dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); + + sampleMultinomialOnce + <<>>( + THCudaTensor_data(state, self), + numDist, + numCategories, + THCudaTensor_data(state, probDistContig)); + } else { + // Generic, slow implementation with memory allocations + + // For sampling without replacement, we modify the distribution + // for subsequent samples in this space + THCudaTensor* origDist = THCudaTensor_new(state); + THCudaTensor_resizeAs(state, origDist, probDistContig); + THCudaTensor_copy(state, origDist, probDistContig); + + THCudaTensor* normDist = THCudaTensor_new(state); + THCudaTensor_resizeAs(state, normDist, probDistContig); + + THCudaTensor* prefixSum = THCudaTensor_new(state); + + // Renorm along rows + THCudaTensor_copy(state, normDist, origDist); + THCudaTensor_renormRows(state, normDist); + + // Prefix sum along rows + THCudaTensor_cumsum(state, prefixSum, normDist, 1); + + if (with_replacement) { + // Sample with replacement + + // Binary search is warp divergent (so effectively we're running + // with just a single thread), but for better utilization, + // we need each block to have at least 4 warps. + dim3 block(32, 4); + + // Each warp in a block will generate a sample from one + // distribution concurrently. + dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS); + + sampleMultinomialWithReplacement + <<>>( + gen->gen_states, + n_sample, + THCudaTensor_data(state, self), + numDist, numCategories, + THCudaTensor_data(state, prefixSum)); + } else { + // Sample without replacement + + // Binary search is warp divergent (so effectively we're running + // with just a single thread), but for better utilization, + // we need each block to have at least 4 warps. + dim3 block(32, 4); + + // Each warp in a block will generate a sample from a different + // distribution concurrently. + ptrdiff_t numBlocks = THCCeilDiv(numDist, 4L); + dim3 grid(numBlocks < MAX_NUM_BLOCKS ? numBlocks : MAX_NUM_BLOCKS); + + for (int sample = 0; sample < n_sample; ++sample) { + if (sample > 0) { + // Update probabilities + // Renorm along rows + THCudaTensor_copy(state, normDist, origDist); + THCudaTensor_renormRows(state, normDist); + + // Prefix sum along rows + THCudaTensor_cumsum(state, prefixSum, normDist, 1); + } + + // The kernel can only draw one sample before we have to + // recalculate our distribution + sampleMultinomialWithoutReplacement + <<>>( + gen->gen_states, + n_sample, + sample, + THCudaTensor_data(state, self), + numDist, numCategories, + THCudaTensor_data(state, origDist), + THCudaTensor_data(state, prefixSum)); + } + } + + THCudaTensor_free(state, prefixSum); + THCudaTensor_free(state, normDist); + THCudaTensor_free(state, origDist); + } + + // Revert data restructuring based on input sizes + if (inputSize == 1) { + THCudaTensor_resize1d(state, self, n_sample); + + // Unfortunately, if prob_dist is contiguous already, + // newContiguous is not a private copy, so we have to restructure + // this too, so as to not affect prob_dist + THCudaTensor_resize1d(state, probDistContig, numCategories); + } + + THCudaTensor_free(state, probDistContig); +} +#undef NUM_BLOCKS diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh deleted file mode 100644 index 003e960..0000000 --- a/lib/THC/THCTensorRandom.cuh +++ /dev/null @@ -1,278 +0,0 @@ -#ifndef THC_TENSOR_RANDOM_CUH -#define THC_TENSOR_RANDOM_CUH - -#include "THCNumerics.cuh" -#include "THCReduceApplyUtils.cuh" -#include "THCTensorMathReduce.cuh" - -#include - -#define MAX_NUM_BLOCKS 64 -#define BLOCK_SIZE 256 -/* Separate kernel because curand_log_normal gets extra parameters. */ - -template -__global__ void generateLogNormal(curandStateMtgp32 *state, int size, T *result, double mean, double stddev) -{ - int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; - int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; - for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { - float x = curand_log_normal(&state[blockIdx.x], mean, stddev); - if (i < size) { - result[i] = ScalarConvert::to(x); - } - } -} - -template <> -__global__ void generateLogNormal(curandStateMtgp32 *state, int size, double *result, double mean, double stddev) -{ - int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; - int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; - for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { - double x = curand_log_normal_double(&state[blockIdx.x], mean, stddev); - if (i < size) { - result[i] = x; - } - } -} - -#undef MAX_NUM_BLOCKS -#undef BLOCK_SIZE - -// Normalizes the L1 norm of every row to 1; used by multinomial -template -__global__ void renormRowsL1(T* dist, long rows, long cols) { - extern __shared__ __align__(sizeof(T)) unsigned char my_smem[]; - T *smem = reinterpret_cast(my_smem); - - for (long row = blockIdx.x; row < rows; row += gridDim.x) { - T sum = ScalarConvert::to(0); - for (long col = threadIdx.x; col < cols; col += blockDim.x) { - sum = THCNumerics::add(sum, dist[row * cols + col]); - } - - sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd(), ScalarConvert::to(0)); - if (threadIdx.x == 0) { - smem[0] = sum; - } - __syncthreads(); - - sum = smem[0]; - if (THCNumerics::gt(sum, ScalarConvert::to(0))) { - for (long col = threadIdx.x; col < cols; col += blockDim.x) { - dist[row * cols + col] = THCNumerics::div(dist[row * cols + col], sum); - } - } - } -} - -template -__global__ void -sampleMultinomialOnce(T* dest, - long distributions, - int categories, - T* dist) { - extern __shared__ __align__(sizeof(T)) unsigned char my_smem[]; - T *smem = reinterpret_cast(my_smem); - T zero = ScalarConvert::to(0); - - for (long curDist = blockIdx.x; - curDist < distributions; curDist += gridDim.x) { - // Each block handles one distribution - // First pass, find the total sum of the distribution - T sum = zero; - for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { - sum = THCNumerics::add(sum, dist[curDist * categories + cat]); - } - - // threadIdx.x == 0 has the sum value from this - sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd(), zero); - - // Broadcast sum and sample value - if (threadIdx.x == 0) { - smem[0] = sum; - smem[1] = dest[curDist]; - } - __syncthreads(); - - sum = smem[0]; - T sample = smem[1]; - __syncthreads(); - - if (THCNumerics::eq(sum, zero) || THCNumerics::eq(sample, zero)) { - // Choose the first element - if (threadIdx.x == 0) { - dest[curDist] = ScalarConvert::to(1); - } - - continue; - } - - int chunks = THCCeilDiv(categories, (int) blockDim.x); - T prevHighProb = zero; - - for (int chunk = 0; chunk < chunks; ++chunk) { - // All threads in bounds load a value - int cat = chunk * blockDim.x + threadIdx.x; - - T val = - cat < categories ? THCNumerics::div(dist[curDist * categories + cat], sum) : - zero; - - smem[threadIdx.x] = val; - __syncthreads(); - - // Perform an inclusive prefix sum of the shared memory contents - for (int offset = 1; offset < blockDim.x; offset *= 2) { - T val = zero; - - if (threadIdx.x >= offset) { - val = THCNumerics::add(smem[threadIdx.x - offset], smem[threadIdx.x]); - } - - __syncthreads(); - if (threadIdx.x >= offset) { - smem[threadIdx.x] = val; - } - __syncthreads(); - } - - // Each thread will check to see if the sample falls in its - // bucket - T curBucket = THCNumerics::add(smem[threadIdx.x], prevHighProb); - T prevBucket = - threadIdx.x == 0 ? prevHighProb : - THCNumerics::add(smem[threadIdx.x - 1], prevHighProb); - bool inBucket = - (cat < categories) && - (!THCNumerics::gt(sample, curBucket)) && - (THCNumerics::gt(sample, prevBucket)); - - if (inBucket) { - // We're done; we have the sample - // Torch indices are 1-based - // FIXME: broadcast exit flag? - dest[curDist] = ScalarConvert::to(cat + TH_INDEX_BASE); - } - - // Store the previous scan's high value for future use - prevHighProb = THCNumerics::add(prevHighProb, smem[blockDim.x - 1]); - - __syncthreads(); - } - } -} - -template -__device__ int binarySearchForMultinomial(T* dist, - int size, - T val) { - int start = 0; - int end = size; - - while (end - start > 0) { - int mid = start + (end - start) / 2; - - T midVal = dist[mid]; - if (THCNumerics::lt(midVal, val)) { - start = mid + 1; - } else { - end = mid; - } - } - - if (start == size) { - // No probability mass or precision problems; just return the - // first element - start = 0; - } - - return start; -} - -template -__global__ void -sampleMultinomialWithReplacement(curandStateMtgp32* state, - int totalSamples, - T* dest, - long distributions, - int categories, - T* normDistPrefixSum) { - // At the moment, each warp computes one sample value in the binary - // search due to divergence. It seems possible to compute multiple - // values and limit divergence though later on. However, no matter - // what, all block threads must participate in the curand_uniform - // call to update the generator state. - - // The block determines the distribution for which we generate a point - for (long curDist = blockIdx.x; - curDist < distributions; - curDist += gridDim.x) { - for (int sampleBase = 0; - sampleBase < totalSamples; sampleBase += blockDim.y) { - // The warp determines the sample - int sample = sampleBase + threadIdx.y; - - // All threads participate in this - T r = ScalarConvert::to(curand_uniform(&state[blockIdx.x])); - - if (threadIdx.x == 0 && sample < totalSamples) { - // Find the bucket that a uniform sample lies in - int choice = binarySearchForMultinomial( - normDistPrefixSum + curDist * categories, - categories, - r); - - // Torch indices are 1-based - dest[curDist * totalSamples + sample] = ScalarConvert::to(choice + TH_INDEX_BASE); - } - } - } -} - -template -__global__ void -sampleMultinomialWithoutReplacement(curandStateMtgp32* state, - int totalSamples, - int sample, - T* dest, - long distributions, - int categories, - T* origDist, - T* normDistPrefixSum) { - // At the moment, each warp computes one sample value in the binary - // search due to divergence. It seems possible to compute multiple - // values and limit divergence though later on. However, no matter - // what, all block threads must participate in the curand_uniform - // call to update the generator state. - - // The block and warp determines the distribution for which we - // generate a point - for (long curDistBase = blockIdx.x * blockDim.y; - curDistBase < distributions; - curDistBase += gridDim.x * blockDim.y) { - // The warp determines the distribution - long curDist = curDistBase + threadIdx.y; - - // All threads must participate in this - T r = ScalarConvert::to(curand_uniform(&state[blockIdx.x])); - - if (threadIdx.x == 0 && curDist < distributions) { - // Find the bucket that a uniform sample lies in - int choice = binarySearchForMultinomial( - normDistPrefixSum + curDist * categories, - categories, - r); - - // Torch indices are 1-based - dest[curDist * totalSamples + sample] = ScalarConvert::to(choice + TH_INDEX_BASE); - - // Without replacement, so update the original probability so it - // is not considered a second time - origDist[curDist * categories + choice] = ScalarConvert::to(0); - } - } -} - -#endif // THC_TENSOR_RANDOM_CUH diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h index 12128cd..93c0d77 100644 --- a/lib/THC/THCTensorRandom.h +++ b/lib/THC/THCTensorRandom.h @@ -3,9 +3,6 @@ #include "THCTensor.h" -#include "generic/THCTensorRandom.h" -#include "THCGenerateAllTypes.h" - /* Generator */ typedef struct _Generator { struct curandStateMtgp32* gen_states; @@ -31,6 +28,15 @@ THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_s THC_API unsigned long THCRandom_initialSeed(struct THCState *state); THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state); +THC_API void THCudaTensor_geometric(struct THCState *state, THCudaTensor *self, double p); +THC_API void THCudaTensor_bernoulli(struct THCState *state, THCudaTensor *self, double p); +THC_API void THCudaTensor_uniform(struct THCState *state, THCudaTensor *self, double a, double b); +THC_API void THCudaTensor_normal(struct THCState *state, THCudaTensor *self, double mean, double stdv); +THC_API void THCudaTensor_exponential(struct THCState *state, THCudaTensor *self, double lambda); +THC_API void THCudaTensor_cauchy(struct THCState *state, THCudaTensor *self, double median, double sigma); +THC_API void THCudaTensor_logNormal(struct THCState *state, THCudaTensor *self, double mean, double stdv); + +THC_API void THCudaTensor_multinomial(struct THCState *state, THCudaTensor *self, THCudaTensor *prob_dist, int n_sample, int with_replacement); THC_API struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state); diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu deleted file mode 100644 index 50aa2e8..0000000 --- a/lib/THC/generic/THCTensorRandom.cu +++ /dev/null @@ -1,328 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCTensorRandom.cu" -#else - -#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS) - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, double b) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_uniform<<>>( - gen->gen_states, size, data, a, b); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_normal<<>>( - gen->gen_states, size, data, mean, stdv); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -THC_API void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, double stdv) -{ - - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generateLogNormal<<>>( - gen->gen_states, size, data, mean, stdv); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_exponential<<>>( - gen->gen_states, size, data, lambda); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median, double sigma) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_cauchy<<>>( - gen->gen_states, size, data, median, sigma); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -void THCTensor_(renormRows)(struct THCState* state, - THCTensor* t) { - THAssert(THCTensor_(nDimension)(state, t) == 2); - long rows = THCTensor_(size)(state, t, 0); - long cols = THCTensor_(size)(state, t, 1); - - cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state); - THAssert(props != NULL); - - int numSM = props->multiProcessorCount; - int maxThreads = props->maxThreadsPerBlock; - - dim3 grid(rows < numSM * 4 ? rows : numSM * 4); - dim3 block(cols < maxThreads ? cols : maxThreads); - - renormRowsL1 - <<>>(THCTensor_(data)(state, t), - rows, cols); -} - -THC_API void THCTensor_(multinomial)(struct THCState *state, - THCTensor *self, - THCTensor *prob_dist, - int n_sample, - int with_replacement) -{ - THAssert(THCTensor_(checkGPU)(state, 2, self, prob_dist)); - Generator* gen = THCRandom_getGenerator(state); - - int inputSize = THCTensor_(nDimension)(state, prob_dist); - THArgCheck(inputSize > 0 && inputSize <= 2, 2, - "prob_dist must be 1 or 2 dim"); - - // Categories are in the innermost dimension - long numDist = - inputSize == 1 ? 1 : THCTensor_(size)(state, prob_dist, 0); - long numCategoriesLong = - inputSize == 1 ? THCTensor_(size)(state, prob_dist, 0) : - THCTensor_(size)(state, prob_dist, 1); - - // Since the index tensor is float, numCategories cannot exceed max - // float integer precision - THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2, - "number of categories cannot exceed 2^24"); - int numCategories = (int) numCategoriesLong; - - THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples"); - - if (!with_replacement) { - THArgCheck(n_sample <= numCategories, 2, - "cannot sample n_sample > prob_dist:size(1) samples without " - "replacement"); - } - - // It is possible that prob_dist is non-contiguous - THCTensor* probDistContig = - THCTensor_(newContiguous)(state, prob_dist); - - // Restructure data for 2d - if (inputSize == 1) { - THCTensor_(resize2d)(state, probDistContig, 1, numCategories); - } - - THCTensor_(resize2d)(state, self, numDist, n_sample); - - if (n_sample == 1) { - // Optimized allocation-free implementation - - // To exploit greater parallelism for the sampling, generate the - // Uniform random samples in a separate kernel launch, into the - // result memory. The device RNG is thread-limited - THCTensor_(uniform)(state, self, 0.0, 1.0); - - cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state); - THAssert(props != NULL); - - int numSM = props->multiProcessorCount; - int maxThreads = props->maxThreadsPerBlock; - - dim3 block(numCategories < maxThreads ? numCategories : maxThreads); - dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4); - - sampleMultinomialOnce - <<>>( - THCTensor_(data)(state, self), - numDist, - numCategories, - THCTensor_(data)(state, probDistContig)); - } else { - // Generic, slow implementation with memory allocations - - // For sampling without replacement, we modify the distribution - // for subsequent samples in this space - THCTensor* origDist = THCTensor_(new)(state); - THCTensor_(resizeAs)(state, origDist, probDistContig); - THCTensor_(copy)(state, origDist, probDistContig); - - THCTensor* normDist = THCTensor_(new)(state); - THCTensor_(resizeAs)(state, normDist, probDistContig); - - THCTensor* prefixSum = THCTensor_(new)(state); - - // Renorm along rows - THCTensor_(copy)(state, normDist, origDist); - THCTensor_(renormRows)(state, normDist); - - // Prefix sum along rows - THCTensor_(cumsum)(state, prefixSum, normDist, 1); - - if (with_replacement) { - // Sample with replacement - - // Binary search is warp divergent (so effectively we're running - // with just a single thread), but for better utilization, - // we need each block to have at least 4 warps. - dim3 block(32, 4); - - // Each warp in a block will generate a sample from one - // distribution concurrently. - dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS); - - sampleMultinomialWithReplacement - <<>>( - gen->gen_states, - n_sample, - THCTensor_(data)(state, self), - numDist, numCategories, - THCTensor_(data)(state, prefixSum)); - } else { - // Sample without replacement - - // Binary search is warp divergent (so effectively we're running - // with just a single thread), but for better utilization, - // we need each block to have at least 4 warps. - dim3 block(32, 4); - - // Each warp in a block will generate a sample from a different - // distribution concurrently. - ptrdiff_t numBlocks = THCCeilDiv(numDist, 4L); - dim3 grid(numBlocks < MAX_NUM_BLOCKS ? numBlocks : MAX_NUM_BLOCKS); - - for (int sample = 0; sample < n_sample; ++sample) { - if (sample > 0) { - // Update probabilities - // Renorm along rows - THCTensor_(copy)(state, normDist, origDist); - THCTensor_(renormRows)(state, normDist); - - // Prefix sum along rows - THCTensor_(cumsum)(state, prefixSum, normDist, 1); - } - - // The kernel can only draw one sample before we have to - // recalculate our distribution - sampleMultinomialWithoutReplacement - <<>>( - gen->gen_states, - n_sample, - sample, - THCTensor_(data)(state, self), - numDist, numCategories, - THCTensor_(data)(state, origDist), - THCTensor_(data)(state, prefixSum)); - } - } - - THCTensor_(free)(state, prefixSum); - THCTensor_(free)(state, normDist); - THCTensor_(free)(state, origDist); - } - - // Revert data restructuring based on input sizes - if (inputSize == 1) { - THCTensor_(resize1d)(state, self, n_sample); - - // Unfortunately, if prob_dist is contiguous already, - // newContiguous is not a private copy, so we have to restructure - // this too, so as to not affect prob_dist - THCTensor_(resize1d)(state, probDistContig, numCategories); - } - - THCTensor_(free)(state, probDistContig); -} - -THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size) -{ - THAssert(THCTensor_(checkGPU)(state, 1, r_)); - THCTensor_(resize)(state, r_, size, NULL); - THCTensor_(uniform)(state, r_, 0, 1); -} - -void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size) -{ - THAssert(THCTensor_(checkGPU)(state, 1, r_)); - THCTensor_(resize)(state, r_, size, NULL); - THCTensor_(normal)(state, r_, 0, 1); -} - -#endif - -#if defined(THC_REAL_IS_DOUBLE) -GENERATE_KERNEL1(generate_bernoulli, double, double p, double, curand_uniform_double, x <= p) -#else -GENERATE_KERNEL1(generate_bernoulli, real, double p, float, curand_uniform, (ScalarConvert::to(x <= p))) -#endif - -THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_bernoulli<<>>( - gen->gen_states, size, data, p); - - THCTensor_(freeCopyTo)(state, self, self_); -}; - -#if defined(THC_REAL_IS_DOUBLE) - -GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, floor((log(1-x) / log(p)) + 1)) -#else -GENERATE_KERNEL1(generate_geometric, real, double p, float, curand_uniform, (ScalarConvert::to(floorf((log(1-x) / log(p)) + 1)))) -#endif - -THC_API void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - - THCTensor *self = THCTensor_(newContiguous)(state, self_); - ptrdiff_t size = THCTensor_(nElement)(state, self); - real *data = THCTensor_(data)(state, self); - - generate_geometric<<>>( - gen->gen_states, size, data, p); - - THCTensor_(freeCopyTo)(state, self, self_); -}; -#undef NUM_BLOCKS - -#endif diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h deleted file mode 100644 index a2896c3..0000000 --- a/lib/THC/generic/THCTensorRandom.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCTensorRandom.h" -#else - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b); -THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size); -THC_API void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size); -THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv); -THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv); -THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda); -THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma); -THC_API void THCTensor_(multinomial)(struct THCState *state, THCTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement); - -#endif - -THC_API void THCTensor_(bernoulli)(struct THCState *state, THCTensor *self, double p); -THC_API void THCTensor_(geometric)(struct THCState *state, THCTensor *self, double p); - -#endif diff --git a/test/test.lua b/test/test.lua index 724d5ff..00cfa66 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2526,11 +2526,8 @@ function test.uniform() local max = min + torch.uniform() local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(typename) - x:uniform(min, max) - checkIfUniformlyDistributed(x, min, max) - end + t:uniform(min, max) + checkIfUniformlyDistributed(t, min, max) checkMultiDevice(t, 'uniform', min, max) end @@ -2540,17 +2537,13 @@ function test.bernoulli() local p = torch.uniform() local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(typenames) do - local x = t:type(typename) - x:bernoulli(p) - local mean = x:sum() / (sz1 * sz2) - tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p") - local f = t:float() - tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), - torch.FloatTensor(sz1, sz2):fill(1), - 1e-6, - "each value must be either 0 or 1") - end + t:bernoulli(p) + tester:assertalmosteq(t:mean(), p, 0.1, "mean is not equal to p") + local f = t:float() + tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), + torch.FloatTensor(sz1, sz2):fill(1), + 1e-6, + "each value must be either 0 or 1") checkMultiDevice(t, 'bernoulli', p) end @@ -2561,13 +2554,9 @@ function test.normal() local tolerance = 0.01 local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(t2cpu[typename]) - x:normal(mean, std) - tester:assertalmosteq(x:mean(), mean, tolerance, "mean is wrong") - tester:assertalmosteq(x:std(), std, tolerance, "standard deviation is wrong") - end - + t:normal(mean, std) + tester:assertalmosteq(t:mean(), mean, tolerance, "mean is wrong") + tester:assertalmosteq(t:std(), std, tolerance, "standard deviation is wrong") checkMultiDevice(t, 'normal', mean, std) end @@ -2578,13 +2567,10 @@ function test.logNormal() local tolerance = 0.01 local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(typename) - x:logNormal(mean, std) - local logt = x:log() - tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong") - tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong") - end + t:logNormal(mean, std) + local logt = t:log() + tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong") + tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong") checkMultiDevice(t, 'logNormal', mean, std) end @@ -2594,14 +2580,10 @@ function test.geometric() local p = torch.uniform() local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(typename) - x:geometric(p) - - local u = torch.FloatTensor(sz1, sz2):fill(1) - - ((x:float() - 1) * math.log(p)):exp() - checkIfUniformlyDistributed(u, 0, 1) - end + t:geometric(p) + local u = torch.FloatTensor(sz1, sz2):fill(1) - + ((t:float() - 1) * math.log(p)):exp() + checkIfUniformlyDistributed(u, 0, 1) checkMultiDevice(t, 'geometric', p) end @@ -2611,13 +2593,10 @@ function test.exponential() local lambda = torch.uniform() local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(t2cpu[typename]) - x:exponential(lambda) - local u = torch.FloatTensor(sz1, sz2):fill(1) - - (x:float() * -lambda):exp() - checkIfUniformlyDistributed(u, 0, 1) - end + t:exponential(lambda) + local u = torch.FloatTensor(sz1, sz2):fill(1) - + (t:float() * -lambda):exp() + checkIfUniformlyDistributed(u, 0, 1) checkMultiDevice(t, 'exponential', lambda) end @@ -2627,12 +2606,9 @@ function test.cauchy() local median, sigma = torch.uniform(), torch.uniform() local t = torch.CudaTensor(sz1, sz2) - for _, typename in ipairs(float_typenames) do - local x = t:type(typename) - x:cauchy(median, sigma) - local u = ((x:float() - median) / sigma):atan() / math.pi + 0.5 - checkIfUniformlyDistributed(u, 0, 1) - end + t:cauchy(median, sigma) + local u = ((t:float() - median) / sigma):atan() / math.pi + 0.5 + checkIfUniformlyDistributed(u, 0, 1) checkMultiDevice(t, 'cauchy', median, sigma) end @@ -2699,19 +2675,16 @@ function test.multinomial_with_replacement() local prob_dist = torch.CudaTensor(n_row, n_col):uniform() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) - for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(typename) - local sample_indices = torch.multinomial(pd, n_sample, true) - tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") - tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") - - for i = 1, n_row do - for j = 1, n_sample do - local val = sample_indices[{i,j}] - tester:assert(val == math.floor(val) and val >= 1 and val < n_col, - "sampled an invalid index: " .. val) - end - end + local sample_indices = torch.multinomial(prob_dist, n_sample, true) + tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") + tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") + + for i = 1, n_row do + for j = 1, n_sample do + local val = sample_indices[{i,j}] + tester:assert(val == math.floor(val) and val >= 1 and val < n_col, + "sampled an invalid index: " .. val) + end end end end @@ -2725,27 +2698,24 @@ function test.multinomial_without_replacement() local prob_dist = torch.CudaTensor(n_row, n_col):uniform() prob_dist:select(2, n_col):fill(0) --index n_col shouldn't be sampled local n_sample = torch.random(n_col - 1) - for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(typename) - local sample_indices = torch.multinomial(pd, n_sample, false) - tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") - tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") - - sample_indices = sample_indices:float() - - for i = 1, n_row do - local row_samples = {} - for j = 1, n_sample do - local sample_idx = sample_indices[{i,j}] - tester:assert( - sample_idx ~= n_col, "sampled an index with zero probability" - ) - tester:assert( - not row_samples[sample_idx], "sampled an index twice" - ) - row_samples[sample_idx] = true - end - end + local sample_indices = torch.multinomial(prob_dist, n_sample, false) + tester:assert(sample_indices:dim() == 2, "wrong sample_indices dim") + tester:assert(sample_indices:size(2) == n_sample, "wrong number of samples") + + sample_indices = sample_indices:float() + + for i = 1, n_row do + local row_samples = {} + for j = 1, n_sample do + local sample_idx = sample_indices[{i,j}] + tester:assert( + sample_idx ~= n_col, "sampled an index with zero probability" + ) + tester:assert( + not row_samples[sample_idx], "sampled an index twice" + ) + row_samples[sample_idx] = true + end end end end @@ -2761,21 +2731,17 @@ function test.multinomial_without_replacement_gets_all() t[dist] = linear end - local orig = t:clone():long() - - for _, typename in ipairs(float_typenames) do - local x = t:type(typename) + local orig = t:clone() - -- Sample without replacement - local result = torch.multinomial(x, distSize) - tester:assert(result:size(1) == distributions) - tester:assert(result:size(2) == distSize) + -- Sample without replacement + local result = torch.multinomial(t, distSize) + tester:assert(result:size(1) == distributions) + tester:assert(result:size(2) == distSize) - -- Sort, and we should have the original results, since without replacement - -- sampling everything, we should have chosen every value uniquely - result = result:sort(2) - tester:assertTensorEq(orig:type(typename), result, 0, "error in multinomial_without_replacement_gets_all") - end + -- Sort, and we should have the original results, since without replacement + -- sampling everything, we should have chosen every value uniquely + result = result:sort(2) + tester:assertTensorEq(orig, result, 0, "error in multinomial_without_replacement_gets_all") end end @@ -2783,15 +2749,12 @@ function test.multinomial_vector() local n_col = torch.random(100) local prob_dist = torch.CudaTensor(n_col):uniform() local n_sample = n_col - for _, typename in ipairs(float_typenames) do - local pd = prob_dist:type(typename) - local sample_indices = torch.multinomial(pd, n_sample, true) - tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim") - -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize - -- was undone - tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions") - tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples") - end + local sample_indices = torch.multinomial(prob_dist, n_sample, true) + tester:assert(sample_indices:dim() == 1, "wrong sample_indices dim") + -- Multinomial resizes prob_dist to be 2d (1xn), check that the resize + -- was undone + tester:assert(prob_dist:dim() == 1, "wrong number of prob_dist dimensions") + tester:assert(sample_indices:size(1) == n_sample, "wrong number of samples") end function test.get_device() -- cgit v1.2.3