diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-03 23:32:00 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 12076f677505257ac945ea0b092cb54e42ccecaa (patch) | |
tree | 3626bdee91bc86bcd6fbda73389e4e543656df53 | |
parent | 167fb1025c65e25c25e78ee8017d52f1f29d349c (diff) |
[cutorch refactor] cleanup code in prep for review
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 159 | ||||
-rw-r--r-- | test/test.lua | 2 |
3 files changed, 80 insertions, 82 deletions
diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index df8e290..4b66ac3 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -17,7 +17,6 @@ struct ReduceAdd { }; #ifdef CUDA_HALF_TENSOR - template <> struct ReduceAdd<half, half> { inline __device__ half operator()(half a, half b) const { diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 1b86027..5e1ea00 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -40,7 +40,8 @@ THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim) #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real maxnorm) +THC_API void +THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real maxnorm) { THAssert(THCTensor_(checkGPU)(state, 2, self, src)); THCTensor *self_; @@ -68,7 +69,8 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real v THCTensor_(free)(state, data); } -void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag) +THC_API void +THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag) { THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); THLongStorage *dim = THCTensor_(newSizeOf)(state, src); @@ -89,7 +91,8 @@ void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dim THCTensor_(freeCopyTo)(state, self, self_); } -accreal THCTensor_(stdall)(THCState *state, THCTensor *self) +THC_API accreal +THCTensor_(stdall)(THCState *state, THCTensor *self) { THAssert(THCTensor_(checkGPU)(state, 1, self)); return THCNumerics<accreal>::sqrt((THCTensor_(varall)(state, self))); @@ -120,6 +123,80 @@ THCTensor_(varall)(THCState *state, THCTensor *self) return val; } +THC_API void +THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self, src)); + if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { + THC_reduceDim(state, self, src, + TensorNonZeroOp<real>(), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { + THC_reduceDim(state, self, src, + TensorNormOp<real, 1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { + THC_reduceDim(state, self, src, + TensorNormOp<real, 2>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5)); + + } else { + THC_reduceDim(state, self, src, + TensorNormOp<real, -1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value)); + } + + THCudaCheck(cudaGetLastError()); +} + +THC_API accreal +THCTensor_(normall)(THCState *state, THCTensor *self, real value) +{ + THAssert(THCTensor_(checkGPU)(state, 1, self)); + accreal result; + + if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { + THC_reduceAll(state, self, + TensorNonZeroOp<real>(), + ReduceAdd<real, accreal>(), + ReduceAdd<accreal, accreal>(), + ScalarConvert<float, accreal>::to(0.0f), + &result, 0); + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { + THC_reduceAll(state, self, + TensorNormOp<real, 1>(value), + ReduceAdd<real, accreal>(), + ReduceAdd<accreal, accreal>(), + ScalarConvert<float, accreal>::to(0.0f), + &result, 0); + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { + THC_reduceAll(state, self, + TensorNormOp<real, 2>(value), + ReduceAdd<real, accreal>(), + ReduceAdd<accreal, accreal>(), + ScalarConvert<float, accreal>::to(0.0f), + &result, 0); + result = THCNumerics<accreal>::sqrt(result); + } else { + THC_reduceAll(state, self, + TensorNormOp<real, -1>(value), + ReduceAdd<real, accreal>(), + ReduceAdd<accreal, accreal>(), + ScalarConvert<float, accreal>::to(0.0f), + &result, 0); + result = THCNumerics<accreal>::pow( + result, + ScalarConvert<real, accreal>::to(THCNumerics<real>::cinv(value)) + ); + } + + THCudaCheck(cudaGetLastError()); + return result; +} + #endif THC_API accreal @@ -237,80 +314,4 @@ THCTensor_(min)(THCState *state, MinValuePair<typename TensorUtils<THCTensor>::DataType, long>()); } -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) -{ - THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { - THC_reduceDim(state, self, src, - TensorNonZeroOp<real>(), ReduceAdd<real, real>(), - ScalarConvert<float, real>::to(0.0), dimension); - } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { - THC_reduceDim(state, self, src, - TensorNormOp<real, 1>(value), ReduceAdd<real, real>(), - ScalarConvert<float, real>::to(0.0), dimension); - - } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { - THC_reduceDim(state, self, src, - TensorNormOp<real, 2>(value), ReduceAdd<real, real>(), - ScalarConvert<float, real>::to(0.0), dimension); - THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5)); - - } else { - THC_reduceDim(state, self, src, - TensorNormOp<real, -1>(value), ReduceAdd<real, real>(), - ScalarConvert<float, real>::to(0.0), dimension); - THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value)); - } - - THCudaCheck(cudaGetLastError()); -} - -THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value) -{ - THAssert(THCTensor_(checkGPU)(state, 1, self)); - accreal result; - - if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { - THC_reduceAll(state, self, - TensorNonZeroOp<real>(), - ReduceAdd<real, accreal>(), - ReduceAdd<accreal, accreal>(), - ScalarConvert<float, accreal>::to(0.0f), - &result, 0); - } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { - THC_reduceAll(state, self, - TensorNormOp<real, 1>(value), - ReduceAdd<real, accreal>(), - ReduceAdd<accreal, accreal>(), - ScalarConvert<float, accreal>::to(0.0f), - &result, 0); - } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { - THC_reduceAll(state, self, - TensorNormOp<real, 2>(value), - ReduceAdd<real, accreal>(), - ReduceAdd<accreal, accreal>(), - ScalarConvert<float, accreal>::to(0.0f), - &result, 0); - result = THCNumerics<accreal>::sqrt(result); - } else { - THC_reduceAll(state, self, - TensorNormOp<real, -1>(value), - ReduceAdd<real, accreal>(), - ReduceAdd<accreal, accreal>(), - ScalarConvert<float, accreal>::to(0.0f), - &result, 0); - result = THCNumerics<accreal>::pow( - result, - ScalarConvert<real, accreal>::to(THCNumerics<real>::cinv(value)) - ); - } - - THCudaCheck(cudaGetLastError()); - return result; -} - -#endif - #endif diff --git a/test/test.lua b/test/test.lua index ddf67e4..078dd46 100644 --- a/test/test.lua +++ b/test/test.lua @@ -233,8 +233,6 @@ local function compareFloatAndCuda(x, fn, ...) .. "are different for function '%s'", tostring(fn))) for k, _ in ipairs(rcpu) do if not isEqual(rcpu[k], rcuda[k], tolerance) then - print(string.format("cpu results: %s", tostring(rcpu[k]))) - print(string.format("cuda results: %s", tostring(rcuda[k]))) print(args) tester:assert(false, errstr) end |