diff options
author | ethanluoyc <ethanluoyc@gmail.com> | 2017-04-21 02:24:14 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-09 02:20:52 +0300 |
commit | 3502535d13adf117fa90ac3c0827c44d0946d158 (patch) | |
tree | a0bb8f099fce60a320d76038bf4d0e1691beb213 /lib | |
parent | fdc73fa7bd34ebe3e3eeb6a0b7202fe9bdd93be8 (diff) |
Implement lgamma function.
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THC/THCNumerics.cuh | 11 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 1 |
3 files changed, 13 insertions, 0 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index f1576b6..b6d1dac 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -246,6 +246,15 @@ struct THCNumerics<half> { #endif } +static inline __host__ __device__ half lgamma(half a) { +#ifdef __CUDA_ARCH__ + float fa = __half2float(a); + return __float2half(lgammaf(fa)); +#else // __CUDA_ARCH__ + return THC_float2half(lgammaf(THC_half2float(a))); +#endif + } + static inline __host__ __device__ half cos(half a) { #ifdef __CUDA_ARCH__ #ifdef CUDA_HALF_INSTRUCTIONS @@ -527,6 +536,7 @@ struct THCNumerics<float> { static inline __host__ __device__ bool eq(float a, float b) { return a == b; } static inline __host__ __device__ bool ne(float a, float b) { return a != b; } + static inline __host__ __device__ float lgamma(float a) { return lgammaf(a);} static inline __host__ __device__ float exp (float a) { return expf(a); } static inline __host__ __device__ float exp10(float a) { return exp10f(a); } static inline __host__ __device__ float log (float a) { return logf(a); } @@ -571,6 +581,7 @@ struct THCNumerics<double> { static inline __host__ __device__ bool eq(double a, double b) { return a == b; } static inline __host__ __device__ bool ne(double a, double b) { return a != b; } + static inline __host__ __device__ double lgamma(double a) { return ::lgamma(a);} static inline __host__ __device__ double exp (double a) { return ::exp(a); } static inline __host__ __device__ double exp10(double a) { return ::exp10(a); } static inline __host__ __device__ double log (double a) { return ::log(a); } diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 174e8ec..cdf4b82 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -36,6 +36,7 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( log, THCNumerics<real>::log, Real) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(lgamma, THCNumerics<real>::lgamma, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, THCNumerics<real>::log1p, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( exp, THCNumerics<real>::exp, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cos, THCNumerics<real>::cos, Real) diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index 69cfdb3..17171c0 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -6,6 +6,7 @@ THC_API void THCTensor_(sigmoid)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log)(THCState *state, THCTensor *self, THCTensor *src); +THC_API void THCTensor_(lgamma)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log1p)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(exp)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(cos)(THCState *state, THCTensor *self, THCTensor *src); |