diff options
author | soumith <soumith@fb.com> | 2016-11-03 02:40:54 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-11-03 02:40:54 +0300 |
commit | a06a5e1e1bb35fede112a096030ab96b9526e7ed (patch) | |
tree | b9cd86a9addbe353c1fe4e33216e12c1a4530e8a | |
parent | 45a229db8bcb823217f0c64547418bc58507e91d (diff) |
making dot to have an accreal return type (consistent with CPU)dotfix
-rw-r--r-- | TensorMath.lua | 2 | ||||
-rw-r--r-- | lib/THC/THCBlas.cu | 8 | ||||
-rw-r--r-- | lib/THC/THCBlas.h | 2 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 10 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.h | 2 |
5 files changed, 12 insertions, 12 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 802565e..a5d436f 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1134,7 +1134,7 @@ for k, Tensor_ in pairs(handledTypenames) do cname("dot"), {{name=Tensor}, {name=Tensor}, - {name=real, creturned=true}}) + {name=accreal, creturned=true}}) method:register("m_cutorch_" .. Tensor .. "Math__") interface:print(method:tostring()) diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index ed3d2e3..26c9c8d 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -49,7 +49,7 @@ double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, } #ifdef CUDA_HALF_TENSOR -half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy) +float THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy) { #if CUDA_VERSION >= 8000 if (n == 1) { @@ -64,16 +64,16 @@ half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long half result; cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); - THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_16F, CUDA_R_32F)); + THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_32F, CUDA_R_32F)); return result; } THError("Cublas_Hdot only supports n, incx and incy " "up to signed integer limits: %d", INT_MAX); - return THC_float2half(0); + return 0; #else THError("Cublas_Hdot requires CUDA 8.0+"); - return THC_float2half(0); + return 0; #endif } #endif diff --git a/lib/THC/THCBlas.h b/lib/THC/THCBlas.h index 435d4f5..bf91f93 100644 --- a/lib/THC/THCBlas.h +++ b/lib/THC/THCBlas.h @@ -8,7 +8,7 @@ THC_API float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy); THC_API double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy); #ifdef CUDA_HALF_TENSOR -THC_API half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy); +THC_API float THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy); #endif /* Level 2 */ diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index 7c4ba1d..d4bd3c2 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -2,7 +2,7 @@ #define THC_GENERIC_FILE "generic/THCTensorMathBlas.cu" #else -THC_API real +THC_API accreal THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) @@ -14,17 +14,17 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) src = THCTensor_(newContiguous)(state, src); #ifdef THC_REAL_IS_FLOAT - real result = THCudaBlas_Sdot(state, + accreal result = THCudaBlas_Sdot(state, THCTensor_(nElement)(state, self), THCTensor_(data)(state, self), 1, THCTensor_(data)(state, src), 1); #elif defined(THC_REAL_IS_DOUBLE) - real result = THCudaBlas_Ddot(state, + accreal result = THCudaBlas_Ddot(state, THCTensor_(nElement)(state, self), THCTensor_(data)(state, self), 1, THCTensor_(data)(state, src), 1); #elif defined(THC_REAL_IS_HALF) - real result = THCudaBlas_Hdot(state, + accreal result = THCudaBlas_Hdot(state, THCTensor_(nElement)(state, self), THCTensor_(data)(state, self), 1, THCTensor_(data)(state, src), 1); @@ -36,7 +36,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) #else THError("unimplemented data type"); - return ScalarConvert<int, real>::to(0); + return ScalarConvert<int, accreal>::to(0); #endif } diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h index 68f95e3..f37910c 100644 --- a/lib/THC/generic/THCTensorMathBlas.h +++ b/lib/THC/generic/THCTensorMathBlas.h @@ -2,7 +2,7 @@ #define THC_GENERIC_FILE "generic/THCTensorMathBlas.h" #else -THC_API real THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src); +THC_API accreal THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(addmv)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec); THC_API void THCTensor_(addmm)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat1, THCTensor *mat2); THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2); |