diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-30 00:15:08 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | aa39a6cd8aa0f2078f85618727443e5456815900 (patch) | |
tree | 3c65afa4897b52d47a64558a5f87c1c8f4c0de51 /lib/THC | |
parent | 498d644156131b1ea60294175381d48c7a904c65 (diff) |
[cutorch refactor] move varall into generic
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCNumerics.cuh | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 34 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 27 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 30 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 2 |
6 files changed, 60 insertions, 35 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index 543a544..c718ba5 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -91,6 +91,7 @@ struct THCNumerics<long> { static inline __host__ __device__ long add(long a, long b) { return a + b; } static inline __host__ __device__ long abs(long a) { return labs(a); } + static inline __host__ __device__ long div(long a, long b) { return a / b; }; }; #ifdef CUDA_HALF_TENSOR diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 099db80..8ee93d9 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -72,7 +72,6 @@ THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *r THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension); THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension); -THC_API float THCudaTensor_varall(THCState *state, THCudaTensor *self); THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self); THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 491bf6a..a9dc277 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -125,40 +125,6 @@ void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, T THCudaCheck(cudaGetLastError()); } -struct square_functor -{ - const float mean; - - square_functor(float mean_) : mean(mean_) {} - - __host__ __device__ float operator()(const float& x) const - { - return (x-mean)*(x-mean); - } -}; - -float THCudaTensor_varall(THCState *state, THCudaTensor *self) -{ - THAssert(THCudaTensor_checkGPU(state, 1, self)); - self = THCudaTensor_newContiguous(state, self); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr<float> self_data(THCudaTensor_data(state, self)); - - float mean = THCudaTensor_meanall(state, self); - float result = - thrust::transform_reduce( -#if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), -#endif - self_data, self_data+size, square_functor(mean), - (float)0, thrust::plus<float>()); - - result = result/(THCudaTensor_nElement(state, self)-1); - - THCudaTensor_free(state, self); - return result; -} - float THCudaTensor_stdall(THCState *state, THCudaTensor *self) { THAssert(THCudaTensor_checkGPU(state, 1, self)); diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index 549909c..df8e290 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -68,6 +68,33 @@ struct ReduceMultiply<half, float> { }; #endif // CUDA_HALF_TENSOR +template <typename ResT, typename ArgT> +struct SquareFunctor { + SquareFunctor(ResT mean): mean_(mean) {} + + inline __device__ ResT operator()(ArgT x) const { + return (((ResT) x) - mean_) * (((ResT) x) - mean_); + } + + const ResT mean_; +}; + +#ifdef CUDA_HALF_TENSOR +template <typename ResT> +struct SquareFunctor<ResT, half> { + SquareFunctor(ResT mean): mean_(mean) {} + + inline __device__ ResT operator()(half x) const { + return THCNumerics<ResT>::mul( + THCNumerics<ResT>::sub(mean_, ScalarConvert<half, ResT>::to(x)), + THCNumerics<ResT>::sub(mean_, ScalarConvert<half, ResT>::to(x)) + ); + } + + const ResT mean_; +}; +#endif // CUDA_HALF_TENSOR + template <typename T> struct ReduceMin { inline __device__ T operator()(T a, T b) const { diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index bb34971..4103d48 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -89,6 +89,31 @@ void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dim THCTensor_(freeCopyTo)(state, self, self_); } +THC_API accreal +THCTensor_(varall)(THCState *state, THCTensor *self) +{ + THAssert(THCTensor_(checkGPU)(state, 1, self)); + accreal mean = THCTensor_(meanall)(state, self); + + accreal val; + if (!THC_reduceAll(state, self, + SquareFunctor<accreal, real>(mean), + ReduceAdd<accreal, accreal>(), + ReduceAdd<accreal, accreal>(), + ScalarConvert<int, accreal>::to(0), + &val, 0)) { + THArgCheck(false, 1, CUTORCH_DIM_WARNING); + } + + val = THCNumerics<accreal>::div( + val, + ScalarConvert<int, accreal>::to(THCTensor_(nElement)(state, self) - 1) + ); + + THCudaCheck(cudaGetLastError()); + return val; +} + #endif THC_API accreal @@ -121,6 +146,11 @@ THCTensor_(prodall)(THCState *state, THCTensor *self) { THArgCheck(false, 1, CUTORCH_DIM_WARNING); } + val = THCNumerics<accreal>::div( + val, + ScalarConvert<long, accreal>::to(THCTensor_(nElement)(state, self)) - 1 + ); + THCudaCheck(cudaGetLastError()); return val; } diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index e106dd6..3aefb17 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -10,6 +10,8 @@ THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value); +THC_API accreal THCTensor_(varall)(THCState *state, THCTensor *self); + #endif THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim); |