Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2017-09-05 22:53:58 +0300
committerSoumith Chintala <soumith@gmail.com>2017-09-10 20:50:57 +0300
commitbc651e42cad1c9b9bd6a43af9d4513b29a74f8ca (patch)
treea85c67e617f4468563686196ecdd1340b4e212bc
parent326b5a5224bfef5a8f585e5340054e55d44cb453 (diff)
Optimize pow for different exponents and add tests
-rw-r--r--lib/THC/THCTensorMathPointwise.cuh73
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu54
2 files changed, 80 insertions, 47 deletions
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh
index 6ab010a..a37645d 100644
--- a/lib/THC/THCTensorMathPointwise.cuh
+++ b/lib/THC/THCTensorMathPointwise.cuh
@@ -264,60 +264,47 @@ struct TensorMulOp<half> {
};
#endif // CUDA_HALF_TENSOR
-template<typename T>
+template<typename T, int StaticExp>
struct TensorPowOp {
TensorPowOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
- *out = powf((float) *in, (float) val);
+ if (StaticExp == 1) {
+ *out = *in;
+ } else if (StaticExp == 2) {
+ *out = THCNumerics<T>::mul(*in, *in);
+ } else if (StaticExp == 3) {
+ *out = THCNumerics<T>::mul(*in, *in);
+ *out = THCNumerics<T>::mul(*out, *in);
+ } else if (StaticExp == -1) {
+ *out = THCNumerics<T>::cinv(*in);
+ } else if (StaticExp == -2) {
+ *out = THCNumerics<T>::mul(*in, *in);
+ *out = THCNumerics<T>::cinv(*out);
+ } else {
+ *out = THCNumerics<T>::pow(*in, val);
+ }
}
__device__ __forceinline__ void operator()(T* v) {
- *v = powf((float) *v, (float) val);
+ if (StaticExp == 1) {
+ *v = *v;
+ } else if (StaticExp == 2) {
+ *v = THCNumerics<T>::mul(*v, *v);
+ } else if (StaticExp == 3) {
+ *v = THCNumerics<T>::mul(THCNumerics<T>::mul(*v, *v), *v);
+ } else if (StaticExp == -1) {
+ *v = THCNumerics<T>::cinv(*v);
+ } else if (StaticExp == -2) {
+ *v = THCNumerics<T>::mul(*v, *v);
+ *v = THCNumerics<T>::cinv(*v);
+ } else {
+ *v = THCNumerics<T>::pow(*v, val);
+ }
}
const T val;
};
-template <>
-struct TensorPowOp<double> {
- TensorPowOp(double v) : val(v) {}
-
- __device__ __forceinline__ void operator()(double* out, double* in) {
- *out = pow(*in, val);
- }
-
- __device__ __forceinline__ void operator()(double* v) {
- *v = pow(*v, val);
- }
-
- const double val;
-};
-
-#ifdef CUDA_HALF_TENSOR
-template <>
-struct TensorPowOp<half> {
- TensorPowOp(half v) : val(v) {}
-
- __device__ __forceinline__ void operator()(half* out, half* in) {
- // No fp16 pow function yet
- float fin = __half2float(*in);
- float fval = __half2float(val);
- float fout = powf(fin, fval);
- *out = __float2half(fout);
- }
-
- __device__ __forceinline__ void operator()(half* v) {
- // No fp16 pow function yet
- float fv = __half2float(*v);
- float fval = __half2float(val);
- float fout = powf(fv, fval);
- *v = __float2half(fout);
- }
-
- const half val;
-};
-#endif // CUDA_HALF_TENSOR
-
template<typename T>
struct TensorTPowOp {
TensorTPowOp(T v) : val(v) {}
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index c9b4f8c..a2c714a 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -166,14 +166,60 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) {
void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real value) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
if (self_ == src) {
- if (!THC_pointwiseApply1(state, self_, TensorPowOp<real>(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(1))) {
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 1>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(2))) {
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 2>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(3))) {
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, 3>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-1))) {
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -1>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-2))) {
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -2>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // fallback implementation using pow
+ if (!THC_pointwiseApply1(state, self_, TensorPowOp<real, -3>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
}
} else {
THCTensor_(resizeAs)(state, self_, src);
- if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real>(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(1))) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 1>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(2))) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 2>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(3))) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, 3>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-1))) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -1>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else if (THCNumerics<real>::eq(value, ScalarConvert<int, real>::to(-2))) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -2>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ // fallback implementation using pow
+ if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real, -3>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
}
}