diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-27 20:21:20 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | e103982479048babf915b2afcf0012593409b9e5 (patch) | |
tree | 3eec281a83f0ac0c8558eebebdc7f372f49791bf /lib/THC | |
parent | 72123175666a3860a093fc7fa5ea26628b89d379 (diff) |
Make _norm(...)'s ops generic
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 29 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 33 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 8 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 2 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 2 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 28 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 2 |
9 files changed, 48 insertions, 58 deletions
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index c1e5728..a4b9ed7 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -76,7 +76,6 @@ THC_API float THCudaTensor_varall(THCState *state, THCudaTensor *self); THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self); THC_API float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value); -THC_API void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension); THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 55c225f..778a0f9 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -6,6 +6,7 @@ #include "THCApply.cuh" #include "THCReduce.cuh" #include "THCTensorMathReduce.cuh" +#include "THCTensorMathPointwise.cuh" #include <thrust/device_ptr.h> #include <thrust/transform_reduce.h> @@ -232,34 +233,6 @@ float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) return result; } -void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self, src)); - if (value == 0.0f) { - THC_reduceDim(state, self, src, - TensorNonZeroOp<float>(), thrust::plus<float>(), - 0.0f, dimension); - } else if (value == 1.0f) { - THC_reduceDim(state, self, src, - TensorNormOp<float, 1>(value), thrust::plus<float>(), - 0.0f, dimension); - - } else if (value == 2.0f) { - THC_reduceDim(state, self, src, - TensorNormOp<float, 2>(value), thrust::plus<float>(), - 0.0f, dimension); - THCudaTensor_pow(state, self, self, 0.5f); - - } else { - THC_reduceDim(state, self, src, - TensorNormOp<float, -1>(value), thrust::plus<float>(), - 0.0f, dimension); - THCudaTensor_pow(state, self, self, 1.0f / value); - } - - THCudaCheck(cudaGetLastError()); -} - struct dist_functor { const float exponent; diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index c5509c1..549909c 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -17,6 +17,7 @@ struct ReduceAdd { }; #ifdef CUDA_HALF_TENSOR + template <> struct ReduceAdd<half, half> { inline __device__ half operator()(half a, half b) const { @@ -147,7 +148,13 @@ template <typename T> struct TensorNonZeroOp { TensorNonZeroOp() {} - __host__ __device__ bool operator()(T lhs) const { return lhs != 0.0; } + __host__ __device__ T operator()(T lhs) const { + if (THCNumerics<T>::eq(lhs, ScalarConvert<float, T>::to(0.0))) { + return ScalarConvert<int, T>::to(0); + } else { + return ScalarConvert<int, T>::to(1); + } + } }; template <typename T, int StaticExp> @@ -155,7 +162,7 @@ struct TensorNormOp { TensorNormOp(T exp) : exponent(exp) {} - __host__ __device__ float operator()(T x) const { + __host__ __device__ T operator()(T x) const { if (StaticExp == 1) { return (T) fabsf((float) x); } else if (StaticExp == 2) { @@ -173,7 +180,7 @@ struct TensorNormOp<double, StaticExp> { TensorNormOp(double exp) : exponent(exp) {} - __host__ __device__ float operator()(double x) const { + __host__ __device__ double operator()(double x) const { if (StaticExp == 1) { return fabs(x); } else if (StaticExp == 2) { @@ -186,6 +193,26 @@ struct TensorNormOp<double, StaticExp> const double exponent; }; +#ifdef CUDA_HALF_TENSOR +template <int StaticExp> +struct TensorNormOp<half, StaticExp> +{ + TensorNormOp(half exp) : exponent(exp) {} + + __host__ __device__ half operator()(half x) const { + if (StaticExp == 1) { + return THCNumerics<half>::abs(x); + } else if (StaticExp == 2) { + return THCNumerics<half>::mul(x, x); + } else { + return THCNumerics<half>::pow(THCNumerics<half>::abs(x), exponent); + } + } + + const half exponent; +}; +#endif + #include <thrust/functional.h> // Given the sum of values and the sum of squares, compute the variance or standard deviation. diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 9ffc89b..a0e550a 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -36,14 +36,6 @@ THCTensor_(zero)(THCState *state, THCTensor *self_) } THC_API void -THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim) -{ - THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - THCudaTensor_sum(state, self, src, dim); - THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim)); -} - -THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size) { THAssert(THCTensor_(checkGPU)(state, 1, r_)); diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 6b59262..5c9e66d 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -4,7 +4,6 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); -THC_API void THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size); THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size); diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 7cbd00f..b2f8950 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -101,7 +101,7 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { THCudaCheck(cudaGetLastError()); } -void THCTensor_pow(THCState *state, THCTensor *self_, THCTensor *src, real value) { +void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real value) { THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); if (self_ == src) { if (!THC_pointwiseApply1(state, self_, TensorPowOp<real>(value))) { diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index af1ad1c..af50278 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -18,7 +18,7 @@ THC_API void THCTensor_(tan)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(atan)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(atan2)(THCState *state, THCTensor *r_, THCTensor *tx, THCTensor *ty); THC_API void THCTensor_(tanh)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, float value); +THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, real value); THC_API void THCTensor_(tpow)(THCState *state, THCTensor *self, float value, THCTensor *src); THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(rsqrt)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 90cc3e3..0bdd540 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -203,29 +203,29 @@ THCTensor_(min)(THCState *state, #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) +THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) { THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - if (value == 0.0) { + if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { THC_reduceDim(state, self, src, - TensorNonZeroOp<real>(), thrust::plus<real>(), - 0.0, dimension); - } else if (value == 1.0) { + TensorNonZeroOp<real>(), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { THC_reduceDim(state, self, src, - TensorNormOp<real, 1>(), thrust::plus<real>(), - 0.0, dimension); + TensorNormOp<real, 1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); - } else if (value == 2.0) { + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { THC_reduceDim(state, self, src, - TensorNormOp<real, 2>(), thrust::plus<real>(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 0.5); + TensorNormOp<real, 2>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5)); } else { THC_reduceDim(state, self, src, - TensorNormOp<real, -1>(), thrust::plus<real>(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 1.0 / value); + TensorNormOp<real, -1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value)); } THCudaCheck(cudaGetLastError()); diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index d25317e..9e19f52 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -6,7 +6,7 @@ THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm); THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); -THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); +THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); #endif |