Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-12-14 01:54:25 +0300
committerGitHub <noreply@github.com>2016-12-14 01:54:25 +0300
commit3956f022d92f3e20ef41589b1c7e2415ffffcabd (patch)
tree387107a879505ad0a115a462569a2e2002c5c77f
parente00f7d4c0f70e3583ff0a5359095ad7afcaa7009 (diff)
parent6fdd58c414dea8afeb97b736ed5f4f1a86906df1 (diff)
Merge pull request #630 from apaszke/bernoulli
Implement bernoulli with element-wise probabilities for all types
-rw-r--r--TensorMath.lua44
-rw-r--r--lib/THC/THCTensorRandom.cu45
-rw-r--r--lib/THC/generic/THCTensorRandom.cu25
-rw-r--r--lib/THC/generic/THCTensorRandom.h2
-rw-r--r--test/test.lua26
5 files changed, 109 insertions, 33 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 2eb6ddf..91d97dd 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -912,14 +912,21 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor .. "Array"},
{name="index", default=lastdimarray(2)}})
- for _,f in ipairs({{name='geometric'},
- {name='bernoulli', a=0.5}}) do
+ wrap("geometric",
+ cname("geometric"),
+ {{name=Tensor, returned=true},
+ {name='double'}})
- wrap(f.name,
- cname(f.name),
- {{name=Tensor, returned=true},
- {name='double', default=f.a}})
- end
+ wrap("bernoulli",
+ cname("bernoulli"),
+ {{name=Tensor, returned=true},
+ {name='double', default=0.5}},
+ cname("bernoulli_FloatTensor"),
+ {{name=Tensor, returned=true},
+ {name="CudaTensor"}},
+ cname("bernoulli_DoubleTensor"),
+ {{name=Tensor, returned=true},
+ {name="CudaDoubleTensor"}})
wrap("nonzero",
cname("nonzero"),
@@ -1868,14 +1875,21 @@ wrap("nonzero",
{{name="CudaLongTensor", default=true, returned=true},
{name=Tensor}})
-for _,f in ipairs({{name='geometric'},
- {name='bernoulli', a=0.5}}) do
-
- wrap(f.name,
- cname(f.name),
- {{name=Tensor, returned=true},
- {name=real, default=f.a}})
-end
+wrap("geometric",
+ cname("geometric"),
+ {{name=Tensor, returned=true},
+ {name='double'}})
+
+wrap("bernoulli",
+ cname("bernoulli"),
+ {{name=Tensor, returned=true},
+ {name='double', default=0.5}},
+ cname("bernoulli_FloatTensor"),
+ {{name=Tensor, returned=true},
+ {name="CudaTensor"}},
+ cname("bernoulli_DoubleTensor"),
+ {{name=Tensor, returned=true},
+ {name="CudaDoubleTensor"}})
for _,f in ipairs({{name='uniform', a=0, b=1},
{name='normal', a=0, b=1},
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index e05cf82..78f4dc9 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -77,34 +77,59 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
memcpy(&gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size);
}
-#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \
-__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \
+#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \
+__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
- int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
- CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
+ CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
- T y = TRANSFORM; \
+ T y = TRANSFORM; \
result[i] = y; \
} \
} \
}
-#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \
-__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \
+#define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \
+__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
- int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
- CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
+ CURAND_T x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
- T y = TRANSFORM; \
+ T y = TRANSFORM; \
result[i] = y; \
} \
} \
}
+template<typename T, typename U>
+struct is_same { static const bool value = false; };
+
+template<typename T>
+struct is_same<T, T> { static const bool value = true; };
+
+template<typename real, typename prob_type>
+__global__ void generate_bernoulli_tensor(curandStateMtgp32 *state, int size,
+ real *result, prob_type *probs)
+{
+ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
+ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
+ for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
+ if (is_same<prob_type, double>::value) {
+ double x = curand_uniform_double(&state[blockIdx.x]);
+ if (i < size)
+ result[i] = ScalarConvert<bool, real>::to(x <= probs[i]);
+ } else {
+ float x = curand_uniform(&state[blockIdx.x]);
+ if (i < size)
+ result[i] = ScalarConvert<bool, real>::to(x <= probs[i]);
+ }
+ }
+}
+
GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, x * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, x * (b-a) + a)
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
index f04e647..8be2a3b 100644
--- a/lib/THC/generic/THCTensorRandom.cu
+++ b/lib/THC/generic/THCTensorRandom.cu
@@ -302,6 +302,31 @@ THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
THCTensor_(freeCopyTo)(state, self, self_);
};
+#define DEFINE_BERNOULLI_TENSOR(NAME, PROB_TYPE, PROB_DATA_TYPE) \
+THC_API void THCTensor_(NAME)(THCState* state, \
+ THCTensor *self_, PROB_TYPE *probs_) \
+{ \
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, probs_)); \
+ Generator* gen = THCRandom_getGenerator(state); \
+ THCTensor *self = THCTensor_(newContiguous)(state, self_); \
+ PROB_TYPE *probs = PROB_TYPE##_newContiguous(state, probs_); \
+ ptrdiff_t size = THCTensor_(nElement)(state, self); \
+ ptrdiff_t prob_size = PROB_TYPE##_nElement(state, probs); \
+ real *result_data = THCTensor_(data)(state, self); \
+ PROB_DATA_TYPE *probs_data = PROB_TYPE##_data(state, probs); \
+ \
+ THArgCheck(size == prob_size, 3, "inconsistent tensor size"); \
+ \
+ generate_bernoulli_tensor<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( \
+ gen->gen_states, size, result_data, probs_data); \
+ \
+ PROB_TYPE##_free(state, probs); \
+ THCTensor_(freeCopyTo)(state, self, self_); \
+}
+
+DEFINE_BERNOULLI_TENSOR(bernoulli_FloatTensor, THCudaTensor, float)
+DEFINE_BERNOULLI_TENSOR(bernoulli_DoubleTensor, THCudaDoubleTensor, double)
+
#if defined(THC_REAL_IS_DOUBLE)
GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, ceil(log(x) / log(1-p)))
diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h
index a2896c3..0e37491 100644
--- a/lib/THC/generic/THCTensorRandom.h
+++ b/lib/THC/generic/THCTensorRandom.h
@@ -16,6 +16,8 @@ THC_API void THCTensor_(multinomial)(struct THCState *state, THCTensor *self, TH
#endif
THC_API void THCTensor_(bernoulli)(struct THCState *state, THCTensor *self, double p);
+THC_API void THCTensor_(bernoulli_FloatTensor)(struct THCState *state, THCTensor *self, THCudaTensor *p);
+THC_API void THCTensor_(bernoulli_DoubleTensor)(struct THCState *state, THCTensor *self, THCudaDoubleTensor *p);
THC_API void THCTensor_(geometric)(struct THCState *state, THCTensor *self, double p);
#endif
diff --git a/test/test.lua b/test/test.lua
index bce8109..53c9563 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2652,18 +2652,28 @@ function test.bernoulli()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local p = torch.uniform()
+ local p_fl = torch.rand(sz1, sz2):cuda()
+ local p_dbl = torch.rand(sz1, sz2):cudaDouble()
local t = torch.CudaTensor(sz1, sz2)
for _, typename in ipairs(typenames) do
local x = t:type(typename)
- x:bernoulli(p)
- local mean = x:sum() / (sz1 * sz2)
- tester:assertalmosteq(mean, p, 0.1, "mean is not equal to p")
- local f = x:float()
- tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
- torch.FloatTensor(sz1, sz2):fill(1),
- 1e-6,
- "each value must be either 0 or 1")
+ local expected_mean
+ for i, p in ipairs({p, p_fl, p_dbl}) do
+ x:bernoulli(p)
+ local mean = x:sum() / (sz1 * sz2)
+ if torch.type(p) == 'number' then
+ expected_mean = p
+ else
+ expected_mean = p:mean()
+ end
+ tester:assertalmosteq(mean, expected_mean, 0.1, "mean is not equal to the expected value")
+ local f = x:float()
+ tester:assertTensorEq(f:eq(1):add(f:eq(0)):float(),
+ torch.FloatTensor(sz1, sz2):fill(1),
+ 1e-6,
+ "each value must be either 0 or 1")
+ end
end
checkMultiDevice(t, 'bernoulli', p)
end