diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 23:09:05 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 23:09:05 +0300 |
commit | 7860a76e1cc50e5c679a965c95cdca2501cac9bc (patch) | |
tree | 822e601f8204150cf840f19ac7a76e0acbaa5f4e /lib/THC | |
parent | 63df041cae36863deaf9282de3228e6377f3bcba (diff) |
[cutorch refactor] addcmul/addcdiv to generic
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCNumerics.cuh | 4 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cu | 69 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 3 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 34 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 49 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 3 |
6 files changed, 90 insertions, 72 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index af51809..0944360 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -28,6 +28,7 @@ struct THCNumerics<unsigned char> { static inline __host__ __device__ unsigned char add(unsigned char a, unsigned char b) { return a + b; } static inline __host__ __device__ unsigned char mul(unsigned char a, unsigned char b) { return a * b; } static inline __host__ __device__ unsigned char sub(unsigned char a, unsigned char b) { return a - b; } + static inline __host__ __device__ unsigned char div(unsigned char a, unsigned char b) { return a / b; } static inline __host__ __device__ unsigned char abs(unsigned char a) { return abs(a); } }; @@ -46,6 +47,7 @@ struct THCNumerics<char> { static inline __host__ __device__ char add(char a, char b) { return a + b; } static inline __host__ __device__ char mul(char a, char b) { return a * b; } static inline __host__ __device__ char sub(char a, char b) { return a - b; } + static inline __host__ __device__ char div(char a, char b) { return a / b; } static inline __host__ __device__ char abs(char a) { return abs(a); } }; @@ -64,6 +66,7 @@ struct THCNumerics<short> { static inline __host__ __device__ short add(short a, short b) { return a + b; } static inline __host__ __device__ short mul(short a, short b) { return a * b; } static inline __host__ __device__ short sub(short a, short b) { return a - b; } + static inline __host__ __device__ short div(short a, short b) { return a / b; } static inline __host__ __device__ short abs(short a) { return abs(a); } }; @@ -82,6 +85,7 @@ struct THCNumerics<int> { static inline __host__ __device__ int add(int a, int b) { return a + b; } static inline __host__ __device__ int mul(int a, int b) { return a * b; } static inline __host__ __device__ int sub(int a, int b) { return a - b; } + static inline __host__ __device__ int div(int a, int b) { return a / b; } static inline __host__ __device__ int abs(int a) { return ::abs(a); } }; diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index f2a7607..bf8b399 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -75,75 +75,6 @@ void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor * } } -struct TensorAddCMulOp { - TensorAddCMulOp(float v) : val(v) {} - - __device__ __forceinline__ void - operator()(float* out, float* in1, float* in2) { - *out += val * *in1 * *in2; - } - - float val; -}; - -void THCudaTensor_addcmul(THCState *state, THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2) -{ - THAssert(THCudaTensor_checkGPU(state, 4, self_, t, src1, src2)); - if(self_ != t) - { - THCudaTensor_resizeAs(state, self_, t); - THCudaTensor_copy(state, self_, t); - } - else - { - THArgCheck(THCudaTensor_nElement(state, self_) == THCudaTensor_nElement(state, src1), - 1, "sizes do not match"); - } - - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), - 3, "sizes do not match"); - - if (!THC_pointwiseApply3(state, self_, src1, src2, TensorAddCMulOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - - THCudaCheck(cudaGetLastError()); -} - -struct TensorAddCDivOp { - TensorAddCDivOp(float v) : val(v) {} - - __device__ __forceinline__ void - operator()(float* out, float* in1, float* in2) { - *out += val * *in1 / *in2; - } - - float val; -}; - -void THCudaTensor_addcdiv(THCState *state, THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2) -{ - THAssert(THCudaTensor_checkGPU(state, 4, self_, t, src1, src2)); - if(self_ != t) - { - THCudaTensor_resizeAs(state, self_, t); - THCudaTensor_copy(state, self_, t); - } - else - { - THArgCheck(THCudaTensor_nElement(state, self_) == THCudaTensor_nElement(state, src1), - 1, "sizes do not match"); - } - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), - 3, "sizes do not match"); - - if (!THC_pointwiseApply3(state, self_, src1, src2, TensorAddCDivOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - - THCudaCheck(cudaGetLastError()); -} - template <typename T> struct TensorFillOp { TensorFillOp(T v) : val(v) {} diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 7010ee3..3bafd59 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -42,9 +42,6 @@ THC_API void THCudaTensor_triu(THCState *state, THCudaTensor *self, THCudaTensor THC_API void THCudaTensor_diag(THCState *state, THCudaTensor *self, THCudaTensor *src, long k); THC_API float THCudaTensor_trace(THCState *state, THCudaTensor *self); -THC_API void THCudaTensor_addcmul(THCState *state, THCudaTensor *self, THCudaTensor* t, float value, THCudaTensor *src1, THCudaTensor *src2); -THC_API void THCudaTensor_addcdiv(THCState *state, THCudaTensor *self, THCudaTensor* t, float value, THCudaTensor *src1, THCudaTensor *src2); - THC_API void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); THC_API void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index 5a6de80..c52e082 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -507,4 +507,38 @@ struct TensorMinValueOp { T val; }; +template <typename T> +struct TensorAddCMulOp { + TensorAddCMulOp(T v) : val(v) {} + + __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) { + *out = THCNumerics<T>::add( + *out, + THCNumerics<T>::mul( + val, + THCNumerics<T>::mul(*in1, *in2) + ) + ); + } + + T val; +}; + +template <typename T> +struct TensorAddCDivOp { + TensorAddCDivOp(T v) : val(v) {} + + __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) { + *out = THCNumerics<T>::add( + *out, + THCNumerics<T>::mul( + val, + THCNumerics<T>::div(*in1, *in2) + ) + ); + } + + T val; +}; + #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index acb2d4c..90ddfbc 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -414,4 +414,53 @@ THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, real val } } +THC_API void +THCTensor_(addcmul)(THCState *state, THCTensor *self_, THCTensor *t, real value, THCTensor *src1, THCTensor *src2) +{ + THAssert(THCTensor_(checkGPU)(state, 4, self_, t, src1, src2)); + if(self_ != t) + { + THCTensor_(resizeAs)(state, self_, t); + THCTensor_(copy)(state, self_, t); + } + else + { + THArgCheck(THCTensor_(nElement)(state, self_) == THCTensor_(nElement)(state, src1), + 1, "sizes do not match"); + } + + THArgCheck(THCTensor_(nElement)(state, src1) == THCTensor_(nElement)(state, src2), + 3, "sizes do not match"); + + if (!THC_pointwiseApply3(state, self_, src1, src2, TensorAddCMulOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +THC_API void +THCTensor_(addcdiv)(THCState *state, THCTensor *self_, THCTensor *t, real value, THCTensor *src1, THCTensor *src2) +{ + THAssert(THCTensor_(checkGPU)(state, 4, self_, t, src1, src2)); + if(self_ != t) + { + THCTensor_(resizeAs)(state, self_, t); + THCTensor_(copy)(state, self_, t); + } + else + { + THArgCheck(THCTensor_(nElement)(state, self_) == THCTensor_(nElement)(state, src1), + 1, "sizes do not match"); + } + THArgCheck(THCTensor_(nElement)(state, src1) == THCTensor_(nElement)(state, src2), + 3, "sizes do not match"); + + if (!THC_pointwiseApply3(state, self_, src1, src2, TensorAddCDivOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + #endif diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index efbe76c..7a9d128 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -49,4 +49,7 @@ THC_API void THCTensor_(cmin)(THCState *state, THCTensor *self, THCTensor *src1, THC_API void THCTensor_(cmaxValue)(THCState *state, THCTensor *self, THCTensor *src, real value); THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, real value); +THC_API void THCTensor_(addcmul)(THCState *state, THCTensor *self, THCTensor* t, real value, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(addcdiv)(THCState *state, THCTensor *self, THCTensor* t, real value, THCTensor *src1, THCTensor *src2); + #endif |