Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-09 21:55:54 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-12 00:23:01 +0300
commitb33d92cb7d28b0764f8278dde2e160097f858457 (patch)
treed91fdbc4aa68ccb70a6ccdd45e9e8b36b7ad19ab
parent6a2df68d1693446ec3dda34a271d17cc7d6b959d (diff)
[cutorch rand2gen] partial move of logNormal to generic, needs further debugging
-rw-r--r--TensorMath.lua4
-rw-r--r--lib/THC/THCTensorRandom.cu31
-rw-r--r--lib/THC/THCTensorRandom.cuh33
-rw-r--r--lib/THC/THCTensorRandom.h1
-rw-r--r--lib/THC/generic/THCTensorRandom.cu23
-rw-r--r--lib/THC/generic/THCTensorRandom.h1
-rw-r--r--test/test.lua11
7 files changed, 61 insertions, 43 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index e5bdf2c..c6dada5 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -948,7 +948,8 @@ for k, Tensor_ in pairs(handledTypenames) do
for _,f in ipairs({{name='uniform', a=0, b=1},
{name='cauchy', a=0, b=1},
- {name='normal', a=0, b=1}}) do
+ {name='normal', a=0, b=1},
+ {name='logNormal', a=1, b=2}}) do
wrap(f.name,
cname(f.name),
@@ -957,7 +958,6 @@ for k, Tensor_ in pairs(handledTypenames) do
{name='double', default=f.b}})
end
-
wrap('exponential',
cname('exponential'),
{{name=Tensor, returned=true},
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index 357fbe1..416691c 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -220,37 +220,6 @@ GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_unif
GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<float, half>::to(x * (b-a) + a)))
-/* Separate kernel because curand_log_normal gets extra parameters. */
-__global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *result, float mean, float stddev)
-{
- int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
- int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
- for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
- float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
- if (i < size) {
- result[i] = x;
- }
- }
-}
-
-#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
-THC_API void THCudaTensor_logNormal(THCState* state, THCudaTensor *self_, double mean, double stdv)
-{
- THAssert(THCudaTensor_checkGPU(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
-
- THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
- ptrdiff_t size = THCudaTensor_nElement(state, self);
- float *data = THCudaTensor_data(state, self);
-
- generate_log_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, mean, stdv);
-
- THCudaTensor_freeCopyTo(state, self, self_);
-};
-
-#undef NUM_BLOCKS
-
#include "generic/THCTensorRandom.cu"
#include "THCGenerateAllTypes.h"
diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh
index c503e96..003e960 100644
--- a/lib/THC/THCTensorRandom.cuh
+++ b/lib/THC/THCTensorRandom.cuh
@@ -7,6 +7,39 @@
#include <curand_kernel.h>
+#define MAX_NUM_BLOCKS 64
+#define BLOCK_SIZE 256
+/* Separate kernel because curand_log_normal gets extra parameters. */
+
+template <typename T>
+__global__ void generateLogNormal(curandStateMtgp32 *state, int size, T *result, double mean, double stddev)
+{
+ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
+ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
+ float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
+ if (i < size) {
+ result[i] = ScalarConvert<float, T>::to(x);
+ }
+ }
+}
+
+template <>
+__global__ void generateLogNormal<double>(curandStateMtgp32 *state, int size, double *result, double mean, double stddev)
+{
+ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
+ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
+ double x = curand_log_normal_double(&state[blockIdx.x], mean, stddev);
+ if (i < size) {
+ result[i] = x;
+ }
+ }
+}
+
+#undef MAX_NUM_BLOCKS
+#undef BLOCK_SIZE
+
// Normalizes the L1 norm of every row to 1; used by multinomial
template <typename T>
__global__ void renormRowsL1(T* dist, long rows, long cols) {
diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h
index be510de..12128cd 100644
--- a/lib/THC/THCTensorRandom.h
+++ b/lib/THC/THCTensorRandom.h
@@ -31,7 +31,6 @@ THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_s
THC_API unsigned long THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);
-THC_API void THCudaTensor_logNormal(struct THCState *state, THCudaTensor *self, double mean, double stdv);
THC_API struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state);
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
index f0d3567..c410956 100644
--- a/lib/THC/generic/THCTensorRandom.cu
+++ b/lib/THC/generic/THCTensorRandom.cu
@@ -2,9 +2,10 @@
#define THC_GENERIC_FILE "generic/THCTensorRandom.cu"
#else
+#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
-#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, double b)
{
THAssert(THCTensor_(checkGPU)(state, 1, self_));
@@ -35,6 +36,22 @@ THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean,
THCTensor_(freeCopyTo)(state, self, self_);
};
+THC_API void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, double stdv)
+{
+
+ THAssert(THCTensor_(checkGPU)(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+
+ THCTensor *self = THCTensor_(newContiguous)(state, self_);
+ ptrdiff_t size = THCTensor_(nElement)(state, self);
+ real *data = THCTensor_(data)(state, self);
+
+ generateLogNormal<real><<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, mean, stdv);
+
+ THCTensor_(freeCopyTo)(state, self, self_);
+};
+
GENERATE_KERNEL1(generate_exponential, real, double lambda, float, curand_uniform, (ScalarConvert<float, real>::to((float)(-1. / lambda * log(1-x)))))
THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda)
@@ -255,9 +272,6 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
THCTensor_(free)(state, probDistContig);
}
-
-#undef NUM_BLOCKS
-
THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size)
{
THAssert(THCTensor_(checkGPU)(state, 1, r_));
@@ -269,7 +283,6 @@ THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *siz
GENERATE_KERNEL1(generate_bernoulli, real, double p, float, curand_uniform, (ScalarConvert<bool, real>::to(x <= p)))
-#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
{
THAssert(THCTensor_(checkGPU)(state, 1, self_));
diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h
index 852b00e..ebb3852 100644
--- a/lib/THC/generic/THCTensorRandom.h
+++ b/lib/THC/generic/THCTensorRandom.h
@@ -7,6 +7,7 @@
THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b);
THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size);
THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
+THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma);
THC_API void THCTensor_(multinomial)(struct THCState *state, THCTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement);
diff --git a/test/test.lua b/test/test.lua
index 7bc0b38..fcba6cc 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2578,10 +2578,13 @@ function test.logNormal()
local tolerance = 0.01
local t = torch.CudaTensor(sz1, sz2)
- t:logNormal(mean, std)
- local logt = t:log()
- tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong")
- tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong")
+ for _, typename in ipairs(float_typenames) do
+ local x = t:type(t2cpu[typename])
+ x:logNormal(mean, std)
+ local logt = x:log()
+ tester:assertalmosteq(logt:mean(), mean, tolerance, "mean is wrong")
+ tester:assertalmosteq(logt:std(), std, tolerance, "standard deviation is wrong")
+ end
checkMultiDevice(t, 'logNormal', mean, std)
end