Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-08-25 14:31:32 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-25 14:31:32 +0300
commitd0bb7e12cbfbae560b02b4226d7eb861bd7f48af (patch)
tree17c0f97c8a1e5b3506413c9ec7fd9c2bda09f71c
parenta9f950a6da6567dc3f331a647147925af4196645 (diff)
cuda 9 hgemm fix
-rw-r--r--lib/THC/THCBlas.cu58
1 files changed, 28 insertions, 30 deletions
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu
index 3defa03..6724761 100644
--- a/lib/THC/THCBlas.cu
+++ b/lib/THC/THCBlas.cu
@@ -263,47 +263,45 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n,
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
- {
- int i_m = (int)m;
- int i_n = (int)n;
- int i_k = (int)k;
- int i_lda = (int)lda;
- int i_ldb = (int)ldb;
- int i_ldc = (int)ldc;
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_k = (int)k;
+ int i_lda = (int)lda;
+ int i_ldb = (int)ldb;
+ int i_ldc = (int)ldc;
- cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
- cublasSetStream(handle, THCState_getCurrentStream(state));
+ cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
+ cublasSetStream(handle, THCState_getCurrentStream(state));
- // Check for native Hgemm support
-/* if (THC_fastHalfInstructions(state)) {
- THCublasCheck(cublasHgemm(handle, opa, opb,
- i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb,
- &beta, c, i_ldc));
- } else {*/
// Simulated Hgemm
float fAlpha = THC_half2float(alpha);
float fBeta = THC_half2float(beta);
#if CUDA_VERSION < 9000
THCublasCheck(cublasSgemmEx(handle, opa, opb,
- i_m, i_n, i_k, &fAlpha,
+ i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
- i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
+ i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
#else
- THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
- THCublasCheck(cublasGemmEx(handle, opa, opb,
- i_m, i_n, i_k, &fAlpha,
- a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
- i_ldb, &fBeta, c, CUDA_R_16F, i_ldc,
- CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
- THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
-
-
+ cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
+ if (prop->major >= 5){
+ THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
+ THCublasCheck(cublasGemmEx(handle, opa, opb,
+ i_m, i_n, i_k, &fAlpha,
+ a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
+ i_ldb, &fBeta, c, CUDA_R_16F, i_ldc,
+ CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
+ THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
+ }else{
+ THCublasCheck(cublasSgemmEx(handle, opa, opb,
+ i_m, i_n, i_k, &fAlpha,
+ a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
+ i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
+ }
#endif
-// }
-
- return;
- }
+ return;
+ }
THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc"
"with th bound [val] <= %d", INT_MAX);
}