Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2016-12-02 02:17:50 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2016-12-02 02:17:50 +0300
commit34d27771c6dc988889d8ac857b62932a79bf1210 (patch)
tree5f7442026aa34f56a5fce5bbc6bac52c6b2c9844
parent1093821c335437b399035f3ebf3b67a3e960de8f (diff)
1.3.2 release
Broadcast tuning Better checking of inputs Copy/reduce code simplification
-rw-r--r--Makefile2
-rw-r--r--src/all_gather.cu9
-rw-r--r--src/all_reduce.cu7
-rw-r--r--src/broadcast.cu7
-rw-r--r--src/common_kernel.h290
-rw-r--r--src/core.cu37
-rw-r--r--src/core.h29
-rw-r--r--src/enqueue.h4
-rw-r--r--src/reduce.cu7
-rw-r--r--src/reduce_scatter.cu7
10 files changed, 120 insertions, 279 deletions
diff --git a/Makefile b/Makefile
index a706879..8fc02b6 100644
--- a/Makefile
+++ b/Makefile
@@ -52,7 +52,7 @@ endif
NCCL_MAJOR := 1
NCCL_MINOR := 3
-NCCL_PATCH := 1
+NCCL_PATCH := 2
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
diff --git a/src/all_gather.cu b/src/all_gather.cu
index 2dd6246..cb36b71 100644
--- a/src/all_gather.cu
+++ b/src/all_gather.cu
@@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
+#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@@ -164,18 +165,15 @@ __global__ void AllGatherKernel(const KernelArgs<T> args) {
}
}
-#define THREADS 384
+#define THREADS 512
#define UNROLL 8
template<class FUNC, typename T>
ncclResult_t RingAllGather(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@@ -198,6 +196,7 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, int count, ncclDataT
void* recvbuff, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, int count, ncclDataType_t datatype,
void* recvbuff, ncclComm_t comm, cudaStream_t stream) {
+ NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, "AllGather"));
return enqueue<AllGather, FuncNull>(sendbuff, recvbuff, count, datatype, 0, comm, stream);
}
diff --git a/src/all_reduce.cu b/src/all_reduce.cu
index a81ee62..2f38d6e 100644
--- a/src/all_reduce.cu
+++ b/src/all_reduce.cu
@@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
+#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@@ -202,12 +203,9 @@ __global__ void AllReduceKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingAllReduce(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@@ -230,6 +228,7 @@ NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, int
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
+ NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, 0, comm, "AllReduce"));
return enqueue<AllReduce>(sendbuff, recvbuff, count, datatype, op, 0, comm, stream);
}
diff --git a/src/broadcast.cu b/src/broadcast.cu
index 5843922..3a7cb11 100644
--- a/src/broadcast.cu
+++ b/src/broadcast.cu
@@ -5,10 +5,11 @@
************************************************************************/
#include "core.h"
+#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
-#define NUM_SUBSTEPS 2
+#define NUM_SUBSTEPS 4
#define NUM_BUFCHUNKS 2
// Increase Step and boffset for buffer sync
@@ -135,9 +136,6 @@ __global__ void BroadcastKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingBroadcast(void* buff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
if (comm->nRanks != 1) {
KernelArgs<T> args;
ArgsSetup(&args, buff, buff, root, count, comm);
@@ -160,6 +158,7 @@ NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
+ NCCLCHECK(ArgsCheck(buff, buff, count, datatype, ncclSum, root, comm, "Bcast"));
return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream);
}
diff --git a/src/common_kernel.h b/src/common_kernel.h
index 7a17e36..28fbc85 100644
--- a/src/common_kernel.h
+++ b/src/common_kernel.h
@@ -174,32 +174,46 @@ struct MULTI<FUNC, long long> {
}
};
-template<typename T, bool FETCHTWO>
-__device__ inline void FetchOneOrTwo64b(PackType& s0,
- const volatile T * __restrict__ const src0, PackType& s1,
- const volatile T * __restrict__ const src1, const int idx) {
- s0 = (reinterpret_cast<const volatile PackType *>(src0))[idx];
- if (FETCHTWO) {
- s1 = (reinterpret_cast<const volatile PackType *>(src1))[idx];
+template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS>
+__device__ inline void ReduceCopy(
+ const volatile T * __restrict__ const src0,
+ const volatile T * __restrict__ const src1,
+ volatile T * __restrict__ const dest0,
+ volatile T * __restrict__ const dest1, const int idx) {
+ T val = vFetch(src0+idx);
+ if (TWO_INPUTS) {
+ val = FUNC()(val, vFetch(src1+idx));
}
-}
-
-template<typename T, bool STORETWO>
-__device__ inline void StoreOneOrTwo64b(volatile T * __restrict__ const dest0,
- volatile T * __restrict__ const dest1, PackType val, const int idx) {
- (reinterpret_cast<volatile PackType *>(dest0))[idx] = val;
- if (STORETWO) {
- (reinterpret_cast<volatile PackType *>(dest1))[idx] = val;
+ vStore(dest0+idx, val);
+ if (TWO_OUTPUTS) {
+ vStore(dest1+idx, val);
}
}
-template<class FUNC, typename T, bool ISREDUCE>
-__device__ inline PackType ReduceOrCopy64b(const PackType s0,
- const PackType s1) {
- if (ISREDUCE) {
- return MULTI<FUNC, T>()(s0, s1);
- } else {
- return s0;
+template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS, int UNROLL, int THREADS>
+__device__ inline void ReduceCopy64b(
+ const volatile T * __restrict__ const src0,
+ const volatile T * __restrict__ const src1,
+ volatile T * __restrict__ const dest0,
+ volatile T * __restrict__ const dest1, const int offset) {
+ PackType t0[UNROLL];
+ PackType t1[UNROLL];
+ #pragma unroll
+ for (int u = 0; u < UNROLL; ++u) {
+ int idx = offset + u*THREADS;
+ t0[u] = (reinterpret_cast<const volatile PackType *>(src0))[idx];
+ if (TWO_INPUTS) {
+ t1[u] = (reinterpret_cast<const volatile PackType *>(src1))[idx];
+ }
+ }
+ #pragma unroll
+ for (int u = 0; u < UNROLL; ++u) {
+ int idx = offset + u*THREADS;
+ PackType val = TWO_INPUTS ? MULTI<FUNC, T>()(t0[u], t1[u]) : t0[u];
+ (reinterpret_cast<volatile PackType *>(dest0))[idx] = val;
+ if (TWO_OUTPUTS) {
+ (reinterpret_cast<volatile PackType *>(dest1))[idx] = val;
+ }
}
}
@@ -251,9 +265,6 @@ __device__ inline void ReduceOrCopy(const int tid,
return;
}
- const int UNROLL2 = (UNROLL >= 2) ? (UNROLL / 2) : 1;
- const bool NOUNROLL2 = ((UNROLL / 2) == 0);
-
int Npreamble = (N<alignof(PackType)) ? N : AlignUp(dest0, alignof(PackType)) - dest0;
// stage 0: check if we'll be able to use the fast, 64-bit aligned path.
@@ -266,247 +277,60 @@ __device__ inline void ReduceOrCopy(const int tid,
Npreamble = N;
}
-/*
- if (threadIdx.x == 0) {
- printf("** alignable: %s", (alignable ? "YES" : " NO"));
- printf(", dest0 = 0x%08X", dest0);
- printf(", src0 = 0x%08X", src0);
- if (HAS_DEST1) printf(", dest1 = 0x%08X", dest1);
- if (HAS_SRC1) printf(", src1 = 0x%08X", src1);
- printf("\n");
- }
-*/
-
// stage 1: preamble: handle any elements up to the point of everything coming
// into alignment
for (int idx = tid; idx < Npreamble; idx += THREADS) {
// ought to be no way this is ever more than one iteration, except when
// alignable is false
- T val = vFetch(src0+idx);
- if (HAS_SRC1) {
- val = FUNC()(val, vFetch(src1+idx));
- }
-
- vStore(dest0+idx, val);
- if (HAS_DEST1) {
- vStore(dest1+idx, val);
- }
+ ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx);
}
- // reduce N by however many elements we've handled already
- int Ndone = Npreamble;
- int Nrem = N - Ndone;
-
// stage 2: fast path: use 64b loads/stores to do the bulk of the work,
// assuming the pointers we have are all 64-bit alignable.
if (alignable) {
- if (Ndone > 0) {
- // align up pointers
- dest0 += Ndone; if (HAS_DEST1) { dest1 += Ndone; }
- src0 += Ndone; if (HAS_SRC1) { src1 += Ndone; }
- }
+ const int PackFactor = sizeof(PackType) / sizeof(T);
+ int Nrem = N - Npreamble;
+ dest0 += Npreamble; if (HAS_DEST1) { dest1 += Npreamble; }
+ src0 += Npreamble; if (HAS_SRC1) { src1 += Npreamble; }
// stage 2a: main loop
- int Nalign = (Nrem / (sizeof(PackType) / sizeof(T)) / (UNROLL * THREADS))
+ int Nalign2a = (Nrem / (PackFactor * UNROLL * THREADS))
* (UNROLL * THREADS); // round down
#pragma unroll 1 // don't unroll this loop
- for (int idx = tid; idx < Nalign; idx += UNROLL * THREADS) {
- PackType t0[UNROLL2];
- PackType t1[UNROLL2];
- PackType t2[UNROLL2];
-
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1,
- idx + j * THREADS);
-
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]);
-
- if (!NOUNROLL2) {
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- FetchOneOrTwo64b<T, HAS_SRC1>(t0[j], src0, t1[j], src1,
- idx + (UNROLL2 + j) * THREADS);
- }
-
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j], idx + j * THREADS);
-
- if (!NOUNROLL2) {
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- t2[j] = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0[j], t1[j]);
-
- #pragma unroll
- for (int j = 0; j < UNROLL2; ++j)
- StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2[j],
- idx + (UNROLL2 + j) * THREADS);
- }
+ for (int idx = tid; idx < Nalign2a; idx += UNROLL * THREADS) {
+ ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, UNROLL, THREADS>(src0, src1, dest0, dest1, idx);
}
+ int Ndone2a = Nalign2a * PackFactor;
+ Nrem -= Ndone2a;
+
// stage 2b: slightly less optimized for section when we don't have full
// UNROLLs
- int Ndone2a = Nalign * (sizeof(PackType)/sizeof(T));
- Ndone += Ndone2a;
- Nrem = N - Ndone;
-
- // TODO: This kind of pointer update arithmetic is expensive. Should
- // probably find a better way.
- if (Nrem > 0) {
- dest0 += Ndone2a; if (HAS_DEST1) { dest1 += Ndone2a; }
- src0 += Ndone2a; if (HAS_SRC1) { src1 += Ndone2a; }
- }
- Nalign = Nrem / (sizeof(PackType)/sizeof(T));
+ int Nalign2b = Nrem / PackFactor;
#pragma unroll 4
- for (int idx = tid; idx < Nalign; idx += THREADS) {
- PackType t0, t1, t2;
-
- FetchOneOrTwo64b<T, HAS_SRC1>(t0, src0, t1, src1, idx);
- t2 = ReduceOrCopy64b<FUNC, T, HAS_SRC1>(t0, t1);
- StoreOneOrTwo64b<T, HAS_DEST1>(dest0, dest1, t2, idx);
+ for (int idx = Nalign2a + tid; idx < Nalign2a + Nalign2b; idx += THREADS) {
+ ReduceCopy64b<FUNC, T, HAS_SRC1, HAS_DEST1, 1, 0>(src0, src1, dest0, dest1, idx);
}
- // stage 2c: tail
- int Ndone2b = Nalign * (sizeof(PackType)/sizeof(T));
- Ndone += Nalign * (sizeof(PackType)/sizeof(T));
- Nrem = N - Ndone;
+ int Ndone2b = Nalign2b * PackFactor;
+ Nrem -= Ndone2b;
+ int Ndone2 = Ndone2a + Ndone2b;
+ dest0 += Ndone2; if (HAS_DEST1) { dest1 += Ndone2; }
+ src0 += Ndone2; if (HAS_SRC1) { src1 += Ndone2; }
- if (Nrem > 0) {
- dest0 += Ndone2b; if (HAS_DEST1) { dest1 += Ndone2b; }
- src0 += Ndone2b; if (HAS_SRC1) { src1 += Ndone2b; }
- }
+ // stage 2c: tail
for (int idx = tid; idx < Nrem; idx += THREADS) {
// never ought to make it more than one time through this loop. only a
// few threads should even participate
- T val = vFetch(src0+idx);
- if (HAS_SRC1) {
- val = FUNC()(val, vFetch(src1+idx));
- }
-
- vStore(dest0+idx, val);
- if (HAS_DEST1) {
- vStore(dest1+idx, val);
- }
+ ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(src0, src1, dest0, dest1, idx);
}
} // done fast path
}
-template<int THREADS, int UNROLL, typename T>
-__device__ inline void CalcLastChunk(int * const bigSliceN,
- int * const smallSliceN, int * const lastSliceN, int * const numSlices,
- int * const numBigSlices, int * const numSmallSlices, const int N,
- const int numChunks, const int chunkSize) {
- int Nleft = N - ((numChunks - 1) * chunkSize);
- // semi-equally split up the remaining work into numslices slices.
- // it's "semi"-equal because we want the divisions to land as neatly as we
- // can on alignable boundaries
- int NperTile = UNROLL * THREADS * (sizeof(PackType)/sizeof(T));
- int numTiles = (Nleft + NperTile - 1) / NperTile;
- int numTilesPerBigSlice = (numTiles + *numSlices - 1)
- / *numSlices;
- int numTilesPerSmallSlice = numTiles / *numSlices;
-
- *bigSliceN = NperTile * numTilesPerBigSlice;
- *smallSliceN = NperTile * numTilesPerSmallSlice;
- *numBigSlices = numTiles % *numSlices;
- *numSmallSlices = (*smallSliceN > 0) ?
- *numSlices - *numBigSlices : 0;
-
- // the lastSlice will take the place of one of the small slices unless
- // there are no small slices (because this is a very small reduction), in
- // which case we replace one of the big slices and leave the small slices
- // as 0.
- if (*numSmallSlices > 0) {
- --*numSmallSlices;
- if (*numSmallSlices == 0)
- *smallSliceN = 0;
- }
- else {
- --*numBigSlices;
- if (*numBigSlices == 0)
- *bigSliceN = 0;
- }
-
- *lastSliceN = Nleft -
- (*numBigSlices * *bigSliceN
- + *numSmallSlices * *smallSliceN);
-
- // in cases where args.N % numSlices is pretty small, we'd rather have one
- // slightly big last slice than one big slice, a bunch of small slices,
- // and one smaller last slice
- if ((*numBigSlices == 1) &&
- (*numSmallSlices == *numSlices - 2) &&
- (*lastSliceN < *smallSliceN)) {
- *numBigSlices += *numSmallSlices;
- *numSmallSlices = 0;
- *bigSliceN = *smallSliceN;
- *smallSliceN = 0;
- *lastSliceN = Nleft -
- *numBigSlices * *bigSliceN;
- }
-
- // done recalculating
- *numSlices = *numBigSlices +
- *numSmallSlices + 1;
-}
-
-// Kernel launch
-template<typename T>
-struct KernelArgs {
- // general parameters
- int nRanks;
- int root;
- int buffSize;
- int N;
- int opIndex;
- volatile int * __restrict__ opCounter;
- bool pushrecv;
-
- // some pre-computed sizes
- int SliceSize;
- int SliceOffset;
- int ChunkSize;
- int NumChunks;
-
- // local and remote input, output, and buffer
- const T * __restrict__ ThisInput;
- T * __restrict__ ThisOutput;
-
- DevRing<char>* ring;
-};
-
-template<typename T>
-void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff,
- const int root, const int count, ncclComm *comm) {
- args->nRanks = comm->nRanks;
- args->root = root;
- args->buffSize = comm->buffSize;
- args->N = count;
- args->opIndex = comm->opSched;
- args->opCounter = comm->opCounter;
- args->ThisInput = (const T*)sendbuff;
- args->ThisOutput = (T*)recvbuff;
- args->ring = comm->devRing;
- args->pushrecv = comm->globalMemSpace;
-}
-
-#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \
- args, stream) do { \
- dim3 grid(1, 1, 1); \
- dim3 block(THREADS+1, 1, 1); \
- void* argptrs[] = {&args}; \
- CUDACHECK(cudaLaunchKernel( \
- (void*)K<THREADS, UNROLL, FUNC, T>, \
- grid, block, argptrs, 0, stream)); \
-} while (0)
-
template <typename T>
__device__ inline void incrementOpCounter(const KernelArgs<T> *args) {
// increment comm's operation counts
diff --git a/src/core.cu b/src/core.cu
index be4be06..bf1dc6e 100644
--- a/src/core.cu
+++ b/src/core.cu
@@ -8,6 +8,7 @@
#include <stdlib.h>
#include "core.h"
#include "libwrap.h"
+#include "common_coll.h"
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
@@ -22,6 +23,7 @@ DebugLevel ncclDebugLevel;
NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out);
ncclResult_t ncclGetUniqueId(ncclUniqueId* out) {
+ NCCLCHECK(PtrCheck(out, "GetUniqueId", "out"));
pid_t pid = getpid();
static int count = 0;
int commId = __sync_fetch_and_add(&count, 1);
@@ -578,15 +580,6 @@ static void commFree(ncclComm_t comm) {
}
static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId* commId, int rank) {
- if (ndev < 1) {
- WARN("invalid device count (%d) requested", ndev);
- return ncclUnsupportedDeviceCount;
- }
- if (rank >= ndev || rank < 0) {
- WARN("rank %d exceeds ndev=%d", rank, ndev);
- return ncclInvalidRank;
- }
-
size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(NodeRef);
struct ncclComm* comm = (struct ncclComm*)malloc(commBytes);
if (comm == NULL) {
@@ -731,6 +724,17 @@ NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniq
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
if (myrank == 0) showVersion();
+ NCCLCHECK(PtrCheck(newcomm, "CommInitRank", "newcomm"));
+
+ if (ndev < 1) {
+ WARN("Invalid device count requested : %d", ndev);
+ return ncclUnsupportedDeviceCount;
+ }
+ if (myrank >= ndev || myrank < 0) {
+ WARN("Invalid rank %d, should be in the range 0..%d", myrank, ndev-1);
+ return ncclInvalidRank;
+ }
+
if (strlen(commId.internal) < 1 ||
strlen(commId.internal) >= NCCL_UNIQUE_ID_BYTES) {
WARN("rank %d invalid commId", myrank);
@@ -819,6 +823,13 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
showVersion();
+ NCCLCHECK(PtrCheck(comms, "CommInitRank", "comms"));
+
+ if (ndev < 1) {
+ WARN("Invalid device count requested : %d", ndev);
+ return ncclUnsupportedDeviceCount;
+ }
+
ncclResult_t res;
int savedDevice;
RankEntry* ranks = NULL;
@@ -949,7 +960,7 @@ void ncclCommDestroy(ncclComm_t comm) {
int commDevice = comm->cudaDev;
if (savedDevice != commDevice) {
- CUDACHECK(cudaSetDevice(commDevice));
+ CUDACHECK(cudaSetDevice(commDevice), void());
}
commFree(comm);
@@ -982,18 +993,24 @@ const char* ncclGetErrorString(ncclResult_t code) {
NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
+ NCCLCHECK(PtrCheck(comm, "CommCount", "comm"));
+ NCCLCHECK(PtrCheck(count, "CommCount", "count"));
*count = comm->nRanks;
return ncclSuccess;
}
NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid);
ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
+ NCCLCHECK(PtrCheck(comm, "CommCuDevice", "comm"));
+ NCCLCHECK(PtrCheck(devid, "CommCuDevice", "devid"));
*devid = comm->cudaDev;
return ncclSuccess;
}
NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank);
ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
+ NCCLCHECK(PtrCheck(comm, "CommUserRank", "comm"));
+ NCCLCHECK(PtrCheck(rank, "CommUserRank", "rank"));
*rank = comm->rank;
return ncclSuccess;
}
diff --git a/src/core.h b/src/core.h
index bbabf49..17794d7 100644
--- a/src/core.h
+++ b/src/core.h
@@ -12,18 +12,6 @@
#include <cstdio>
#include <cuda_runtime.h>
-
-// DIE on error
-#define CUDACHECK(cmd) do { \
- cudaError_t e = cmd; \
- if( e != cudaSuccess ) { \
- printf("Cuda failure %s:%d '%s'\n", \
- __FILE__,__LINE__,cudaGetErrorString(e)); \
- exit(EXIT_FAILURE); \
- } \
-} while(false)
-
-
#define MAXRANKS 32
#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
#define NCCL_MEM_PAD_ALIGN 65536
@@ -136,6 +124,23 @@ extern DebugLevel ncclDebugLevel;
} \
} while(0)
+// Check CUDA calls
+#define CUDACHECK(cmd, retcode) do { \
+ cudaError_t e = cmd; \
+ if( e != cudaSuccess ) { \
+ WARN("Cuda failure '%s'\n", cudaGetErrorString(e)); \
+ return retcode; \
+ } \
+} while(false)
+
+// Propagate errors up
+#define NCCLCHECK(call) do { \
+ ncclResult_t res = call; \
+ if (res != ncclSuccess) { \
+ return res; \
+ } \
+} while (0);
+
#ifdef PROFAPI
#define NCCL_API(ret, func, args...) \
__attribute__ ((visibility("default"))) \
diff --git a/src/enqueue.h b/src/enqueue.h
index 01c44c2..43d570e 100644
--- a/src/enqueue.h
+++ b/src/enqueue.h
@@ -34,7 +34,7 @@ ncclResult_t enqueue(const void* sendbuff,
{
if (stream != comm->prevStream) { // sync required for calls in different streams
comm->prevStream = stream;
- CUDACHECK( cudaStreamWaitEvent(stream, comm->doneEvent, 0) );
+ CUDACHECK(cudaStreamWaitEvent(stream, comm->doneEvent, 0), ncclUnhandledCudaError);
}
ncclResult_t ret;
@@ -42,7 +42,7 @@ ncclResult_t enqueue(const void* sendbuff,
// Always have to record done event because we don't know what stream next
// collective will be in.
- CUDACHECK( cudaEventRecord(comm->doneEvent, stream) );
+ CUDACHECK(cudaEventRecord(comm->doneEvent, stream), ncclUnhandledCudaError);
comm->opSched += 1;
return ret;
}
diff --git a/src/reduce.cu b/src/reduce.cu
index f281ce8..9effbe9 100644
--- a/src/reduce.cu
+++ b/src/reduce.cu
@@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
+#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@@ -117,12 +118,9 @@ __global__ void ReduceKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingReduce(const void* sendbuff, void* recvbuff, const int count, const int root,
ncclComm* comm, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, root, count, comm);
@@ -145,6 +143,7 @@ NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, int cou
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
+ NCCLCHECK(ArgsCheck(sendbuff, recvbuff, count, datatype, op, root, comm, "Reduce"));
return enqueue<ReduceFunctor>(sendbuff, recvbuff, count, datatype, op, root, comm, stream);
}
diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu
index f13cbfb..b1100dd 100644
--- a/src/reduce_scatter.cu
+++ b/src/reduce_scatter.cu
@@ -5,6 +5,7 @@
************************************************************************/
#include "core.h"
+#include "common_coll.h"
#include "enqueue.h"
#include "primitives.h"
@@ -133,12 +134,9 @@ __global__ void ReduceScatterKernel(const KernelArgs<T> args) {
template<class FUNC, typename T>
ncclResult_t RingReduceScatter(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream), ncclUnhandledCudaError);
} else {
KernelArgs<T> args;
ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
@@ -161,6 +159,7 @@ NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
+ NCCLCHECK(ArgsCheck(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, "ReduceScatter"));
return enqueue<ReduceScatter>(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream);
}