diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-09-28 19:19:21 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:27 +0300 |
commit | 681c2db9676c98be96c6c2becaea277ee3293604 (patch) | |
tree | 470317545e24b26fffbe1b6c1f5e679e97c7052a | |
parent | b490d8d22ae5cd900001346239a33ba6cde1f3e0 (diff) |
[cutorch refactor] move renorm function into generic
-rw-r--r-- | lib/THC/THCNumerics.cuh | 6 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 28 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 32 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 6 |
5 files changed, 38 insertions, 35 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index 7513764..543a544 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -527,10 +527,7 @@ struct THCNumerics<float> { static inline __host__ __device__ float div (float a, float b) { return a / b; } static inline __host__ __device__ float mul (float a, float b) { return a * b; } static inline __host__ __device__ float sub (float a, float b) { return a - b; } -<<<<<<< 428317be4d5d9423ccb6ac2701c55baaffb8737b static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); } -======= ->>>>>>> [cutorch refactor] move std function into generic }; template <> @@ -573,10 +570,7 @@ struct THCNumerics<double> { static inline __host__ __device__ double div (double a, double b) { return a / b; } static inline __host__ __device__ double mul (double a, double b) { return a * b; } static inline __host__ __device__ double sub (double a, double b) { return a - b; } -<<<<<<< 428317be4d5d9423ccb6ac2701c55baaffb8737b static inline __host__ __device__ double pow (double a, double b) { return pow(a, b); } -======= ->>>>>>> [cutorch refactor] move std function into generic }; /// `half` has some type conversion issues associated with it, since it diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 439e2e1..544922a 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -78,7 +78,6 @@ THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self); THC_API void THCudaTensor_std(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value); THC_API void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension); -THC_API void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float max_norm); THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index d0d47ad..23346dd 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -336,34 +336,6 @@ void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, f THCudaCheck(cudaGetLastError()); } -void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float maxnorm) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self, src)); - THCudaTensor *self_; - THCudaTensor *src_ = THCudaTensor_newTranspose(state, src, dimension, 0); - THCudaTensor *data = THCudaTensor_newClone(state, src_); - long size = THCudaTensor_nElement(state, data)/data->size[0]; - - THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "invalid dimension"); - THArgCheck(value > 0, 2, "non-positive-norm not supported"); - THArgCheck(THCudaTensor_nDimension(state, src) > 1, 1, "need at least 2 dimensions"); - - dim3 grid(data->size[0]); - dim3 threads(32); - - THCTensor_kernel_renorm<float><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm); - - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); - - THCudaTensor_free(state, src_); - self_ = THCudaTensor_newTranspose(state, data, dimension, 0); - THCudaTensor_resizeAs(state, self, self_); - THCudaTensor_freeCopyTo(state, self_, self); - THCudaTensor_free(state, data); -} - struct dist_functor { const float exponent; diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 4ae3c0c..b16d4a5 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -38,6 +38,38 @@ THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim) THCTensor_(div)(state, self, self, ScalarConvert<long, real>::to(THCTensor_(size)(state, src, dim))); } +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real maxnorm) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self, src)); + THCTensor *self_; + THCTensor *src_ = THCTensor_(newTranspose)(state, src, dimension, 0); + THCTensor *data = THCTensor_(newClone)(state, src_); + long size = THCTensor_(nElement)(state, data)/data->size[0]; + + THArgCheck(dimension >= 0 && dimension < THCTensor_(nDimension)(state, src), 3, "invalid dimension"); + THArgCheck(THCNumerics<real>::gt(value, ScalarConvert<int, real>::to(0)), 2, "non-positive-norm not supported"); + THArgCheck(THCTensor_(nDimension)(state, src) > 1, 1, "need at least 2 dimensions"); + + dim3 grid(data->size[0]); + dim3 threads(32); + + THCTensor_kernel_renorm<real><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCTensor_(data)(state, data), value, size, maxnorm); + + cudaError errcode = cudaGetLastError(); + if(errcode != cudaSuccess) + THError(cudaGetErrorString(errcode)); + + THCTensor_(free)(state, src_); + self_ = THCTensor_(newTranspose)(state, data, dimension, 0); + THCTensor_(resizeAs)(state, self, self_); + THCTensor_(freeCopyTo)(state, self_, self); + THCTensor_(free)(state, data); +} + +#endif + THC_API accreal THCTensor_(sumall)(THCState *state, THCTensor *self) { THAssert(THCTensor_(checkGPU)(state, 1, self)); diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index 500003f..507850d 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -2,6 +2,12 @@ #define THC_GENERIC_FILE "generic/THCTensorMathReduce.h" #else +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm); + +#endif + THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim); THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim); |