diff options
author | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2017-06-07 19:57:12 +0300 |
---|---|---|
committer | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2017-06-14 21:20:24 +0300 |
commit | 29a1a916dc14bb2c00feed3d4820d51fa85be1e6 (patch) | |
tree | f03dd425bfca48ce2da7251b018532fce302c1b5 | |
parent | ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b (diff) |
Add support for CUDA9 half semantics
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | src/common_kernel.h | 44 | ||||
-rw-r--r-- | src/copy_kernel.h | 4 |
3 files changed, 27 insertions, 23 deletions
@@ -54,7 +54,7 @@ endif NCCL_MAJOR := 1 NCCL_MINOR := 3 -NCCL_PATCH := 4 +NCCL_PATCH := 5 CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH) CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev) diff --git a/src/common_kernel.h b/src/common_kernel.h index cc71f8a..b96519f 100644 --- a/src/common_kernel.h +++ b/src/common_kernel.h @@ -35,25 +35,33 @@ T vFetch(const volatile T* ptr) { return *ptr; } +template<typename T> inline __device__ +void vStore(volatile T* ptr, const T val) { + *ptr = val; +} + #ifdef CUDA_HAS_HALF +#if CUDART_VERSION < 9000 template<> inline __device__ half vFetch<half>(const volatile half* ptr) { half r; r.x = ptr->x; return r; } -#endif - -template<typename T> inline __device__ -void vStore(volatile T* ptr, const T val) { - *ptr = val; -} - -#ifdef CUDA_HAS_HALF template<> inline __device__ void vStore<half>(volatile half* ptr, const half val) { ptr->x = val.x; } +#else +template<> inline __device__ +half vFetch<half>(const volatile half* ptr) { + return *((half*)ptr); +} +template<> inline __device__ +void vStore<half>(volatile half* ptr, const half val) { + *((half*)ptr) = val; +} +#endif #endif __device__ unsigned int spinct; @@ -125,24 +133,22 @@ struct MULTI<FUNC, int> { #ifdef CUDA_HAS_HALF template<class FUNC> struct MULTI<FUNC, half> { - static_assert(sizeof(PackType) == 2 * sizeof(float), - "PackType must be twice the size of float."); - union converter { - PackType storage; - struct { - half2 a, b; - }; + static_assert(sizeof(PackType) == 4 * sizeof(half), + "PackType must be four times the size of half."); + + struct PackHalf2 { + half2 a, b; }; __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; + struct PackHalf2 cx, cy, cr; + cx = *(reinterpret_cast<const struct PackHalf2*>(&x)); + cy = *(reinterpret_cast<const struct PackHalf2*>(&y)); cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); - return cr.storage; + return *(reinterpret_cast<PackType*>(&cr)); } }; #endif diff --git a/src/copy_kernel.h b/src/copy_kernel.h index 8464699..0f69748 100644 --- a/src/copy_kernel.h +++ b/src/copy_kernel.h @@ -24,9 +24,7 @@ struct FuncPassA<half> { return x; } __device__ half operator()(const half x, const half y) const { - half r; - r.x = x.x; - return r; + return x; } }; #endif |