Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-08 00:26:59 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-08 00:26:59 +0300
commit16358a40172df534e90ca2ea40af929128257a56 (patch)
tree49e69341b9cfc7d85804759c7304287ed9ae3ee6
parentbadb05dfab236c4e7ba884029401c4048e689569 (diff)
[cutorch rand2gen] move normal to generic
-rw-r--r--TensorMath.lua3
-rw-r--r--lib/THC/THCTensorRandom.cu17
-rw-r--r--lib/THC/THCTensorRandom.h1
-rw-r--r--lib/THC/generic/THCTensorRandom.cu17
-rw-r--r--lib/THC/generic/THCTensorRandom.h1
-rw-r--r--test/test.lua10
6 files changed, 28 insertions, 21 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index d432b21..d5910bc 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -938,7 +938,8 @@ for k, Tensor_ in pairs(handledTypenames) do
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name="LongArg"}})
- for _,f in ipairs({{name='uniform', a=0, b=1}}) do
+ for _,f in ipairs({{name='uniform', a=0, b=1},
+ {name='normal', a=0, b=1}}) do
wrap(f.name,
cname(f.name),
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index 822a554..9b019a9 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -218,9 +218,8 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, x * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a)
-GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<double, half>::to(x * (b-a) + a)))
+GENERATE_KERNEL2(generate_uniform, half, double a, double b, float, curand_uniform, (ScalarConvert<float, half>::to(x * (b-a) + a)))
-GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
GENERATE_KERNEL1(generate_geometric, float, double p, float, curand_uniform, (log(1-x) / log(p)) + 1)
GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(1-x)))
GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
@@ -239,20 +238,6 @@ __global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *r
}
#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
-THC_API void THCudaTensor_normal(THCState* state, THCudaTensor *self_, double mean, double stdv)
-{
- THAssert(THCudaTensor_checkGPU(state, 1, self_));
- Generator* gen = THCRandom_getGenerator(state);
- THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
- ptrdiff_t size = THCudaTensor_nElement(state, self);
- float *data = THCudaTensor_data(state, self);
-
- generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->gen_states, size, data, mean, stdv);
-
- THCudaTensor_freeCopyTo(state, self, self_);
-};
-
THC_API void THCudaTensor_logNormal(THCState* state, THCudaTensor *self_, double mean, double stdv)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h
index b9c349f..2d94a26 100644
--- a/lib/THC/THCTensorRandom.h
+++ b/lib/THC/THCTensorRandom.h
@@ -32,7 +32,6 @@ THC_API unsigned long THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCudaTensor_geometric(struct THCState *state, THCudaTensor *self, double p);
-THC_API void THCudaTensor_normal(struct THCState *state, THCudaTensor *self, double mean, double stdv);
THC_API void THCudaTensor_exponential(struct THCState *state, THCudaTensor *self, double lambda);
THC_API void THCudaTensor_cauchy(struct THCState *state, THCudaTensor *self, double median, double sigma);
THC_API void THCudaTensor_logNormal(struct THCState *state, THCudaTensor *self, double mean, double stdv);
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
index 7e63cff..19f7cee 100644
--- a/lib/THC/generic/THCTensorRandom.cu
+++ b/lib/THC/generic/THCTensorRandom.cu
@@ -18,6 +18,22 @@ THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, do
THCTensor_(freeCopyTo)(state, self, self_);
};
+
+GENERATE_KERNEL2(generate_normal, real, double mean, double stdv, float, curand_normal, (ScalarConvert<float, real>::to((x * stdv) + mean)))
+
+THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, self_));
+ Generator* gen = THCRandom_getGenerator(state);
+ THCTensor *self = THCTensor_(newContiguous)(state, self_);
+ ptrdiff_t size = THCTensor_(nElement)(state, self);
+ real *data = THCTensor_(data)(state, self);
+
+ generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ gen->gen_states, size, data, mean, stdv);
+
+ THCTensor_(freeCopyTo)(state, self, self_);
+};
#undef NUM_BLOCKS
THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size)
@@ -45,6 +61,7 @@ THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
THCTensor_(freeCopyTo)(state, self, self_);
};
+
#undef NUM_BLOCKS
#endif
diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h
index f25d76c..2b3634d 100644
--- a/lib/THC/generic/THCTensorRandom.h
+++ b/lib/THC/generic/THCTensorRandom.h
@@ -6,6 +6,7 @@
THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b);
THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
#endif
diff --git a/test/test.lua b/test/test.lua
index 02ddc37..6bf5486 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2561,9 +2561,13 @@ function test.normal()
local tolerance = 0.01
local t = torch.CudaTensor(sz1, sz2)
- t:normal(mean, std)
- tester:assertalmosteq(t:mean(), mean, tolerance, "mean is wrong")
- tester:assertalmosteq(t:std(), std, tolerance, "standard deviation is wrong")
+ for _, typename in ipairs(float_typenames) do
+ local x = t:type(t2cpu[typename])
+ x:normal(mean, std)
+ tester:assertalmosteq(x:mean(), mean, tolerance, "mean is wrong")
+ tester:assertalmosteq(x:std(), std, tolerance, "standard deviation is wrong")
+ end
+
checkMultiDevice(t, 'normal', mean, std)
end