diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2017-09-05 22:53:58 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-09-10 20:50:57 +0300 |
commit | bc651e42cad1c9b9bd6a43af9d4513b29a74f8ca (patch) | |
tree | a85c67e617f4468563686196ecdd1340b4e212bc | |
parent | 326b5a5224bfef5a8f585e5340054e55d44cb453 (diff) |
Optimize pow for different exponents and add tests
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 73 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 54 |
2 files changed, 80 insertions, 47 deletions
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index 6ab010a..a37645d 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -264,60 +264,47 @@ struct TensorMulOp<half> { }; #endif // CUDA_HALF_TENSOR -template<typename T> +template<typename T, int StaticExp> struct TensorPowOp { TensorPowOp(T v) : val(v) {} __device__ __forceinline__ void operator()(T* out, T* in) { - *out = powf((float) *in, (float) val); + if (StaticExp == 1) { + *out = *in; + } else if (StaticExp == 2) { + *out = THCNumerics<T>::mul(*in, *in); + } else if (StaticExp == 3) { + *out = THCNumerics<T>::mul(*in, *in); + *out = THCNumerics<T>::mul(*out, *in); + } else if (StaticExp == -1) { + *out = THCNumerics<T>::cinv(*in); + } else if (StaticExp == -2) { + *out = THCNumerics<T>::mul(*in, *in); + *out = THCNumerics<T>::cinv(*out); + } else { + *out = THCNumerics<T>::pow(*in, val); + } } __device__ __forceinline__ void operator()(T* v) { - *v = powf((float) *v, (float) val); + if (StaticExp == 1) { + *v = *v; + } else if (StaticExp == 2) { + *v = THCNumerics<T>::mul(*v, *v); + } else if (StaticExp == 3) { + *v = THCNumerics<T>::mul(THCNumerics<T>::mul(*v, *v), *v); + } else if (StaticExp == -1) { + *v = THCNumerics<T>::cinv(*v); + } else if (StaticExp == -2) { + *v = THCNumerics<T>::mul(*v, *v); + *v = THCNumerics<T>::cinv(*v); + } else { + *v = THCNumerics<T>::pow(*v, val); + } } const T val; }; -template <> -struct TensorPowOp<double> { - TensorPowOp(double v) : val(v) {} - - __device__ __forceinline__ void operator()(double* out, double* in) { - *out = pow(*in, val); - } - - __device__ __forceinline__ void operator()(double* v) { - *v = pow(*v, val); - } - - const double val; -}; - -#ifdef CUDA_HALF_TENSOR -template <> -struct TensorPowOp<half> { - TensorPowOp(half v) : val(v) {} - - __device__ __forceinline__ void operator()(half* out, half* in) { - // No fp16 pow function yet - float fin = __half2float(*in); - float fval = __half2float(val); - float fout = powf(fin, fval); - *out = __float2half(fout); - } - - __device__ __forceinline__ void operator()(half* v) { - // No fp16 pow function yet - float fv = __half2float(*v); - float fval = __half2float(val); - float fout = powf(fv, fval); - *v = __float2half(fout); - } - - const half val; -}; -#endif // CUDA_HALF_TENSOR - template<typename T> struct TensorTPowOp { TensorTPowOp(T v) : val(v) {} diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index c9b4f8c..a2c714a 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -166,14 +166,60 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); if (self_ == src) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp<real>(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); + if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(1))) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 1>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(2))) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 2>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(3))) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 3>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-1))) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -1>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-2))) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -2>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + // fallback implementation using pow + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -3>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } } } else { THCTensor_(resizeAs)(state, self_, src); - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real>(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); + if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(1))) { + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 1>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(2))) { + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 2>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(3))) { + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 3>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-1))) { + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -1>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-2))) { + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -2>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + // fallback implementation using pow + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -3>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } } } |