Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-07-20 00:10:29 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-25 14:27:16 +0300
commita9f950a6da6567dc3f331a647147925af4196645 (patch)
treedef50e661ea7a78e106a1e7e3139639ebc281856
parent2de2d2bf67e4a4e5d139c729c26c7d609dbd6349 (diff)
Updates for CUDA 9
-rw-r--r--lib/THC/CMakeLists.txt5
-rw-r--r--lib/THC/THCBlas.cu18
-rw-r--r--lib/THC/THCDeviceUtils.cuh1
-rw-r--r--lib/THC/THCHalf.h2
-rw-r--r--lib/THC/THCScanUtils.cuh1
5 files changed, 22 insertions, 5 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 3643904..0fec282 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -41,6 +41,11 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+
+if(CUDA_VERSION VERSION_GREATER "8.0")
+ LIST(APPEND CUDA_NVCC_FLAGS "-D__CUDA_NO_HALF_OPERATORS__")
+endif(CUDA_VERSION VERSION_GREATER "8.0")
+
IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" )
SET(CXX_VERSION "c++11")
diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu
index 79e3b1e..3defa03 100644
--- a/lib/THC/THCBlas.cu
+++ b/lib/THC/THCBlas.cu
@@ -275,20 +275,32 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n,
cublasSetStream(handle, THCState_getCurrentStream(state));
// Check for native Hgemm support
- if (THC_fastHalfInstructions(state)) {
+/* if (THC_fastHalfInstructions(state)) {
THCublasCheck(cublasHgemm(handle, opa, opb,
i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb,
&beta, c, i_ldc));
- } else {
+ } else {*/
// Simulated Hgemm
float fAlpha = THC_half2float(alpha);
float fBeta = THC_half2float(beta);
+#if CUDA_VERSION < 9000
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
- }
+#else
+ THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
+ THCublasCheck(cublasGemmEx(handle, opa, opb,
+ i_m, i_n, i_k, &fAlpha,
+ a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
+ i_ldb, &fBeta, c, CUDA_R_16F, i_ldc,
+ CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
+ THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
+
+
+#endif
+// }
return;
}
diff --git a/lib/THC/THCDeviceUtils.cuh b/lib/THC/THCDeviceUtils.cuh
index 8052860..4ae2bee 100644
--- a/lib/THC/THCDeviceUtils.cuh
+++ b/lib/THC/THCDeviceUtils.cuh
@@ -43,7 +43,6 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK()
#endif
}
-
__device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
diff --git a/lib/THC/THCHalf.h b/lib/THC/THCHalf.h
index d5bd5c1..bb21b9d 100644
--- a/lib/THC/THCHalf.h
+++ b/lib/THC/THCHalf.h
@@ -15,7 +15,7 @@
#if CUDA_VERSION >= 9000
#ifndef __cplusplus
- typedef __half_raw half;
+typedef __half_raw half;
#endif
#endif
diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh
index ce9619d..9a487ca 100644
--- a/lib/THC/THCScanUtils.cuh
+++ b/lib/THC/THCScanUtils.cuh
@@ -2,6 +2,7 @@
#define THC_SCAN_UTILS_INC
#include "THCAsmUtils.cuh"
+#include "THCDeviceUtils.cuh"
// Collection of in-kernel scan / prefix sum utilities