diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-27 20:21:20 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | d385fb68798d4f5258980a9b3c4ccbc75faeacfe (patch) | |
tree | 04db6e5da5f265a7e383c9714028c002890fba8b /lib/THC | |
parent | 104ef5a9cdfa15f28643d367dfb3201188f67a23 (diff) |
Make _norm(...)'s ops generic
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 40 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 42 |
2 files changed, 50 insertions, 32 deletions
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 1462378..04ce041 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -216,30 +216,6 @@ void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, l THCudaTensor_freeCopyTo(state, self, self_); } -template <int StaticExp> -struct TensorNormOp -{ - TensorNormOp(float exp) : exponent(exp) {} - - __host__ __device__ float operator()(float x) const { - if (StaticExp == 1) { - return fabsf(x); - } else if (StaticExp == 2) { - return x * x; - } else { - return powf(fabsf(x), exponent); - } - } - - const float exponent; -}; - -struct TensorNonZeroOp -{ - TensorNonZeroOp() {} - __host__ __device__ bool operator()(float lhs) const { return lhs != 0.0f; } -}; - float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) { THAssert(THCudaTensor_checkGPU(state, 1, self)); @@ -254,14 +230,14 @@ float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) #if CUDA_VERSION >= 7000 thrust::cuda::par.on(THCState_getCurrentStream(state)), #endif - self_data, self_data+size, TensorNonZeroOp(), + self_data, self_data+size, TensorNonZeroOp<float>(), 0.0f, thrust::plus<float>()); } else if (value == 1.0f) { result = thrust::transform_reduce( #if CUDA_VERSION >= 7000 thrust::cuda::par.on(THCState_getCurrentStream(state)), #endif - self_data, self_data+size, TensorNormOp<1>(value), + self_data, self_data+size, TensorNormOp<float, 1>(value), 0.0f, thrust::plus<float>()); } else if (value == 2.0f) { @@ -269,7 +245,7 @@ float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) #if CUDA_VERSION >= 7000 thrust::cuda::par.on(THCState_getCurrentStream(state)), #endif - self_data, self_data+size, TensorNormOp<2>(value), + self_data, self_data+size, TensorNormOp<float, 2>(value), 0.0f, thrust::plus<float>()); result = powf(result, 0.5f); @@ -278,7 +254,7 @@ float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) #if CUDA_VERSION >= 7000 thrust::cuda::par.on(THCState_getCurrentStream(state)), #endif - self_data, self_data+size, TensorNormOp<-1>(value), + self_data, self_data+size, TensorNormOp<float, -1>(value), 0.0f, thrust::plus<float>()); result = powf(result, 1.0f / value); } @@ -292,22 +268,22 @@ void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, f THAssert(THCudaTensor_checkGPU(state, 2, self, src)); if (value == 0.0f) { THC_reduceDim(state, self, src, - TensorNonZeroOp(), thrust::plus<float>(), + TensorNonZeroOp<float>(), thrust::plus<float>(), 0.0f, dimension); } else if (value == 1.0f) { THC_reduceDim(state, self, src, - TensorNormOp<1>(value), thrust::plus<float>(), + TensorNormOp<float, 1>(value), thrust::plus<float>(), 0.0f, dimension); } else if (value == 2.0f) { THC_reduceDim(state, self, src, - TensorNormOp<2>(value), thrust::plus<float>(), + TensorNormOp<float, 2>(value), thrust::plus<float>(), 0.0f, dimension); THCudaTensor_pow(state, self, self, 0.5f); } else { THC_reduceDim(state, self, src, - TensorNormOp<-1>(value), thrust::plus<float>(), + TensorNormOp<float, -1>(value), thrust::plus<float>(), 0.0f, dimension); THCudaTensor_pow(state, self, self, 1.0f / value); } diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index 15cb314..c5509c1 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -143,6 +143,48 @@ __global__ void THCTensor_kernel_renorm(Real *data, const Real value, const long } } +template <typename T> +struct TensorNonZeroOp +{ + TensorNonZeroOp() {} + __host__ __device__ bool operator()(T lhs) const { return lhs != 0.0; } +}; + +template <typename T, int StaticExp> +struct TensorNormOp +{ + TensorNormOp(T exp) : exponent(exp) {} + + __host__ __device__ float operator()(T x) const { + if (StaticExp == 1) { + return (T) fabsf((float) x); + } else if (StaticExp == 2) { + return x * x; + } else { + return (T) powf(fabsf((float) x), (float) exponent); + } + } + + const T exponent; +}; + +template <int StaticExp> +struct TensorNormOp<double, StaticExp> +{ + TensorNormOp(double exp) : exponent(exp) {} + + __host__ __device__ float operator()(double x) const { + if (StaticExp == 1) { + return fabs(x); + } else if (StaticExp == 2) { + return x * x; + } else { + return pow(fabs(x), exponent); + } + } + + const double exponent; +}; #include <thrust/functional.h> |