diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-29 19:22:46 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | c1dcfb6a54febeb6a45b9066553b621070b55c85 (patch) | |
tree | 52dd4de5c788cef12f21d902e2737f2d74104051 | |
parent | 2d31dbf0074fe16c6612a7a1ee144096f48a3917 (diff) |
[cutorch refactor] make _renorm(...)'s ops generic
-rw-r--r-- | lib/THC/THCNumerics.cuh | 12 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 59 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 48 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 25 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 6 |
6 files changed, 83 insertions, 68 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index 36ed0c8..543a544 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -474,6 +474,16 @@ struct THCNumerics<half> { #endif } + static inline __host__ __device__ half pow(half a, half b) { +#ifdef __CUDA_ARCH__ + float fa = __half2float(a); + float fb = __half2float(b); + return __float2half(powf(fa, fb)); +#else // __CUDA_ARCH__ + return THC_float2half(powf(THC_half2float(a), THC_half2float(b))); +#endif + } + }; #endif @@ -517,6 +527,7 @@ struct THCNumerics<float> { static inline __host__ __device__ float div (float a, float b) { return a / b; } static inline __host__ __device__ float mul (float a, float b) { return a * b; } static inline __host__ __device__ float sub (float a, float b) { return a - b; } + static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); } }; template <> @@ -559,6 +570,7 @@ struct THCNumerics<double> { static inline __host__ __device__ double div (double a, double b) { return a / b; } static inline __host__ __device__ double mul (double a, double b) { return a * b; } static inline __host__ __device__ double sub (double a, double b) { return a - b; } + static inline __host__ __device__ double pow (double a, double b) { return pow(a, b); } }; /// `half` has some type conversion issues associated with it, since it diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 9b70b01..439e2e1 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -75,6 +75,7 @@ THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCuda THC_API float THCudaTensor_varall(THCState *state, THCudaTensor *self); THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self); +THC_API void THCudaTensor_std(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value); THC_API void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension); THC_API void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float max_norm); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 7fd11ff..d0d47ad 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -216,6 +216,27 @@ void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, l THCudaTensor_freeCopyTo(state, self, self_); } +void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) +{ + THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); + THLongStorage *dim = THCudaTensor_newSizeOf(state, src); + THLongStorage_set(dim, dimension, 1); + THCudaTensor_resize(state, self_, dim, NULL); + THLongStorage_free(dim); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + src = THCudaTensor_newContiguous(state, src); + + if (dimension == THCudaTensor_nDimension(state, src) - 1) { + THCTensor_varInnermostDim<THCudaTensor, float, true>(state, self, src, flag); + } else { + THCTensor_varOuterDim<THCudaTensor, float, true>(state, self, src, dimension, flag); + } + + THCudaTensor_free(state, src); + THCudaTensor_freeCopyTo(state, self, self_); +} + template <int StaticExp> struct TensorNormOp { @@ -315,42 +336,6 @@ void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, f THCudaCheck(cudaGetLastError()); } -__global__ void THCudaTensor_kernel_renorm(float *data, const float value, const long size, const float maxnorm) -{ - __shared__ float buffer[32]; - long tx = threadIdx.x; - long bx = blockIdx.x; - long step = blockDim.x; - float *row = data + size*bx; - - buffer[tx] = 0; - - // get norm of axis - for (long i=tx; i<size; i+=step) - { - buffer[tx] += pow(fabs(row[i]), value); - } - // add (reduce) - for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1) - { - __syncthreads(); - if (tx < stride) - buffer[tx] += buffer[tx+stride]; - } - // clip norms - __syncthreads(); - float norm = pow(buffer[0], 1/value); - if (norm > maxnorm) - { - norm = maxnorm / (norm + 1e-7); - // renormalize - for (long i=tx; i<size; i+=step) - { - row[i] *= norm; - } - } -} - void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float maxnorm) { THAssert(THCudaTensor_checkGPU(state, 2, self, src)); @@ -366,7 +351,7 @@ void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, dim3 grid(data->size[0]); dim3 threads(32); - THCudaTensor_kernel_renorm<<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm); + THCTensor_kernel_renorm<float><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm); cudaError errcode = cudaGetLastError(); if(errcode != cudaSuccess) diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index 3bc0837..15cb314 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -95,6 +95,54 @@ struct LogicalAny { } }; +template<typename Real> +__global__ void THCTensor_kernel_renorm(Real *data, const Real value, const long size, const Real maxnorm) +{ + __shared__ Real buffer[32]; + long tx = threadIdx.x; + long bx = blockIdx.x; + long step = blockDim.x; + Real *row = data + size*bx; + + buffer[tx] = ScalarConvert<int, Real>::to(0); + + // get norm of axis + for (long i=tx; i<size; i+=step) + { + buffer[tx] = THCNumerics<Real>::add( + buffer[tx], + THCNumerics<Real>::pow( + THCNumerics<Real>::abs(row[i]), + value) + ); + } + // add (reduce) + for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1) + { + __syncthreads(); + if (tx < stride) + buffer[tx] = THCNumerics<Real>::add(buffer[tx], buffer[tx+stride]); + } + // clip norms + __syncthreads(); + Real norm = THCNumerics<Real>::pow(buffer[0], THCNumerics<Real>::cinv(value)); + if (THCNumerics<Real>::gt(norm, maxnorm)) + { + norm = THCNumerics<Real>::div( + maxnorm, + THCNumerics<Real>::add( + norm, + ScalarConvert<float, Real>::to(1e-7) + ) + ); + // renormalize + for (long i=tx; i<size; i+=step) + { + row[i] = THCNumerics<Real>::mul(row[i], norm); + } + } +} + #include <thrust/functional.h> diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 8f6b9c4..4ae3c0c 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -38,31 +38,6 @@ THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim) THCTensor_(div)(state, self, self, ScalarConvert<long, real>::to(THCTensor_(size)(state, src, dim))); } -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -void THCTensor_std(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag) -{ - THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); - THLongStorage *dim = THCTensor_(newSizeOf)(state, src); - THLongStorage_set(dim, dimension, 1); - THCTensor_(resize)(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCTensor *self = THCTensor_(newContiguous)(state, self_); - src = THCTensor_(newContiguous)(state, src); - - if (dimension == THCTensor_(nDimension)(state, src) - 1) { - THCTensor_varInnermostDim<THCTensor, real, true>(state, self, src, flag); - } else { - THCTensor_varOuterDim<THCTensor, real, true>(state, self, src, dimension, flag); - } - - THCTensor_(free)(state, src); - THCTensor_(freeCopyTo)(state, self, self_); -} - -#endif - THC_API accreal THCTensor_(sumall)(THCState *state, THCTensor *self) { THAssert(THCTensor_(checkGPU)(state, 1, self)); diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index bc37f85..500003f 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -2,12 +2,6 @@ #define THC_GENERIC_FILE "generic/THCTensorMathReduce.h" #else -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); - -#endif - THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim); |