diff options
Diffstat (limited to 'lib/THC/generic/THCTensorMathReduce.cu')
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 90cc3e3..0bdd540 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -203,29 +203,29 @@ THCTensor_(min)(THCState *state, #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) +THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) { THAssert(THCTensor_(checkGPU)(state, 2, self, src)); - if (value == 0.0) { + if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) { THC_reduceDim(state, self, src, - TensorNonZeroOp<real>(), thrust::plus<real>(), - 0.0, dimension); - } else if (value == 1.0) { + TensorNonZeroOp<real>(), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) { THC_reduceDim(state, self, src, - TensorNormOp<real, 1>(), thrust::plus<real>(), - 0.0, dimension); + TensorNormOp<real, 1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); - } else if (value == 2.0) { + } else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) { THC_reduceDim(state, self, src, - TensorNormOp<real, 2>(), thrust::plus<real>(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 0.5); + TensorNormOp<real, 2>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5)); } else { THC_reduceDim(state, self, src, - TensorNormOp<real, -1>(), thrust::plus<real>(), - 0.0, dimension); - THCTensor_(pow)(state, self, self, 1.0 / value); + TensorNormOp<real, -1>(value), ReduceAdd<real, real>(), + ScalarConvert<float, real>::to(0.0), dimension); + THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value)); } THCudaCheck(cudaGetLastError()); |