Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2017-06-07 19:57:12 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2017-06-14 21:20:24 +0300
commit29a1a916dc14bb2c00feed3d4820d51fa85be1e6 (patch)
treef03dd425bfca48ce2da7251b018532fce302c1b5
parentccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b (diff)
Add support for CUDA9 half semantics
-rw-r--r--Makefile2
-rw-r--r--src/common_kernel.h44
-rw-r--r--src/copy_kernel.h4
3 files changed, 27 insertions, 23 deletions
diff --git a/Makefile b/Makefile
index 8f34fcb..c37b7f7 100644
--- a/Makefile
+++ b/Makefile
@@ -54,7 +54,7 @@ endif
NCCL_MAJOR := 1
NCCL_MINOR := 3
-NCCL_PATCH := 4
+NCCL_PATCH := 5
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
diff --git a/src/common_kernel.h b/src/common_kernel.h
index cc71f8a..b96519f 100644
--- a/src/common_kernel.h
+++ b/src/common_kernel.h
@@ -35,25 +35,33 @@ T vFetch(const volatile T* ptr) {
return *ptr;
}
+template<typename T> inline __device__
+void vStore(volatile T* ptr, const T val) {
+ *ptr = val;
+}
+
#ifdef CUDA_HAS_HALF
+#if CUDART_VERSION < 9000
template<> inline __device__
half vFetch<half>(const volatile half* ptr) {
half r;
r.x = ptr->x;
return r;
}
-#endif
-
-template<typename T> inline __device__
-void vStore(volatile T* ptr, const T val) {
- *ptr = val;
-}
-
-#ifdef CUDA_HAS_HALF
template<> inline __device__
void vStore<half>(volatile half* ptr, const half val) {
ptr->x = val.x;
}
+#else
+template<> inline __device__
+half vFetch<half>(const volatile half* ptr) {
+ return *((half*)ptr);
+}
+template<> inline __device__
+void vStore<half>(volatile half* ptr, const half val) {
+ *((half*)ptr) = val;
+}
+#endif
#endif
__device__ unsigned int spinct;
@@ -125,24 +133,22 @@ struct MULTI<FUNC, int> {
#ifdef CUDA_HAS_HALF
template<class FUNC>
struct MULTI<FUNC, half> {
- static_assert(sizeof(PackType) == 2 * sizeof(float),
- "PackType must be twice the size of float.");
- union converter {
- PackType storage;
- struct {
- half2 a, b;
- };
+ static_assert(sizeof(PackType) == 4 * sizeof(half),
+ "PackType must be four times the size of half.");
+
+ struct PackHalf2 {
+ half2 a, b;
};
__device__ PackType operator()(const PackType x, const PackType y) const {
- converter cx, cy, cr;
- cx.storage = x;
- cy.storage = y;
+ struct PackHalf2 cx, cy, cr;
+ cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
+ cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
cr.a = FUNC()(cx.a, cy.a);
cr.b = FUNC()(cx.b, cy.b);
- return cr.storage;
+ return *(reinterpret_cast<PackType*>(&cr));
}
};
#endif
diff --git a/src/copy_kernel.h b/src/copy_kernel.h
index 8464699..0f69748 100644
--- a/src/copy_kernel.h
+++ b/src/copy_kernel.h
@@ -24,9 +24,7 @@ struct FuncPassA<half> {
return x;
}
__device__ half operator()(const half x, const half y) const {
- half r;
- r.x = x.x;
- return r;
+ return x;
}
};
#endif