Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2017-09-05 22:53:58 +0300
committerSoumith Chintala <soumith@gmail.com>2017-09-10 20:50:50 +0300
commitf73bb16bc8e7f2d17d09a149fa00b05eada822d7 (patch)
treeb2d9aa1eeac53d9b1a82c9fba86707b23ca16080
parent76a65293437cdac53497180aa8b89a057cc73b6a (diff)
Optimize pow for different exponents and add tests
-rw-r--r--lib/TH/generic/THTensorMath.c37
1 files changed, 29 insertions, 8 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index 43cbf83..1ed4ee2 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -2856,13 +2856,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \
} \
-#define LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(NAME, CFUNC) \
- void THTensor_(NAME)(THTensor *r_, THTensor *t, real value) \
- { \
- THTensor_(resizeAs)(r_, t); \
- TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data, value);); \
- } \
-
#if defined(TH_REAL_IS_LONG)
LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs)
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
@@ -2912,7 +2905,6 @@ LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh))
LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan))
LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan))
LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh))
-LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,TH_MATH_NAME(pow))
LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt))
LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt))
LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil))
@@ -2925,6 +2917,35 @@ LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
+void THTensor_(pow)(THTensor *r_, THTensor *t, real value)
+{
+ THTensor_(resizeAs)(r_, t);
+ if(value == 1){
+ THTensor_(copy)(r_, t);
+ }
+ else if(value == 2){
+ THTensor_(cmul)(r_, t, t);
+ }
+ else if(value == 3){
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = *t_data * *t_data * *t_data;);
+ }
+ else if(value == 0.5){
+ THTensor_(sqrt)(r_, t);
+ }
+ else if(value == -0.5){
+ THTensor_(rsqrt)(r_, t);
+ }
+ else if(value == -1){
+ THTensor_(cinv)(r_, t);
+ }
+ else if(value == -2){
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data););
+ }
+ else{
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(pow)(*t_data, value););
+ }
+}
+
void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty)
{
THTensor_(resizeAs)(r_, tx);