From e103982479048babf915b2afcf0012593409b9e5 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 27 Sep 2016 10:21:20 -0700 Subject: Make _norm(...)'s ops generic --- lib/THC/generic/THCTensorMath.cu | 8 -------- lib/THC/generic/THCTensorMath.h | 1 - lib/THC/generic/THCTensorMathPointwise.cu | 2 +- lib/THC/generic/THCTensorMathPointwise.h | 2 +- lib/THC/generic/THCTensorMathReduce.cu | 28 ++++++++++++++-------------- lib/THC/generic/THCTensorMathReduce.h | 2 +- 6 files changed, 17 insertions(+), 26 deletions(-) (limited to 'lib/THC/generic') diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 9ffc89b..a0e550a 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -35,14 +35,6 @@ THCTensor_(zero)(THCState *state, THCTensor *self_) THCudaCheck(cudaGetLastError()); } -THC_API void -THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim) -{ - THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - THCudaTensor_sum(state, self, src, dim); - THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim)); -} - THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size) { diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 6b59262..5c9e66d 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -4,7 +4,6 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); -THC_API void THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size); THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size); diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 7cbd00f..b2f8950 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -101,7 +101,7 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { THCudaCheck(cudaGetLastError()); } -void THCTensor_pow(THCState *state, THCTensor *self_, THCTensor *src, real value) { +void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real value) { THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); if (self_ == src) { if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index af1ad1c..af50278 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -18,7 +18,7 @@ THC_API void THCTensor_(tan)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(atan)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(atan2)(THCState *state, THCTensor *r_, THCTensor *tx, THCTensor *ty); THC_API void THCTensor_(tanh)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, float value); +THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, real value); THC_API void THCTensor_(tpow)(THCState *state, THCTensor *self, float value, THCTensor *src); THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(rsqrt)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 90cc3e3..0bdd540 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -203,29 +203,29 @@ THCTensor_(min)(THCState *state, #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) +THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) { THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - if (value == 0.0) { + if (THCNumerics::eq(value, ScalarConvert::to(0.0))) { THC_reduceDim(state, self, src, - TensorNonZeroOp(), thrust::plus(), - 0.0, dimension); - } else if (value == 1.0) { + TensorNonZeroOp(), ReduceAdd(), + ScalarConvert::to(0.0), dimension); + } else if (THCNumerics::eq(value, ScalarConvert::to(1.0))) { THC_reduceDim(state, self, src, - TensorNormOp(), thrust::plus(), - 0.0, dimension); + TensorNormOp(value), ReduceAdd(), + ScalarConvert::to(0.0), dimension); - } else if (value == 2.0) { + } else if (THCNumerics::eq(value, ScalarConvert::to(2.0))) { THC_reduceDim(state, self, src, - TensorNormOp(), thrust::plus(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 0.5); + TensorNormOp(value), ReduceAdd(), + ScalarConvert::to(0.0), dimension); + THCTensor_(pow)(state, self, self, ScalarConvert::to(0.5)); } else { THC_reduceDim(state, self, src, - TensorNormOp(), thrust::plus(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 1.0 / value); + TensorNormOp(value), ReduceAdd(), + ScalarConvert::to(0.0), dimension); + THCTensor_(pow)(state, self, self, THCNumerics::cinv(value)); } THCudaCheck(cudaGetLastError()); diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index d25317e..9e19f52 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -6,7 +6,7 @@ THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm); THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); -THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); +THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); #endif -- cgit v1.2.3