diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-14 01:54:25 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-14 01:54:25 +0300 |
commit | 3956f022d92f3e20ef41589b1c7e2415ffffcabd (patch) | |
tree | 387107a879505ad0a115a462569a2e2002c5c77f | |
parent | e00f7d4c0f70e3583ff0a5359095ad7afcaa7009 (diff) | |
parent | 6fdd58c414dea8afeb97b736ed5f4f1a86906df1 (diff) |
Merge pull request #630 from apaszke/bernoulli
Implement bernoulli with element-wise probabilities for all types
-rw-r--r-- | TensorMath.lua | 44 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cu | 45 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorRandom.cu | 25 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorRandom.h | 2 | ||||
-rw-r--r-- | test/test.lua | 26 |
5 files changed, 109 insertions, 33 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 2eb6ddf..91d97dd 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -912,14 +912,21 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor .. "Array"}, {name="index", default=lastdimarray(2)}}) - for _,f in ipairs({{name='geometric'}, - {name='bernoulli', a=0.5}}) do + wrap("geometric", + cname("geometric"), + {{name=Tensor, returned=true}, + {name='double'}}) - wrap(f.name, - cname(f.name), - {{name=Tensor, returned=true}, - {name='double', default=f.a}}) - end + wrap("bernoulli", + cname("bernoulli"), + {{name=Tensor, returned=true}, + {name='double', default=0.5}}, + cname("bernoulli_FloatTensor"), + {{name=Tensor, returned=true}, + {name="CudaTensor"}}, + cname("bernoulli_DoubleTensor"), + {{name=Tensor, returned=true}, + {name="CudaDoubleTensor"}}) wrap("nonzero", cname("nonzero"), @@ -1868,14 +1875,21 @@ wrap("nonzero", {{name="CudaLongTensor", default=true, returned=true}, {name=Tensor}}) -for _,f in ipairs({{name='geometric'}, - {name='bernoulli', a=0.5}}) do - - wrap(f.name, - cname(f.name), - {{name=Tensor, returned=true}, - {name=real, default=f.a}}) -end +wrap("geometric", + cname("geometric"), + {{name=Tensor, returned=true}, + {name='double'}}) + +wrap("bernoulli", + cname("bernoulli"), + {{name=Tensor, returned=true}, + {name='double', default=0.5}}, + cname("bernoulli_FloatTensor"), + {{name=Tensor, returned=true}, + {name="CudaTensor"}}, + cname("bernoulli_DoubleTensor"), + {{name=Tensor, returned=true}, + {name="CudaDoubleTensor"}}) for _,f in ipairs({{name='uniform', a=0, b=1}, {name='normal', a=0, b=1}, diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index e05cf82..78f4dc9 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -77,34 +77,59 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) memcpy(&gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size); } -#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ +#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \ +__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ - int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ + int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \ - CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ + CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ if (i < size) { \ - T y = TRANSFORM; \ + T y = TRANSFORM; \ result[i] = y; \ } \ } \ } -#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \ +#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \ +__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ - int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ + int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \ - CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ + CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \ if (i < size) { \ - T y = TRANSFORM; \ + T y = TRANSFORM; \ result[i] = y; \ } \ } \ } +template<typename T, typename U> +struct is_same { static const bool value = false; }; + +template<typename T> +struct is_same<T, T> { static const bool value = true; }; + +template<typename real, typename prob_type> +__global__ void generate_bernoulli_tensor(curandStateMtgp32 *state, int size, + real *result, prob_type *probs) +{ + int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; + int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; + for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { + if (is_same<prob_type, double>::value) { + double x = curand_uniform_double(&state[blockIdx.x]); + if (i < size) + result[i] = ScalarConvert<bool, real>::to(x <= probs[i]); + } else { + float x = curand_uniform(&state[blockIdx.x]); + if (i < size) + result[i] = ScalarConvert<bool, real>::to(x <= probs[i]); + } + } +} + GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, x * (b-a) + a) GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a) diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu index f04e647..8be2a3b 100644 --- a/lib/THC/generic/THCTensorRandom.cu +++ b/lib/THC/generic/THCTensorRandom.cu @@ -302,6 +302,31 @@ THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p) THCTensor_(freeCopyTo)(state, self, self_); }; +#define DEFINE_BERNOULLI_TENSOR(NAME, PROB_TYPE, PROB_DATA_TYPE) \ +THC_API void THCTensor_(NAME)(THCState* state, \ + THCTensor *self_, PROB_TYPE *probs_) \ +{ \ + THAssert(THCTensor_(checkGPU)(state, 2, self_, probs_)); \ + Generator* gen = THCRandom_getGenerator(state); \ + THCTensor *self = THCTensor_(newContiguous)(state, self_); \ + PROB_TYPE *probs = PROB_TYPE##_newContiguous(state, probs_); \ + ptrdiff_t size = THCTensor_(nElement)(state, self); \ + ptrdiff_t prob_size = PROB_TYPE##_nElement(state, probs); \ + real *result_data = THCTensor_(data)(state, self); \ + PROB_DATA_TYPE *probs_data = PROB_TYPE##_data(state, probs); \ + \ + THArgCheck(size == prob_size, 3, "inconsistent tensor size"); \ + \ + generate_bernoulli_tensor<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( \ + gen->gen_states, size, result_data, probs_data); \ + \ + PROB_TYPE##_free(state, probs); \ + THCTensor_(freeCopyTo)(state, self, self_); \ +} + +DEFINE_BERNOULLI_TENSOR(bernoulli_FloatTensor, THCudaTensor, float) +DEFINE_BERNOULLI_TENSOR(bernoulli_DoubleTensor, THCudaDoubleTensor, double) + #if defined(THC_REAL_IS_DOUBLE) GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, ceil(log(x) / log(1-p))) diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h index a2896c3..0e37491 100644 --- a/lib/THC/generic/THCTensorRandom.h +++ b/lib/THC/generic/THCTensorRandom.h @@ -16,6 +16,8 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, THCTensor *self, TH #endif THC_API void THCTensor_(bernoulli)(struct THCState *state, THCTensor *self, double p); +THC_API void THCTensor_(bernoulli_FloatTensor)(struct THCState *state, THCTensor *self, THCudaTensor *p); +THC_API void THCTensor_(bernoulli_DoubleTensor)(struct THCState *state, THCTensor *self, THCudaDoubleTensor *p); THC_API void THCTensor_(geometric)(struct THCState *state, THCTensor *self, double p); #endif diff --git a/test/test.lua b/test/test.lua index bce8109..53c9563 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2652,18 +2652,28 @@ function test.bernoulli() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) local p = torch.uniform() + local p_fl = torch.rand(sz1, sz2):cuda() + local p_dbl = torch.rand(sz1, sz2):cudaDouble() local t = torch.CudaTensor(sz1, sz2) for _, typename in ipairs(typenames) do local x = t:type(typename) - x:bernoulli(p) - local mean = x:sum() / (sz1 * sz2) - tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p") - local f = x:float() - tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), - torch.FloatTensor(sz1, sz2):fill(1), - 1e-6, - "each value must be either 0 or 1") + local expected_mean + for i, p in ipairs({p, p_fl, p_dbl}) do + x:bernoulli(p) + local mean = x:sum() / (sz1 * sz2) + if torch.type(p) == 'number' then + expected_mean = p + else + expected_mean = p:mean() + end + tester:assertalmosteq(mean, expected_mean, 0.1, "mean is not equal to the expected value") + local f = x:float() + tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(), + torch.FloatTensor(sz1, sz2):fill(1), + 1e-6, + "each value must be either 0 or 1") + end end checkMultiDevice(t, 'bernoulli', p) end |