diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-28 17:58:56 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | 2d31dbf0074fe16c6612a7a1ee144096f48a3917 (patch) | |
tree | d76d2a6bf325f4fc21bef41d5974f9a9083b02ac | |
parent | f8cb97a490d605ed19d15edafe23782ccf2ef96c (diff) |
[cutorch refactor] move std function into generic
-rw-r--r-- | lib/THC/THCNumerics.cuh | 45 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 25 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 61 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 25 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 6 |
6 files changed, 114 insertions, 49 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index a9d7897..36ed0c8 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -435,6 +435,45 @@ struct THCNumerics<half> { return THC_float2half(THC_half2float(a) + THC_half2float(b)); #endif } + + static inline __host__ __device__ half div(half a, half b) { +#ifdef __CUDA_ARCH__ + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half( fa / fb ); +#else // __CUDA_ARCH__ + return THC_float2half(THC_half2float(a) / THC_half2float(b)); +#endif + } + + static inline __host__ __device__ half mul(half a, half b) { +#ifdef __CUDA_ARCH__ +#ifdef CUDA_HALF_INSTRUCTIONS + return __hmul(a, b); +#else + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half( fa * fb ); +#endif +#else // __CUDA_ARCH__ + return THC_float2half(THC_half2float(a) * THC_half2float(b)); +#endif + } + + static inline __host__ __device__ half sub(half a, half b) { +#ifdef __CUDA_ARCH__ +#ifdef CUDA_HALF_INSTRUCTIONS + return __hsub(a, b); +#else + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half( fa - fb ); +#endif +#else // __CUDA_ARCH__ + return THC_float2half(THC_half2float(a) - THC_half2float(b)); +#endif + } + }; #endif @@ -475,6 +514,9 @@ struct THCNumerics<float> { static inline __host__ __device__ float frac (float a) { return a - truncf(a); } static inline __host__ __device__ float cinv (float a) { return 1.0f / a; } static inline __host__ __device__ float add (float a, float b) { return a + b; } + static inline __host__ __device__ float div (float a, float b) { return a / b; } + static inline __host__ __device__ float mul (float a, float b) { return a * b; } + static inline __host__ __device__ float sub (float a, float b) { return a - b; } }; template <> @@ -514,6 +556,9 @@ struct THCNumerics<double> { static inline __host__ __device__ double frac (double a) { return a - ::trunc(a); } static inline __host__ __device__ double cinv (double a) { return 1.0 / a; } static inline __host__ __device__ double add (double a, double b) { return a + b; } + static inline __host__ __device__ double div (double a, double b) { return a / b; } + static inline __host__ __device__ double mul (double a, double b) { return a * b; } + static inline __host__ __device__ double sub (double a, double b) { return a - b; } }; /// `half` has some type conversion issues associated with it, since it diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 439e2e1..9b70b01 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -75,7 +75,6 @@ THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCuda THC_API float THCudaTensor_varall(THCState *state, THCudaTensor *self); THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self); -THC_API void THCudaTensor_std(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value); THC_API void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension); THC_API void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float max_norm); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index ad1244f..7fd11ff 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -207,30 +207,9 @@ void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, l src = THCudaTensor_newContiguous(state, src); if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCTensor_varInnermostDim<THCudaTensor, false>(state, self, src, flag); + THCTensor_varInnermostDim<THCudaTensor, float, false>(state, self, src, flag); } else { - THCTensor_varOuterDim<THCudaTensor, false>(state, self, src, dimension, flag); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCTensor_varInnermostDim<THCudaTensor, true>(state, self, src, flag); - } else { - THCTensor_varOuterDim<THCudaTensor, true>(state, self, src, dimension, flag); + THCTensor_varOuterDim<THCudaTensor, float, false>(state, self, src, dimension, flag); } THCudaTensor_free(state, src); diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index a7ab306..3bc0837 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -101,20 +101,26 @@ struct LogicalAny { // Given the sum of values and the sum of squares, compute the variance or standard deviation. template<typename Real, bool flag, bool apply_sqrt> __forceinline__ __device__ Real THCTensor_computeVar(Real sum, Real sum2, unsigned row_size) { + Real rs2 = ScalarConvert<unsigned, Real>::to(row_size); + Real rs2m = ScalarConvert<unsigned, Real>::to(row_size - 1); + Real zero = ScalarConvert<int, Real>::to(0); if (flag) { - sum /= row_size; - sum2 /= row_size; - sum2 -= sum * sum; - sum2 = (sum2 < 0 ? 0 : sum2); + sum = THCNumerics<Real>::div(sum, rs2); + sum2 = THCNumerics<Real>::div(sum2, rs2); + sum2 = THCNumerics<Real>::sub(sum2, THCNumerics<Real>::mul(sum, sum)); + sum2 = (THCNumerics<Real>::lt(sum2, zero) ? zero : sum2); } else { - sum /= row_size; - sum2 /= row_size - 1; - sum2 -= ((Real)row_size) / ((Real)(row_size - 1)) * sum * sum; - sum2 = (sum2 < 0 ? 0 : sum2); + sum = THCNumerics<Real>::div(sum, rs2); + sum2 = THCNumerics<Real>::div(sum2, rs2m); + sum2 = THCNumerics<Real>::sub(sum2, + THCNumerics<Real>::mul( + THCNumerics<Real>::div(rs2 ,rs2m), + THCNumerics<Real>::mul(sum, sum))); + sum2 = (THCNumerics<Real>::lt(sum2, zero) ? zero : sum2); } if (apply_sqrt) - return sqrt(sum2); + return THCNumerics<Real>::sqrt(sum2); else return sum2; } @@ -138,12 +144,15 @@ __global__ void THCTensor_kernel_varOuterDim(Real *tgt, Real *src_, unsigned num for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { Real *src = src_ + orow * row_size * num_irows + irow; - Real sum = 0, sum2 = 0; + Real sum = ScalarConvert<int, Real>::to(0), sum2 = ScalarConvert<int, Real>::to(0); for (unsigned col = 0; col < row_size; ++col) { Real val = *src; - sum += val; - sum2 += val * val; + sum = THCNumerics<Real>::add(sum, val); + sum2 = THCNumerics<Real>::add( + sum2, + THCNumerics<Real>::mul(val, val) + ); src += num_irows; } @@ -153,7 +162,7 @@ __global__ void THCTensor_kernel_varOuterDim(Real *tgt, Real *src_, unsigned num } } -template<typename TensorTypeK, bool apply_sqrt> +template<typename TensorTypeK, typename Real, bool apply_sqrt> __host__ void THCTensor_varOuterDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, long dimension, int flag) { unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src); @@ -174,10 +183,10 @@ __host__ void THCTensor_varOuterDim(THCState *state, TensorTypeK *tgt, TensorTyp dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x))); if (flag) { - THCTensor_kernel_varOuterDim<float, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( + THCTensor_kernel_varOuterDim<Real, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size); } else { - THCTensor_kernel_varOuterDim<float, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( + THCTensor_kernel_varOuterDim<Real, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size); } cudaError errcode = cudaGetLastError(); @@ -206,14 +215,14 @@ __global__ void THCTensor_kernel_varInnermostDim(Real *tgt, Real *src_, unsigned for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { unsigned row = block_row + threadIdx.y; - Real sum = 0, sum2 = 0; + Real sum = ScalarConvert<int, Real>::to(0), sum2 = ScalarConvert<int, Real>::to(0); if (row < num_rows) { Real *src = src_ + row * row_size; // Sequential reduction within a thread. for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { Real val = src[col]; - sum += val; - sum2 += val * val; + sum = THCNumerics<Real>::add(sum, val); + sum2 = THCNumerics<Real>::add(sum2, THCNumerics<Real>::mul(val, val)); } } ssum[threadIdx.y][threadIdx.x] = sum; @@ -223,22 +232,24 @@ __global__ void THCTensor_kernel_varInnermostDim(Real *tgt, Real *src_, unsigned // Reduce intermediate values to single value. for (unsigned s = 8; s > 1; s >>= 1) { if (row < num_rows && threadIdx.x < s) { - ssum[threadIdx.y][threadIdx.x] += ssum[threadIdx.y][threadIdx.x + s]; - ssum2[threadIdx.y][threadIdx.x] += ssum2[threadIdx.y][threadIdx.x + s]; + ssum[threadIdx.y][threadIdx.x] = + THCNumerics<Real>::add(ssum[threadIdx.y][threadIdx.x], ssum[threadIdx.y][threadIdx.x + s]); + ssum2[threadIdx.y][threadIdx.x] = + THCNumerics<Real>::add(ssum2[threadIdx.y][threadIdx.x], ssum2[threadIdx.y][threadIdx.x + s]); } __syncthreads(); } if (row < num_rows && threadIdx.x == 0) { - sum = ssum[threadIdx.y][0] + ssum[threadIdx.y][1]; - sum2 = ssum2[threadIdx.y][0] + ssum2[threadIdx.y][1]; + sum = THCNumerics<Real>::add(ssum[threadIdx.y][0], ssum[threadIdx.y][1]); + sum2 = THCNumerics<Real>::add(ssum2[threadIdx.y][0], ssum2[threadIdx.y][1]); tgt[row] = THCTensor_computeVar<Real, flag, apply_sqrt>(sum, sum2, row_size); } __syncthreads(); } } -template<typename TensorTypeK, bool apply_sqrt> +template<typename TensorTypeK, typename Real, bool apply_sqrt> __host__ void THCTensor_varInnermostDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, int flag) { unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src); @@ -254,10 +265,10 @@ __host__ void THCTensor_varInnermostDim(THCState *state, TensorTypeK *tgt, Tenso dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y))); if (flag) { - THCTensor_kernel_varInnermostDim<float, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( + THCTensor_kernel_varInnermostDim<Real, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_rows, row_size); } else { - THCTensor_kernel_varInnermostDim<float, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( + THCTensor_kernel_varInnermostDim<Real, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>( TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_rows, row_size); } cudaError errcode = cudaGetLastError(); diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 4ae3c0c..8f6b9c4 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -38,6 +38,31 @@ THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim) THCTensor_(div)(state, self, self, ScalarConvert<long, real>::to(THCTensor_(size)(state, src, dim))); } +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +void THCTensor_std(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); + THLongStorage *dim = THCTensor_(newSizeOf)(state, src); + THLongStorage_set(dim, dimension, 1); + THCTensor_(resize)(state, self_, dim, NULL); + THLongStorage_free(dim); + + THCTensor *self = THCTensor_(newContiguous)(state, self_); + src = THCTensor_(newContiguous)(state, src); + + if (dimension == THCTensor_(nDimension)(state, src) - 1) { + THCTensor_varInnermostDim<THCTensor, real, true>(state, self, src, flag); + } else { + THCTensor_varOuterDim<THCTensor, real, true>(state, self, src, dimension, flag); + } + + THCTensor_(free)(state, src); + THCTensor_(freeCopyTo)(state, self, self_); +} + +#endif + THC_API accreal THCTensor_(sumall)(THCState *state, THCTensor *self) { THAssert(THCTensor_(checkGPU)(state, 1, self)); diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index 500003f..bc37f85 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -2,6 +2,12 @@ #define THC_GENERIC_FILE "generic/THCTensorMathReduce.h" #else +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); + +#endif + THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim); |