diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-08-25 14:31:32 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-25 14:31:32 +0300 |
commit | d0bb7e12cbfbae560b02b4226d7eb861bd7f48af (patch) | |
tree | 17c0f97c8a1e5b3506413c9ec7fd9c2bda09f71c | |
parent | a9f950a6da6567dc3f331a647147925af4196645 (diff) |
cuda 9 hgemm fix
-rw-r--r-- | lib/THC/THCBlas.cu | 58 |
1 files changed, 28 insertions, 30 deletions
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index 3defa03..6724761 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -263,47 +263,45 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, cublasOperation_t opb = convertTransToCublasOperation(transb); if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_lda = (int)lda; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; - cublasHandle_t handle = THCState_getCurrentBlasHandle(state); - cublasSetStream(handle, THCState_getCurrentStream(state)); + cublasHandle_t handle = THCState_getCurrentBlasHandle(state); + cublasSetStream(handle, THCState_getCurrentStream(state)); - // Check for native Hgemm support -/* if (THC_fastHalfInstructions(state)) { - THCublasCheck(cublasHgemm(handle, opa, opb, - i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, - &beta, c, i_ldc)); - } else {*/ // Simulated Hgemm float fAlpha = THC_half2float(alpha); float fBeta = THC_half2float(beta); #if CUDA_VERSION < 9000 THCublasCheck(cublasSgemmEx(handle, opa, opb, - i_m, i_n, i_k, &fAlpha, + i_m, i_n, i_k, &fAlpha, a, CUDA_R_16F, i_lda, b, CUDA_R_16F, - i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); #else - THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - THCublasCheck(cublasGemmEx(handle, opa, opb, - i_m, i_n, i_k, &fAlpha, - a, CUDA_R_16F, i_lda, b, CUDA_R_16F, - i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, - CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); - THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - + cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state); + if (prop->major >= 5){ + THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + THCublasCheck(cublasGemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, + CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); + THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + }else{ + THCublasCheck(cublasSgemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + } #endif -// } - - return; - } + return; + } THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc" "with th bound [val] <= %d", INT_MAX); } |