diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-08-09 22:22:52 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-15 09:51:11 +0300 |
commit | 0252bcd1b43cc70986a359c5250c80edb6eb29c2 (patch) | |
tree | 5cfc0d669eba2c5d8dae01e0926376ef65379ae4 | |
parent | 7462a22d95d6c306f6b20f50d9986a0893355244 (diff) |
Support __neg__, .neg(), and neg_() for Long, Int, Short tensor types.
-rw-r--r-- | lib/THC/THCNumerics.cuh | 4 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 8 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 8 |
3 files changed, 18 insertions, 2 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index ba86e8f..a36ff14 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -44,6 +44,7 @@ struct THCNumerics<char> { static inline __host__ __device__ bool eq(char a, char b) { return a == b; } static inline __host__ __device__ bool ne(char a, char b) { return a != b; } + static inline __host__ __device__ char neg(char a) { return -a; } static inline __host__ __device__ char add(char a, char b) { return a + b; } static inline __host__ __device__ char mul(char a, char b) { return a * b; } static inline __host__ __device__ char sub(char a, char b) { return a - b; } @@ -63,6 +64,7 @@ struct THCNumerics<short> { static inline __host__ __device__ bool eq(short a, short b) { return a == b; } static inline __host__ __device__ bool ne(short a, short b) { return a != b; } + static inline __host__ __device__ short neg(short a) { return -a; } static inline __host__ __device__ short add(short a, short b) { return a + b; } static inline __host__ __device__ short mul(short a, short b) { return a * b; } static inline __host__ __device__ short sub(short a, short b) { return a - b; } @@ -82,6 +84,7 @@ struct THCNumerics<int> { static inline __host__ __device__ bool eq(int a, int b) { return a == b; } static inline __host__ __device__ bool ne(int a, int b) { return a != b; } + static inline __host__ __device__ int neg(int a) { return -a; } static inline __host__ __device__ int add(int a, int b) { return a + b; } static inline __host__ __device__ int mul(int a, int b) { return a * b; } static inline __host__ __device__ int sub(int a, int b) { return a - b; } @@ -101,6 +104,7 @@ struct THCNumerics<long> { static inline __host__ __device__ bool eq(long a, long b) { return a == b; } static inline __host__ __device__ bool ne(long a, long b) { return a != b; } + static inline __host__ __device__ long neg(long a) { return -a; } static inline __host__ __device__ long add(long a, long b) { return a + b; } static inline __host__ __device__ long mul(long a, long b) { return a * b; } static inline __host__ __device__ long sub(long a, long b) { return a - b; } diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index cdf4b82..c9b4f8c 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -46,7 +46,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(rsqrt, THCNumerics<real>::rsqrt, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( ceil, THCNumerics<real>::ceil, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, THCNumerics<real>::floor, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics<real>::trunc, Real) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( acos, THCNumerics<real>::acos, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cosh, THCNumerics<real>::cosh, Real) @@ -61,6 +60,13 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics<real>::cinv, Real) #endif +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \ + defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG) + +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real) + +#endif + IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( abs, THCNumerics<real>::abs, Real) #undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_ diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index 17171c0..cba627c 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -30,11 +30,17 @@ THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src) THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w); -THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src); #endif +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \ + defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG) + +THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src); + +#endif + THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value); |