Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-07-20 00:10:29 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-25 14:27:25 +0300
commita6522cf297e1496cd2bd2271992ec3bcc525bd72 (patch)
treeeae555074c84932d9142da487240773809c797bb
parent90afcbf93b629b74cbb3bd76b12a0a8c389195e9 (diff)
Updates for CUDA 9
-rw-r--r--lib/THCUNN/BatchNormalization.cu4
-rw-r--r--lib/THCUNN/CMakeLists.txt4
2 files changed, 6 insertions, 2 deletions
diff --git a/lib/THCUNN/BatchNormalization.cu b/lib/THCUNN/BatchNormalization.cu
index 125e3ff..e6717c7 100644
--- a/lib/THCUNN/BatchNormalization.cu
+++ b/lib/THCUNN/BatchNormalization.cu
@@ -5,7 +5,7 @@
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
-
+#include "THCDeviceUtils.cuh"
const int WARP_SIZE = 32;
// The maximum number of threads in a block
@@ -80,7 +80,7 @@ template <typename T>
static __device__ __forceinline__ T warpSum(T val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
- val += __shfl_xor(val, 1 << i, WARP_SIZE);
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
}
#else
__shared__ T values[MAX_BLOCK_SIZE];
diff --git a/lib/THCUNN/CMakeLists.txt b/lib/THCUNN/CMakeLists.txt
index 6047d97..46eda7e 100644
--- a/lib/THCUNN/CMakeLists.txt
+++ b/lib/THCUNN/CMakeLists.txt
@@ -51,6 +51,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+if(CUDA_VERSION VERSION_GREATER "8.0")
+ LIST(APPEND CUDA_NVCC_FLAGS "-D__CUDA_NO_HALF_OPERATORS__")
+endif(CUDA_VERSION VERSION_GREATER "8.0")
+
IF(MSVC)
LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler /wd4819")
ADD_DEFINITIONS(-DTH_EXPORTS)