Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-09 22:55:51 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-09 22:55:51 +0300
commit0172650d2efa1bb2da90b8db076fb7ebe61ae1ca (patch)
treeb76ca0e1d1dd1ae4f44eec5ce010c16c2abc75b2
parentef80644772cd00fb8b75fc1aa734a16a7f2cb85e (diff)
[cutorch rand2gen] extend functions to use _double methods
-rw-r--r--lib/THC/THCTensorRandom.cu12
-rw-r--r--lib/THC/generic/THCTensorRandom.cu15
2 files changed, 21 insertions, 6 deletions
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index 416691c..4493fe8 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -220,6 +220,18 @@ GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_unif
GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<float, half>::to(x * (b-a) + a)))
+GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
+GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)
+GENERATE_KERNEL2(generate_normal, half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, half>::to((x * stdv) + mean)))
+
+GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(1-x)))
+GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(1-x)))
+GENERATE_KERNEL1(generate_exponential, half, double lambda, float, curand_uniform, (ScalarConvert<float, half>::to((float)(-1. / lambda * log(1-x)))))
+
+GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
+GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))
+GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))
+
#include "generic/THCTensorRandom.cu"
#include "THCGenerateAllTypes.h"
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
index 9af79c1..16d1edd 100644
--- a/lib/THC/generic/THCTensorRandom.cu
+++ b/lib/THC/generic/THCTensorRandom.cu
@@ -20,8 +20,6 @@ THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, do
THCTensor_(freeCopyTo)(state, self, self_);
};
-GENERATE_KERNEL2(generate_normal, real, double mean, double stdv, float, curand_normal, (ScalarConvert<float, real>::to((x * stdv) + mean)))
-
THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
{
THAssert(THCTensor_(checkGPU)(state, 1, self_));
@@ -52,8 +50,6 @@ THC_API void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mea
THCTensor_(freeCopyTo)(state, self, self_);
};
-GENERATE_KERNEL1(generate_exponential, real, double lambda, float, curand_uniform, (ScalarConvert<float, real>::to((float)(-1. / lambda * log(1-x)))))
-
THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda)
{
THAssert(THCTensor_(checkGPU)(state, 1, self_));
@@ -69,8 +65,6 @@ THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double l
THCTensor_(freeCopyTo)(state, self, self_);
};
-GENERATE_KERNEL2(generate_cauchy, real, double median, double sigma, float, curand_uniform, (ScalarConvert<float, real>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))
-
THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median, double sigma)
{
THAssert(THCTensor_(checkGPU)(state, 1, self_));
@@ -288,7 +282,11 @@ void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size)
#endif
+#if defined(THC_REAL_IS_DOUBLE)
+GENERATE_KERNEL1(generate_bernoulli, double, double p, double, curand_uniform_double, x <= p)
+#else
GENERATE_KERNEL1(generate_bernoulli, real, double p, float, curand_uniform, (ScalarConvert<bool, real>::to(x <= p)))
+#endif
THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
{
@@ -304,7 +302,12 @@ THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
THCTensor_(freeCopyTo)(state, self, self_);
};
+#if defined(THC_REAL_IS_DOUBLE)
+
+GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, (log(1-x) / log(p)) + 1)
+#else
GENERATE_KERNEL1(generate_geometric, real, double p, float, curand_uniform, (ScalarConvert<float, real>::to((log(1-x) / log(p)) + 1)))
+#endif
THC_API void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p)
{