diff options
author | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2016-12-02 02:17:50 +0300 |
---|---|---|
committer | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2016-12-02 02:17:50 +0300 |
commit | 34d27771c6dc988889d8ac857b62932a79bf1210 (patch) | |
tree | 5f7442026aa34f56a5fce5bbc6bac52c6b2c9844 | |
parent | 1093821c335437b399035f3ebf3b67a3e960de8f (diff) |
1.3.2 release
Broadcast tuning
Better checking of inputs
Copy/reduce code simplification
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | src/all_gather.cu | 9 | ||||
-rw-r--r-- | src/all_reduce.cu | 7 | ||||
-rw-r--r-- | src/broadcast.cu | 7 | ||||
-rw-r--r-- | src/common_kernel.h | 290 | ||||
-rw-r--r-- | src/core.cu | 37 | ||||
-rw-r--r-- | src/core.h | 29 | ||||
-rw-r--r-- | src/enqueue.h | 4 | ||||
-rw-r--r-- | src/reduce.cu | 7 | ||||
-rw-r--r-- | src/reduce_scatter.cu | 7 |
10 files changed, 120 insertions, 279 deletions
@@ -52,7 +52,7 @@ endif NCCL_MAJOR := 1 NCCL_MINOR := 3 -NCCL_PATCH := 1 +NCCL_PATCH := 2 CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH) CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev) diff --git a/src/all_gather.cu b/src/all_gather.cu index 2dd6246..cb36b71 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "core.h" +#include "common_coll.h" #include "enqueue.h" #include "primitives.h" @@ -164,18 +165,15 @@ __global__ void AllGatherKernel(const KernelArgs<T> args) { } } -#define THREADS 384 +#define THREADS 512 #define UNROLL 8 template<class FUNC, typename T> ncclResult_t RingAllGather(const void* sendbuff, void* recvbuff, const int count, ncclComm* comm, cudaStream_t stream) { - if (count == 0) - return ncclSuccess; - if (comm->nRanks == 1) { if (sendbuff != recvbuff) - CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError); } else { KernelArgs<T> args; ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm); @@ -198,6 +196,7 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, int count, ncclDataT void* recvbuff, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllGather(const void* sendbuff, int count, ncclDataType_t datatype, void* recvbuff, ncclComm_t comm, cudaStream_t stream) { + NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, "AllGather")); return enqueue<AllGather, FuncNull>(sendbuff, recvbuff, count, datatype, 0, comm, stream); } diff --git a/src/all_reduce.cu b/src/all_reduce.cu index a81ee62..2f38d6e 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "core.h" +#include "common_coll.h" #include "enqueue.h" #include "primitives.h" @@ -202,12 +203,9 @@ __global__ void AllReduceKernel(const KernelArgs<T> args) { template<class FUNC, typename T> ncclResult_t RingAllReduce(const void* sendbuff, void* recvbuff, const int count, ncclComm* comm, cudaStream_t stream) { - if (count == 0) - return ncclSuccess; - if (comm->nRanks == 1) { if (sendbuff != recvbuff) - CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError); } else { KernelArgs<T> args; ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm); @@ -230,6 +228,7 @@ NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, int ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) { + NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, 0, comm, "AllReduce")); return enqueue<AllReduce>(sendbuff, recvbuff, count, datatype, op, 0, comm, stream); } diff --git a/src/broadcast.cu b/src/broadcast.cu index 5843922..3a7cb11 100644 --- a/src/broadcast.cu +++ b/src/broadcast.cu @@ -5,10 +5,11 @@ ************************************************************************/ #include "core.h" +#include "common_coll.h" #include "enqueue.h" #include "primitives.h" -#define NUM_SUBSTEPS 2 +#define NUM_SUBSTEPS 4 #define NUM_BUFCHUNKS 2 // Increase Step and boffset for buffer sync @@ -135,9 +136,6 @@ __global__ void BroadcastKernel(const KernelArgs<T> args) { template<class FUNC, typename T> ncclResult_t RingBroadcast(void* buff, const int count, const int root, ncclComm* comm, cudaStream_t stream) { - if (count == 0) - return ncclSuccess; - if (comm->nRanks != 1) { KernelArgs<T> args; ArgsSetup(&args, buff, buff, root, count, comm); @@ -160,6 +158,7 @@ NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) { + NCCLCHECK(ArgsCheck(buff, buff, count, datatype, ncclSum, root, comm, "Bcast")); return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream); } diff --git a/src/common_kernel.h b/src/common_kernel.h index 7a17e36..28fbc85 100644 --- a/src/common_kernel.h +++ b/src/common_kernel.h @@ -174,32 +174,46 @@ struct MULTI<FUNC, long long> { } }; -template<typename T, bool FETCHTWO> -__device__ inline void FetchOneOrTwo64b(PackType& s0, - const volatile T * __restrict__ const src0, PackType& s1, - const volatile T * __restrict__ const src1, const int idx) { - s0 = (reinterpret_cast<const volatile PackType *>(src0))[idx]; - if (FETCHTWO) { - s1 = (reinterpret_cast<const volatile PackType *>(src1))[idx]; +template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS> +__device__ inline void ReduceCopy( + const volatile T * __restrict__ const src0, + const volatile T * __restrict__ const src1, + volatile T * __restrict__ const dest0, + volatile T * __restrict__ const dest1, const int idx) { + T val = vFetch(src0+idx); + if (TWO_INPUTS) { + val = FUNC()(val, vFetch(src1+idx)); } -} - -template<typename T, bool STORETWO> -__device__ inline void StoreOneOrTwo64b(volatile T * __restrict__ const dest0, - volatile T * __restrict__ const dest1, PackType val, const int idx) { - (reinterpret_cast<volatile PackType *>(dest0))[idx] = val; - if (STORETWO) { - (reinterpret_cast<volatile PackType *>(dest1))[idx] = val; + vStore(dest0+idx, val); + if (TWO_OUTPUTS) { + vStore(dest1+idx, val); } } -template<class FUNC, typename T, bool ISREDUCE> -__device__ inline PackType ReduceOrCopy64b(const PackType s0, - const PackType s1) { - if (ISREDUCE) { - return MULTI<FUNC, T>()(s0, s1); - } else { - return s0; +template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS, int UNROLL, int THREADS> +__device__ inline void ReduceCopy64b( + const volatile T * __restrict__ const src0, + const volatile T * __restrict__ const src1, + volatile T * __restrict__ const dest0, + volatile T * __restrict__ const dest1, const int offset) { + PackType t0[UNROLL]; + PackType t1[UNROLL]; + #pragma unroll + for (int u = 0; u < UNROLL; ++u) { + int idx = offset + u*THREADS; + t0[u] = (reinterpret_cast<const volatile PackType *>(src0))[idx]; + if (TWO_INPUTS) { + t1[u] = (reinterpret_cast<const volatile PackType *>(src1))[idx]; + } + } + #pragma unroll + for (int u = 0; u < UNROLL; ++u) { + int idx = offset + u*THREADS; + PackType val = TWO_INPUTS ? MULTI<FUNC, T>()(t0[u], t1[u]) : t0[u]; + (reinterpret_cast<volatile PackType *>(dest0))[idx] = val; + if (TWO_OUTPUTS) { + (reinterpret_cast<volatile PackType *>(dest1))[idx] = val; + } } } @@ -251,9 +265,6 @@ __device__ inline void ReduceOrCopy(const int tid, return; } - const int UNROLL2 = (UNROLL >= 2) ? (UNROLL / 2) : 1; - const bool NOUNROLL2 = ((UNROLL / 2) == 0); - int Npreamble = (N<alignof(PackType)) ? N : AlignUp(dest0, alignof(PackType)) - dest0; // stage 0: check if we'll be able to use the fast, 64-bit aligned path. @@ -266,247 +277,60 @@ __device__ inline void ReduceOrCopy(const int tid, Npreamble = N; } -/* - if (threadIdx.x == 0) { - printf("** alignable: %s", (alignable ? "YES" : " NO")); - printf(", dest0 = 0x%08X", dest0); - printf(", src0 = 0x%08X", src0); - if (HAS_DEST1) printf(", dest1 = 0x%08X", dest1); - if (HAS_SRC1) printf(", src1 = 0x%08X", src1); - printf("\n"); - } -*/ - // stage 1: preamble: handle any elements up to the point of everything coming // into alignment for (int idx = tid; idx < Npreamble; idx += THREADS) { // ought to be no way this is ever more than one iteration, except when // alignable is false - T val = vFetch(src0+idx); - if (HAS_SRC1) { - val = FUNC()(val, vFetch(src1+idx)); - } - - vStore(dest0+idx, val); - if (HAS_DEST1) { - vStore(dest1+idx, val); - } + ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx); } - // reduce N by however many elements we've handled already - int Ndone = Npreamble; - int Nrem = N - Ndone; - // stage 2: fast path: use 64b loads/stores to do the bulk of the work, // assuming the pointers we have are all 64-bit alignable. if (alignable) { - if (Ndone > 0) { - // align up pointers - dest0 += Ndone; if (HAS_DEST1) { dest1 += Ndone; } - src0 += Ndone; if (HAS_SRC1) { src1 += Ndone; } - } + const int PackFactor = sizeof(PackType) / sizeof(T); + int Nrem = N - Npreamble; + dest0 += Npreamble; if (HAS_DEST1) { dest1 += Npreamble; } + src0 += Npreamble; if (HAS_SRC1) { src1 += Npreamble; } // stage 2a: main loop - int Nalign = (Nrem / (sizeof(PackType) / sizeof(T)) / (UNROLL * THREADS)) + int Nalign2a = (Nrem / (PackFactor * UNROLL * THREADS)) * (UNROLL * THREADS); // round down #pragma unroll 1 // don't unroll this loop - for (int idx = tid; idx < Nalign; idx += UNROLL * THREADS) { - PackType t0[UNROLL2]; - PackType t1[UNROLL2]; - PackType t2[UNROLL2]; - - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1, - idx + j * THREADS); - - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]); - - if (!NOUNROLL2) { - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1, - idx + (UNROLL2 + j) * THREADS); - } - - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j], idx + j * THREADS); - - if (!NOUNROLL2) { - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]); - - #pragma unroll - for (int j = 0; j < UNROLL2; ++j) - StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j], - idx + (UNROLL2 + j) * THREADS); - } + for (int idx = tid; idx < Nalign2a; idx += UNROLL * THREADS) { + ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, UNROLL, THREADS>(src0, src1, dest0, dest1, idx); } + int Ndone2a = Nalign2a * PackFactor; + Nrem -= Ndone2a; + // stage 2b: slightly less optimized for section when we don't have full // UNROLLs - int Ndone2a = Nalign * (sizeof(PackType)/sizeof(T)); - Ndone += Ndone2a; - Nrem = N - Ndone; - - // TODO: This kind of pointer update arithmetic is expensive. Should - // probably find a better way. - if (Nrem > 0) { - dest0 += Ndone2a; if (HAS_DEST1) { dest1 += Ndone2a; } - src0 += Ndone2a; if (HAS_SRC1) { src1 += Ndone2a; } - } - Nalign = Nrem / (sizeof(PackType)/sizeof(T)); + int Nalign2b = Nrem / PackFactor; #pragma unroll 4 - for (int idx = tid; idx < Nalign; idx += THREADS) { - PackType t0, t1, t2; - - FetchOneOrTwo64b<T, HAS_SRC1>(t0, src0, t1, src1, idx); - t2 = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0, t1); - StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2, idx); + for (int idx = Nalign2a + tid; idx < Nalign2a + Nalign2b; idx += THREADS) { + ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, 1, 0>(src0, src1, dest0, dest1, idx); } - // stage 2c: tail - int Ndone2b = Nalign * (sizeof(PackType)/sizeof(T)); - Ndone += Nalign * (sizeof(PackType)/sizeof(T)); - Nrem = N - Ndone; + int Ndone2b = Nalign2b * PackFactor; + Nrem -= Ndone2b; + int Ndone2 = Ndone2a + Ndone2b; + dest0 += Ndone2; if (HAS_DEST1) { dest1 += Ndone2; } + src0 += Ndone2; if (HAS_SRC1) { src1 += Ndone2; } - if (Nrem > 0) { - dest0 += Ndone2b; if (HAS_DEST1) { dest1 += Ndone2b; } - src0 += Ndone2b; if (HAS_SRC1) { src1 += Ndone2b; } - } + // stage 2c: tail for (int idx = tid; idx < Nrem; idx += THREADS) { // never ought to make it more than one time through this loop. only a // few threads should even participate - T val = vFetch(src0+idx); - if (HAS_SRC1) { - val = FUNC()(val, vFetch(src1+idx)); - } - - vStore(dest0+idx, val); - if (HAS_DEST1) { - vStore(dest1+idx, val); - } + ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx); } } // done fast path } -template<int THREADS, int UNROLL, typename T> -__device__ inline void CalcLastChunk(int * const bigSliceN, - int * const smallSliceN, int * const lastSliceN, int * const numSlices, - int * const numBigSlices, int * const numSmallSlices, const int N, - const int numChunks, const int chunkSize) { - int Nleft = N - ((numChunks - 1) * chunkSize); - // semi-equally split up the remaining work into numslices slices. - // it's "semi"-equal because we want the divisions to land as neatly as we - // can on alignable boundaries - int NperTile = UNROLL * THREADS * (sizeof(PackType)/sizeof(T)); - int numTiles = (Nleft + NperTile - 1) / NperTile; - int numTilesPerBigSlice = (numTiles + *numSlices - 1) - / *numSlices; - int numTilesPerSmallSlice = numTiles / *numSlices; - - *bigSliceN = NperTile * numTilesPerBigSlice; - *smallSliceN = NperTile * numTilesPerSmallSlice; - *numBigSlices = numTiles % *numSlices; - *numSmallSlices = (*smallSliceN > 0) ? - *numSlices - *numBigSlices : 0; - - // the lastSlice will take the place of one of the small slices unless - // there are no small slices (because this is a very small reduction), in - // which case we replace one of the big slices and leave the small slices - // as 0. - if (*numSmallSlices > 0) { - --*numSmallSlices; - if (*numSmallSlices == 0) - *smallSliceN = 0; - } - else { - --*numBigSlices; - if (*numBigSlices == 0) - *bigSliceN = 0; - } - - *lastSliceN = Nleft - - (*numBigSlices * *bigSliceN - + *numSmallSlices * *smallSliceN); - - // in cases where args.N % numSlices is pretty small, we'd rather have one - // slightly big last slice than one big slice, a bunch of small slices, - // and one smaller last slice - if ((*numBigSlices == 1) && - (*numSmallSlices == *numSlices - 2) && - (*lastSliceN < *smallSliceN)) { - *numBigSlices += *numSmallSlices; - *numSmallSlices = 0; - *bigSliceN = *smallSliceN; - *smallSliceN = 0; - *lastSliceN = Nleft - - *numBigSlices * *bigSliceN; - } - - // done recalculating - *numSlices = *numBigSlices + - *numSmallSlices + 1; -} - -// Kernel launch -template<typename T> -struct KernelArgs { - // general parameters - int nRanks; - int root; - int buffSize; - int N; - int opIndex; - volatile int * __restrict__ opCounter; - bool pushrecv; - - // some pre-computed sizes - int SliceSize; - int SliceOffset; - int ChunkSize; - int NumChunks; - - // local and remote input, output, and buffer - const T * __restrict__ ThisInput; - T * __restrict__ ThisOutput; - - DevRing<char>* ring; -}; - -template<typename T> -void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff, - const int root, const int count, ncclComm *comm) { - args->nRanks = comm->nRanks; - args->root = root; - args->buffSize = comm->buffSize; - args->N = count; - args->opIndex = comm->opSched; - args->opCounter = comm->opCounter; - args->ThisInput = (const T*)sendbuff; - args->ThisOutput = (T*)recvbuff; - args->ring = comm->devRing; - args->pushrecv = comm->globalMemSpace; -} - -#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \ - args, stream) do { \ - dim3 grid(1, 1, 1); \ - dim3 block(THREADS+1, 1, 1); \ - void* argptrs[] = {&args}; \ - CUDACHECK(cudaLaunchKernel( \ - (void*)K<THREADS, UNROLL, FUNC, T>, \ - grid, block, argptrs, 0, stream)); \ -} while (0) - template <typename T> __device__ inline void incrementOpCounter(const KernelArgs<T> *args) { // increment comm's operation counts diff --git a/src/core.cu b/src/core.cu index be4be06..bf1dc6e 100644 --- a/src/core.cu +++ b/src/core.cu @@ -8,6 +8,7 @@ #include <stdlib.h> #include "core.h" #include "libwrap.h" +#include "common_coll.h" #include <sys/mman.h> #include <sys/stat.h> #include <sys/types.h> @@ -22,6 +23,7 @@ DebugLevel ncclDebugLevel; NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out); ncclResult_t ncclGetUniqueId(ncclUniqueId* out) { + NCCLCHECK(PtrCheck(out, "GetUniqueId", "out")); pid_t pid = getpid(); static int count = 0; int commId = __sync_fetch_and_add(&count, 1); @@ -578,15 +580,6 @@ static void commFree(ncclComm_t comm) { } static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId* commId, int rank) { - if (ndev < 1) { - WARN("invalid device count (%d) requested", ndev); - return ncclUnsupportedDeviceCount; - } - if (rank >= ndev || rank < 0) { - WARN("rank %d exceeds ndev=%d", rank, ndev); - return ncclInvalidRank; - } - size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(NodeRef); struct ncclComm* comm = (struct ncclComm*)malloc(commBytes); if (comm == NULL) { @@ -731,6 +724,17 @@ NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniq ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) { if (myrank == 0) showVersion(); + NCCLCHECK(PtrCheck(newcomm, "CommInitRank", "newcomm")); + + if (ndev < 1) { + WARN("Invalid device count requested : %d", ndev); + return ncclUnsupportedDeviceCount; + } + if (myrank >= ndev || myrank < 0) { + WARN("Invalid rank %d, should be in the range 0..%d", myrank, ndev-1); + return ncclInvalidRank; + } + if (strlen(commId.internal) < 1 || strlen(commId.internal) >= NCCL_UNIQUE_ID_BYTES) { WARN("rank %d invalid commId", myrank); @@ -819,6 +823,13 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { showVersion(); + NCCLCHECK(PtrCheck(comms, "CommInitRank", "comms")); + + if (ndev < 1) { + WARN("Invalid device count requested : %d", ndev); + return ncclUnsupportedDeviceCount; + } + ncclResult_t res; int savedDevice; RankEntry* ranks = NULL; @@ -949,7 +960,7 @@ void ncclCommDestroy(ncclComm_t comm) { int commDevice = comm->cudaDev; if (savedDevice != commDevice) { - CUDACHECK(cudaSetDevice(commDevice)); + CUDACHECK(cudaSetDevice(commDevice), void()); } commFree(comm); @@ -982,18 +993,24 @@ const char* ncclGetErrorString(ncclResult_t code) { NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count); ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) { + NCCLCHECK(PtrCheck(comm, "CommCount", "comm")); + NCCLCHECK(PtrCheck(count, "CommCount", "count")); *count = comm->nRanks; return ncclSuccess; } NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid); ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) { + NCCLCHECK(PtrCheck(comm, "CommCuDevice", "comm")); + NCCLCHECK(PtrCheck(devid, "CommCuDevice", "devid")); *devid = comm->cudaDev; return ncclSuccess; } NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank); ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) { + NCCLCHECK(PtrCheck(comm, "CommUserRank", "comm")); + NCCLCHECK(PtrCheck(rank, "CommUserRank", "rank")); *rank = comm->rank; return ncclSuccess; } @@ -12,18 +12,6 @@ #include <cstdio> #include <cuda_runtime.h> - -// DIE on error -#define CUDACHECK(cmd) do { \ - cudaError_t e = cmd; \ - if( e != cudaSuccess ) { \ - printf("Cuda failure %s:%d '%s'\n", \ - __FILE__,__LINE__,cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ -} while(false) - - #define MAXRANKS 32 #define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21) #define NCCL_MEM_PAD_ALIGN 65536 @@ -136,6 +124,23 @@ extern DebugLevel ncclDebugLevel; } \ } while(0) +// Check CUDA calls +#define CUDACHECK(cmd, retcode) do { \ + cudaError_t e = cmd; \ + if( e != cudaSuccess ) { \ + WARN("Cuda failure '%s'\n", cudaGetErrorString(e)); \ + return retcode; \ + } \ +} while(false) + +// Propagate errors up +#define NCCLCHECK(call) do { \ + ncclResult_t res = call; \ + if (res != ncclSuccess) { \ + return res; \ + } \ +} while (0); + #ifdef PROFAPI #define NCCL_API(ret, func, args...) \ __attribute__ ((visibility("default"))) \ diff --git a/src/enqueue.h b/src/enqueue.h index 01c44c2..43d570e 100644 --- a/src/enqueue.h +++ b/src/enqueue.h @@ -34,7 +34,7 @@ ncclResult_t enqueue(const void* sendbuff, { if (stream != comm->prevStream) { // sync required for calls in different streams comm->prevStream = stream; - CUDACHECK( cudaStreamWaitEvent(stream, comm->doneEvent, 0) ); + CUDACHECK(cudaStreamWaitEvent(stream, comm->doneEvent, 0), ncclUnhandledCudaError); } ncclResult_t ret; @@ -42,7 +42,7 @@ ncclResult_t enqueue(const void* sendbuff, // Always have to record done event because we don't know what stream next // collective will be in. - CUDACHECK( cudaEventRecord(comm->doneEvent, stream) ); + CUDACHECK(cudaEventRecord(comm->doneEvent, stream), ncclUnhandledCudaError); comm->opSched += 1; return ret; } diff --git a/src/reduce.cu b/src/reduce.cu index f281ce8..9effbe9 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "core.h" +#include "common_coll.h" #include "enqueue.h" #include "primitives.h" @@ -117,12 +118,9 @@ __global__ void ReduceKernel(const KernelArgs<T> args) { template<class FUNC, typename T> ncclResult_t RingReduce(const void* sendbuff, void* recvbuff, const int count, const int root, ncclComm* comm, cudaStream_t stream) { - if (count == 0) - return ncclSuccess; - if (comm->nRanks == 1) { if (sendbuff != recvbuff) - CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError); } else { KernelArgs<T> args; ArgsSetup(&args, sendbuff, recvbuff, root, count, comm); @@ -145,6 +143,7 @@ NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, int cou ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, int count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) { + NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, root, comm, "Reduce")); return enqueue<ReduceFunctor>(sendbuff, recvbuff, count, datatype, op, root, comm, stream); } diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index f13cbfb..b1100dd 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "core.h" +#include "common_coll.h" #include "enqueue.h" #include "primitives.h" @@ -133,12 +134,9 @@ __global__ void ReduceScatterKernel(const KernelArgs<T> args) { template<class FUNC, typename T> ncclResult_t RingReduceScatter(const void* sendbuff, void* recvbuff, const int count, ncclComm* comm, cudaStream_t stream) { - if (count == 0) - return ncclSuccess; - if (comm->nRanks == 1) { if (sendbuff != recvbuff) - CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError); } else { KernelArgs<T> args; ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm); @@ -161,6 +159,7 @@ NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) { + NCCLCHECK(ArgsCheck(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, "ReduceScatter")); return enqueue<ReduceScatter>(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream); } |