diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2017-09-05 22:53:58 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-09-10 20:50:50 +0300 |
commit | f73bb16bc8e7f2d17d09a149fa00b05eada822d7 (patch) | |
tree | b2d9aa1eeac53d9b1a82c9fba86707b23ca16080 | |
parent | 76a65293437cdac53497180aa8b89a057cc73b6a (diff) |
Optimize pow for different exponents and add tests
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 43cbf83..1ed4ee2 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -2856,13 +2856,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=) TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ } \ -#define LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(NAME, CFUNC) \ - void THTensor_(NAME)(THTensor *r_, THTensor *t, real value) \ - { \ - THTensor_(resizeAs)(r_, t); \ - TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data, value);); \ - } \ - #if defined(TH_REAL_IS_LONG) LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs) LAB_IMPLEMENT_BASIC_FUNCTION(neg,-) @@ -2912,7 +2905,6 @@ LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh)) LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan)) LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan)) LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh)) -LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,TH_MATH_NAME(pow)) LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt)) LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt)) LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil)) @@ -2925,6 +2917,35 @@ LAB_IMPLEMENT_BASIC_FUNCTION(neg,-) LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / ) +void THTensor_(pow)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + if(value == 1){ + THTensor_(copy)(r_, t); + } + else if(value == 2){ + THTensor_(cmul)(r_, t, t); + } + else if(value == 3){ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = *t_data * *t_data * *t_data;); + } + else if(value == 0.5){ + THTensor_(sqrt)(r_, t); + } + else if(value == -0.5){ + THTensor_(rsqrt)(r_, t); + } + else if(value == -1){ + THTensor_(cinv)(r_, t); + } + else if(value == -2){ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data);); + } + else{ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(pow)(*t_data, value);); + } +} + void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty) { THTensor_(resizeAs)(r_, tx); |