Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-11-03 02:40:54 +0300
committersoumith <soumith@fb.com>2016-11-03 02:40:54 +0300
commita06a5e1e1bb35fede112a096030ab96b9526e7ed (patch)
treeb9cd86a9addbe353c1fe4e33216e12c1a4530e8a
parent45a229db8bcb823217f0c64547418bc58507e91d (diff)
making dot to have an accreal return type (consistent with CPU)dotfix
-rw-r--r--TensorMath.lua2
-rw-r--r--lib/THC/THCBlas.cu8
-rw-r--r--lib/THC/THCBlas.h2
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu10
-rw-r--r--lib/THC/generic/THCTensorMathBlas.h2
5 files changed, 12 insertions, 12 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 802565e..a5d436f 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1134,7 +1134,7 @@ for k, Tensor_ in pairs(handledTypenames) do
cname("dot"),
{{name=Tensor},
{name=Tensor},
- {name=real, creturned=true}})
+ {name=accreal, creturned=true}})
method:register("m_cutorch_" .. Tensor .. "Math__")
interface:print(method:tostring())
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu
index ed3d2e3..26c9c8d 100644
--- a/lib/THC/THCBlas.cu
+++ b/lib/THC/THCBlas.cu
@@ -49,7 +49,7 @@ double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y,
}
#ifdef CUDA_HALF_TENSOR
-half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy)
+float THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy)
{
#if CUDA_VERSION >= 8000
if (n == 1) {
@@ -64,16 +64,16 @@ half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long
half result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
- THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_16F, CUDA_R_32F));
+ THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_32F, CUDA_R_32F));
return result;
}
THError("Cublas_Hdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
- return THC_float2half(0);
+ return 0;
#else
THError("Cublas_Hdot requires CUDA 8.0+");
- return THC_float2half(0);
+ return 0;
#endif
}
#endif
diff --git a/lib/THC/THCBlas.h b/lib/THC/THCBlas.h
index 435d4f5..bf91f93 100644
--- a/lib/THC/THCBlas.h
+++ b/lib/THC/THCBlas.h
@@ -8,7 +8,7 @@
THC_API float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy);
THC_API double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy);
#ifdef CUDA_HALF_TENSOR
-THC_API half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy);
+THC_API float THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy);
#endif
/* Level 2 */
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index 7c4ba1d..d4bd3c2 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -2,7 +2,7 @@
#define THC_GENERIC_FILE "generic/THCTensorMathBlas.cu"
#else
-THC_API real
+THC_API accreal
THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
{
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
@@ -14,17 +14,17 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
src = THCTensor_(newContiguous)(state, src);
#ifdef THC_REAL_IS_FLOAT
- real result = THCudaBlas_Sdot(state,
+ accreal result = THCudaBlas_Sdot(state,
THCTensor_(nElement)(state, self),
THCTensor_(data)(state, self), 1,
THCTensor_(data)(state, src), 1);
#elif defined(THC_REAL_IS_DOUBLE)
- real result = THCudaBlas_Ddot(state,
+ accreal result = THCudaBlas_Ddot(state,
THCTensor_(nElement)(state, self),
THCTensor_(data)(state, self), 1,
THCTensor_(data)(state, src), 1);
#elif defined(THC_REAL_IS_HALF)
- real result = THCudaBlas_Hdot(state,
+ accreal result = THCudaBlas_Hdot(state,
THCTensor_(nElement)(state, self),
THCTensor_(data)(state, self), 1,
THCTensor_(data)(state, src), 1);
@@ -36,7 +36,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
#else
THError("unimplemented data type");
- return ScalarConvert<int, real>::to(0);
+ return ScalarConvert<int, accreal>::to(0);
#endif
}
diff --git a/lib/THC/generic/THCTensorMathBlas.h b/lib/THC/generic/THCTensorMathBlas.h
index 68f95e3..f37910c 100644
--- a/lib/THC/generic/THCTensorMathBlas.h
+++ b/lib/THC/generic/THCTensorMathBlas.h
@@ -2,7 +2,7 @@
#define THC_GENERIC_FILE "generic/THCTensorMathBlas.h"
#else
-THC_API real THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src);
+THC_API accreal THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(addmv)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec);
THC_API void THCTensor_(addmm)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *mat1, THCTensor *mat2);
THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2);