Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-08-09 22:22:52 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-15 09:51:11 +0300
commit0252bcd1b43cc70986a359c5250c80edb6eb29c2 (patch)
tree5cfc0d669eba2c5d8dae01e0926376ef65379ae4
parent7462a22d95d6c306f6b20f50d9986a0893355244 (diff)
Support __neg__, .neg(), and neg_() for Long, Int, Short tensor types.
-rw-r--r--lib/THC/THCNumerics.cuh4
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu8
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h8
3 files changed, 18 insertions, 2 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh
index ba86e8f..a36ff14 100644
--- a/lib/THC/THCNumerics.cuh
+++ b/lib/THC/THCNumerics.cuh
@@ -44,6 +44,7 @@ struct THCNumerics<char> {
static inline __host__ __device__ bool eq(char a, char b) { return a == b; }
static inline __host__ __device__ bool ne(char a, char b) { return a != b; }
+ static inline __host__ __device__ char neg(char a) { return -a; }
static inline __host__ __device__ char add(char a, char b) { return a + b; }
static inline __host__ __device__ char mul(char a, char b) { return a * b; }
static inline __host__ __device__ char sub(char a, char b) { return a - b; }
@@ -63,6 +64,7 @@ struct THCNumerics<short> {
static inline __host__ __device__ bool eq(short a, short b) { return a == b; }
static inline __host__ __device__ bool ne(short a, short b) { return a != b; }
+ static inline __host__ __device__ short neg(short a) { return -a; }
static inline __host__ __device__ short add(short a, short b) { return a + b; }
static inline __host__ __device__ short mul(short a, short b) { return a * b; }
static inline __host__ __device__ short sub(short a, short b) { return a - b; }
@@ -82,6 +84,7 @@ struct THCNumerics<int> {
static inline __host__ __device__ bool eq(int a, int b) { return a == b; }
static inline __host__ __device__ bool ne(int a, int b) { return a != b; }
+ static inline __host__ __device__ int neg(int a) { return -a; }
static inline __host__ __device__ int add(int a, int b) { return a + b; }
static inline __host__ __device__ int mul(int a, int b) { return a * b; }
static inline __host__ __device__ int sub(int a, int b) { return a - b; }
@@ -101,6 +104,7 @@ struct THCNumerics<long> {
static inline __host__ __device__ bool eq(long a, long b) { return a == b; }
static inline __host__ __device__ bool ne(long a, long b) { return a != b; }
+ static inline __host__ __device__ long neg(long a) { return -a; }
static inline __host__ __device__ long add(long a, long b) { return a + b; }
static inline __host__ __device__ long mul(long a, long b) { return a * b; }
static inline __host__ __device__ long sub(long a, long b) { return a - b; }
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index cdf4b82..c9b4f8c 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -46,7 +46,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(rsqrt, THCNumerics<real>::rsqrt, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( ceil, THCNumerics<real>::ceil, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, THCNumerics<real>::floor, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics<real>::trunc, Real)
-IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( acos, THCNumerics<real>::acos, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cosh, THCNumerics<real>::cosh, Real)
@@ -61,6 +60,13 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics<real>::cinv, Real)
#endif
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
+ defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
+
+IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)
+
+#endif
+
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( abs, THCNumerics<real>::abs, Real)
#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
index 17171c0..cba627c 100644
--- a/lib/THC/generic/THCTensorMathPointwise.h
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -30,11 +30,17 @@ THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src)
THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w);
-THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);
#endif
+#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
+ defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
+
+THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
+
+#endif
+
THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value);