diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-08 19:11:39 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-12 00:22:59 +0300 |
commit | cb3eb45848cf445b36f95c13161ff1c7e3e545f5 (patch) | |
tree | 3300c0d047535bddb14392489f4e2e7ede1d78bc | |
parent | 507fbe3d25497909aad66a9d384a69974a9bb041 (diff) |
[cutorch rand2gen] move cauchy to generic
-rw-r--r-- | TensorMath.lua | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cu | 16 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.h | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorRandom.cu | 17 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorRandom.h | 1 | ||||
-rw-r--r-- | test/test.lua | 9 |
6 files changed, 25 insertions, 20 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 478be2a..b6f3c7b 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -939,6 +939,7 @@ for k, Tensor_ in pairs(handledTypenames) do {name="LongArg"}}) for _,f in ipairs({{name='uniform', a=0, b=1}, + {name='cauchy', a=0, b=1}, {name='normal', a=0, b=1}}) do wrap(f.name, diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index dcce23e..a493ce7 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -221,7 +221,6 @@ GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_un GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<float, half>::to(x * (b-a) + a))) GENERATE_KERNEL1(generate_geometric, float, double p, float, curand_uniform, (log(1-x) / log(p)) + 1) -GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5)))) /* Separate kernel because curand_log_normal gets extra parameters. */ __global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *result, float mean, float stddev) @@ -267,21 +266,6 @@ THC_API void THCudaTensor_geometric(THCState* state, THCudaTensor *self_, double THCudaTensor_freeCopyTo(state, self, self_); }; -THC_API void THCudaTensor_cauchy(THCState* state, THCudaTensor *self_, double median, double sigma) -{ - THAssert(THCudaTensor_checkGPU(state, 1, self_)); - Generator* gen = THCRandom_getGenerator(state); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - ptrdiff_t size = THCudaTensor_nElement(state, self); - float *data = THCudaTensor_data(state, self); - - generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( - gen->gen_states, size, data, median, sigma); - - THCudaTensor_freeCopyTo(state, self, self_); -}; - void THCudaTensor_renormRows(struct THCState* state, THCudaTensor* t) { THAssert(THCudaTensor_nDimension(state, t) == 2); diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h index 1ca1bac..eaec73d 100644 --- a/lib/THC/THCTensorRandom.h +++ b/lib/THC/THCTensorRandom.h @@ -32,7 +32,6 @@ THC_API unsigned long THCRandom_initialSeed(struct THCState *state); THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCudaTensor_geometric(struct THCState *state, THCudaTensor *self, double p); -THC_API void THCudaTensor_cauchy(struct THCState *state, THCudaTensor *self, double median, double sigma); THC_API void THCudaTensor_logNormal(struct THCState *state, THCudaTensor *self, double mean, double stdv); THC_API void THCudaTensor_multinomial(struct THCState *state, THCudaTensor *self, THCudaTensor *prob_dist, int n_sample, int with_replacement); diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index da1a4e1..23dd888 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -52,6 +52,23 @@ THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double l THCTensor_(freeCopyTo)(state, self, self_); }; +GENERATE_KERNEL2(generate_cauchy, real, double median, double sigma, float, curand_uniform, (ScalarConvert<float, real>::to((float)(median + sigma * tan(M_PI*(x-0.5)))))) + +THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median, double sigma) +{ + THAssert(THCTensor_(checkGPU)(state, 1, self_)); + Generator* gen = THCRandom_getGenerator(state); + + THCTensor *self = THCTensor_(newContiguous)(state, self_); + ptrdiff_t size = THCTensor_(nElement)(state, self); + real *data = THCTensor_(data)(state, self); + + generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( + gen->gen_states, size, data, median, sigma); + + THCTensor_(freeCopyTo)(state, self, self_); +}; + #undef NUM_BLOCKS THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size) diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h index 9af42ce..059a362 100644 --- a/lib/THC/generic/THCTensorRandom.h +++ b/lib/THC/generic/THCTensorRandom.h @@ -8,6 +8,7 @@ THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size); THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv); THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda); +THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma); #endif diff --git a/test/test.lua b/test/test.lua index 8319971..f75c628 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2620,9 +2620,12 @@ function test.cauchy() local median, sigma = torch.uniform(), torch.uniform() local t = torch.CudaTensor(sz1, sz2) - t:cauchy(median, sigma) - local u = ((t:float() - median) / sigma):atan() / math.pi + 0.5 - checkIfUniformlyDistributed(u, 0, 1) + for _, typename in ipairs(float_typenames) do + local x = t:type(t2cpu[typename]) + x:cauchy(median, sigma) + local u = ((x:float() - median) / sigma):atan() / math.pi + 0.5 + checkIfUniformlyDistributed(u, 0, 1) + end checkMultiDevice(t, 'cauchy', median, sigma) end |