diff options
-rw-r--r-- | lib/THC/THCBlas.cu | 58 |
1 files changed, 28 insertions, 30 deletions
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index 3defa03..6724761 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -263,47 +263,45 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, cublasOperation_t opb = convertTransToCublasOperation(transb); if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_lda = (int)lda; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; - cublasHandle_t handle = THCState_getCurrentBlasHandle(state); - cublasSetStream(handle, THCState_getCurrentStream(state)); + cublasHandle_t handle = THCState_getCurrentBlasHandle(state); + cublasSetStream(handle, THCState_getCurrentStream(state)); - // Check for native Hgemm support -/* if (THC_fastHalfInstructions(state)) { - THCublasCheck(cublasHgemm(handle, opa, opb, - i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, - &beta, c, i_ldc)); - } else {*/ // Simulated Hgemm float fAlpha = THC_half2float(alpha); float fBeta = THC_half2float(beta); #if CUDA_VERSION < 9000 THCublasCheck(cublasSgemmEx(handle, opa, opb, - i_m, i_n, i_k, &fAlpha, + i_m, i_n, i_k, &fAlpha, a, CUDA_R_16F, i_lda, b, CUDA_R_16F, - i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); #else - THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - THCublasCheck(cublasGemmEx(handle, opa, opb, - i_m, i_n, i_k, &fAlpha, - a, CUDA_R_16F, i_lda, b, CUDA_R_16F, - i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, - CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); - THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - + cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state); + if (prop->major >= 5){ + THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + THCublasCheck(cublasGemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, + CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); + THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + }else{ + THCublasCheck(cublasSgemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + } #endif -// } - - return; - } + return; + } THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc" "with th bound [val] <= %d", INT_MAX); } |