diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-07-20 00:10:29 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-25 14:27:16 +0300 |
commit | a9f950a6da6567dc3f331a647147925af4196645 (patch) | |
tree | def50e661ea7a78e106a1e7e3139639ebc281856 | |
parent | 2de2d2bf67e4a4e5d139c729c26c7d609dbd6349 (diff) |
Updates for CUDA 9
-rw-r--r-- | lib/THC/CMakeLists.txt | 5 | ||||
-rw-r--r-- | lib/THC/THCBlas.cu | 18 | ||||
-rw-r--r-- | lib/THC/THCDeviceUtils.cuh | 1 | ||||
-rw-r--r-- | lib/THC/THCHalf.h | 2 | ||||
-rw-r--r-- | lib/THC/THCScanUtils.cuh | 1 |
5 files changed, 22 insertions, 5 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index 3643904..0fec282 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -41,6 +41,11 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3") endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + +if(CUDA_VERSION VERSION_GREATER "8.0") + LIST(APPEND CUDA_NVCC_FLAGS "-D__CUDA_NO_HALF_OPERATORS__") +endif(CUDA_VERSION VERSION_GREATER "8.0") + IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" ) SET(CXX_VERSION "c++11") diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index 79e3b1e..3defa03 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -275,20 +275,32 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, cublasSetStream(handle, THCState_getCurrentStream(state)); // Check for native Hgemm support - if (THC_fastHalfInstructions(state)) { +/* if (THC_fastHalfInstructions(state)) { THCublasCheck(cublasHgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc)); - } else { + } else {*/ // Simulated Hgemm float fAlpha = THC_half2float(alpha); float fBeta = THC_half2float(beta); +#if CUDA_VERSION < 9000 THCublasCheck(cublasSgemmEx(handle, opa, opb, i_m, i_n, i_k, &fAlpha, a, CUDA_R_16F, i_lda, b, CUDA_R_16F, i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); - } +#else + THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + THCublasCheck(cublasGemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, + CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); + THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + + +#endif +// } return; } diff --git a/lib/THC/THCDeviceUtils.cuh b/lib/THC/THCDeviceUtils.cuh index 8052860..4ae2bee 100644 --- a/lib/THC/THCDeviceUtils.cuh +++ b/lib/THC/THCDeviceUtils.cuh @@ -43,7 +43,6 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK() #endif } - __device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 diff --git a/lib/THC/THCHalf.h b/lib/THC/THCHalf.h index d5bd5c1..bb21b9d 100644 --- a/lib/THC/THCHalf.h +++ b/lib/THC/THCHalf.h @@ -15,7 +15,7 @@ #if CUDA_VERSION >= 9000 #ifndef __cplusplus - typedef __half_raw half; +typedef __half_raw half; #endif #endif diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh index ce9619d..9a487ca 100644 --- a/lib/THC/THCScanUtils.cuh +++ b/lib/THC/THCScanUtils.cuh @@ -2,6 +2,7 @@ #define THC_SCAN_UTILS_INC #include "THCAsmUtils.cuh" +#include "THCDeviceUtils.cuh" // Collection of in-kernel scan / prefix sum utilities |