Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWojtek Wasko <wwasko@nvidia.com>2018-10-09 12:02:01 +0300
committerWojtek Wasko <wwasko@nvidia.com>2018-10-11 21:23:09 +0300
commit6f4f4e92688bb9a32327b1be031bf1b87a3bc429 (patch)
treed838cf3529b733ac89aa73cbfcd55336d43f3840
parentf93fe9bfd94884cec2ba711897222e0df5569a53 (diff)
Abort mechanism and API for querying async errors.dev/wwasko/abort
Change-Id: If1f8fadc719b136788609a10416658f3ef76cf35
-rw-r--r--src/collectives/device/all_gather.h18
-rw-r--r--src/collectives/device/all_reduce.h19
-rw-r--r--src/collectives/device/broadcast.h18
-rw-r--r--src/collectives/device/ll_kernel.h76
-rw-r--r--src/collectives/device/primitives.h62
-rw-r--r--src/collectives/device/reduce.h17
-rw-r--r--src/collectives/device/reduce_scatter.h17
-rw-r--r--src/include/core.h8
-rw-r--r--src/init.cu21
-rw-r--r--src/misc/enqueue.cu3
-rw-r--r--src/misc/group.cu12
-rw-r--r--src/nccl.h.in7
-rw-r--r--src/transport.cu1
13 files changed, 213 insertions, 66 deletions
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index a30e575..cef6097 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -26,8 +26,8 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
int prevdirect = ring->recv.conn.direct;
int nextdirect = ring->send.conn.direct;
- WaitFlag waitDoneFromNext(ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLGATHER_SUBSTEPS);
+ WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, ALLGATHER_SUBSTEPS);
PostFlag postDoneToPrev(ring->recv.conn.head, ALLGATHER_SUBSTEPS, NULL, 0);
PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS);
@@ -38,13 +38,14 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
const int buffSize = ring->buffSize / sizeof(T);
const int sliceSize = buffSize / ALLGATHER_BUFCHUNKS;
const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
+ uint32_t shouldExit = 0;
if (tid == 0) {
// Update in case we skipped some collectives
*ring->recv.conn.opCount = args->opCount;
// Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
+ WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0);
+ waitOpCountNext.wait(&shouldExit, args->opCount);
if (prevdirect) {
*ring->recv.conn.ptrExchange = args->ThisOutput;
}
@@ -55,7 +56,7 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
*ptr = nullptr;
}
}
- __syncthreads();
+ exitIfAbortBarrier(shouldExit);
uint64_t step = 0ULL;
int poffset, noffset = 0;
@@ -157,12 +158,13 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
}
if (tid == 0) {
- waitDoneFromNext.wait(ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS));
+ waitDoneFromNext.wait(&shouldExit, ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS));
*ring->send.conn.head = 0ULL;
*ring->recv.conn.tail = 0ULL;
__threadfence_system();
*ring->recv.conn.opCount = args->opCount+1;
}
+ exitIfAbortBarrier(shouldExit);
}
#include "ll_kernel.h"
@@ -223,11 +225,13 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
if (thisInput + chunkOffset == thisOutput + offset) { // In place
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + chunkOffset,
nextOutput + noffset,
maxOffset, nflag, llNthreads);
} else {
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + chunkOffset,
thisOutput + offset,
nextOutput + noffset,
@@ -244,6 +248,7 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
prevInput + poffset,
thisOutput + offset,
nextOutput + noffset,
@@ -259,6 +264,7 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) {
offset = chunkOffset + rankDest * size;
LL::ReduceCopy(
+ comm->abortFlag,
prevInput + poffset,
thisOutput + offset,
maxOffset, pflag, llNthreads);
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index d7abc64..6d6e2a0 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -26,8 +26,8 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) {
int prevdirect = ring->recv.conn.direct;
int nextdirect = ring->send.conn.direct;
- WaitFlag waitDoneFromNext(ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLREDUCE_SUBSTEPS);
+ WaitFlag waitDoneFromNext(args->comm->abortFlag, ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(args->comm->abortFlag, ring->recv.conn.tail, ALLREDUCE_SUBSTEPS);
PostFlag postDoneToPrev(ring->recv.conn.head, ALLREDUCE_SUBSTEPS, NULL, 0);
PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS);
@@ -39,13 +39,14 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) {
const int buffSize = ring->buffSize / sizeof(T);
const int sliceSize = buffSize / ALLREDUCE_BUFCHUNKS;
const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
+ uint32_t shouldExit = 0;
if (tid == 0) {
// Update in case we skipped some collectives
*ring->recv.conn.opCount = args->opCount;
// Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
+ WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0);
+ waitOpCountNext.wait(&shouldExit, args->opCount);
if (prevdirect) {
*ring->recv.conn.ptrExchange = args->ThisOutput;
}
@@ -56,7 +57,7 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) {
*ptr = nullptr;
}
}
- __syncthreads();
+ exitIfAbortBarrier(shouldExit);
uint64_t step = 0ULL;
int poffset, noffset = 0;
@@ -188,12 +189,13 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) {
if (tid == 0) {
// Wait for next to have consumed all data before we reset the flag
- waitDoneFromNext.wait(ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS));
+ waitDoneFromNext.wait(&shouldExit, ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS));
*ring->send.conn.head = 0ULL;
*ring->recv.conn.tail = 0ULL;
__threadfence_system();
*ring->recv.conn.opCount = args->opCount+1;
}
+ exitIfAbortBarrier(shouldExit);
}
#include "ll_kernel.h"
@@ -254,6 +256,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ args->comm->abortFlag,
thisInput + offset,
nextOutput + noffset,
maxOffset, nflag, llNthreads);
@@ -269,6 +272,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ args->comm->abortFlag,
thisInput + offset,
prevInput + poffset,
nextOutput + noffset,
@@ -287,6 +291,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ args->comm->abortFlag,
thisInput + offset,
prevInput + poffset,
thisOutput + offset,
@@ -305,6 +310,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ args->comm->abortFlag,
prevInput + poffset,
thisOutput + offset,
nextOutput + noffset,
@@ -322,6 +328,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
// Here we need to copy from buffer to this output.
LL::ReduceCopy(
+ args->comm->abortFlag,
prevInput + poffset,
thisOutput + offset,
maxOffset, pflag, llNthreads);
diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h
index c2f6d00..8f89e95 100644
--- a/src/collectives/device/broadcast.h
+++ b/src/collectives/device/broadcast.h
@@ -25,8 +25,8 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) {
int prevdirect = ring->recv.conn.direct;
int nextdirect = ring->send.conn.direct;
- WaitFlag waitDoneFromNext(ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0);
+ WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, 0);
PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0);
PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, BROADCAST_BUFCHUNKS*BROADCAST_SUBSTEPS);
@@ -39,14 +39,15 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) {
const int rank = ring->devUserRanks[0];
const int nextRank = ring->devUserRanks[1];
const int root = args->root;
+ uint32_t shouldExit = 0;
if (tid == 0) {
// Update in case we skipped some collectives
*ring->recv.conn.opCount = args->opCount;
if (nextRank != root) {
// Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
+ WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0);
+ waitOpCountNext.wait(&shouldExit, args->opCount);
}
if (rank != root && prevdirect) {
*ring->recv.conn.ptrExchange = args->ThisOutput;
@@ -58,7 +59,7 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) {
*ptr = nullptr;
}
}
- __syncthreads();
+ exitIfAbortBarrier(shouldExit);
uint64_t step = 0ULL;
int boffset = 0;
@@ -129,13 +130,14 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) {
if (tid == 0) {
if (nextRank != root) {
// Wait for next to have consumed data before resetting the flag
- waitDoneFromNext.wait(BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1));
+ waitDoneFromNext.wait(&shouldExit, BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1));
*ring->send.conn.head = 0ULL;
}
*ring->recv.conn.tail = 0ULL;
__threadfence_system();
*ring->recv.conn.opCount = args->opCount+1;
}
+ exitIfAbortBarrier(shouldExit);
}
#include "ll_kernel.h"
@@ -188,11 +190,13 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
if (thisInput == thisOutput) {
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
nextOutput + boffset,
maxOffset, flag, llNthreads);
} else {
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
thisOutput + offset,
nextOutput + boffset,
@@ -202,6 +206,7 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) {
NEXT_STEP_LL;
} else if (nextRank == root) {
LL::ReduceCopy(
+ comm->abortFlag,
prevInput + boffset,
thisOutput + offset,
maxOffset, flag, llNthreads);
@@ -210,6 +215,7 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) {
} else {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
prevInput + boffset,
thisOutput + offset,
nextOutput + boffset,
diff --git a/src/collectives/device/ll_kernel.h b/src/collectives/device/ll_kernel.h
index 5ec3c9a..f8f70ed 100644
--- a/src/collectives/device/ll_kernel.h
+++ b/src/collectives/device/ll_kernel.h
@@ -7,10 +7,32 @@
#ifndef NCCL_LL_KERNEL_H_
#define NCCL_LL_KERNEL_H_
-static __device__ uint64_t readLL(union ncclLLFifoLine* src, uint32_t flag) {
+#define LL_SPINS_BEFORE_CHECK_ABORT 100000
+
+// Each thread sets a predicate to true if val == 1
+// nthreads threads enter the barrier and do a popc on their predicates being True
+// If any of the thread's predicate was True, all the threads call exit()
+inline __device__ void exitIfAbortBarrier(uint32_t abortFlag, int nthreads) {
+ uint32_t popc;
+ asm ("{");
+ asm volatile (" .reg .pred barr_pred;");
+ asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abortFlag));
+ asm volatile (" bar.red.popc.u32 %0, 14, %1, barr_pred;" : "=r"(popc) : "r"(nthreads));
+ asm ("}");
+ if (popc) { asm volatile ("exit;"); }
+}
+
+static __device__ uint64_t readLL(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, uint32_t flag) {
uint32_t data1, flag1, data2, flag2;
+ size_t spins = 0;
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
+ if (++spins == LL_SPINS_BEFORE_CHECK_ABORT) {
+ if (*abortFlag != 0) {
+ asm volatile ("exit;");
+ }
+ spins = 0;
+ }
} while ((flag1 != flag) || (flag2 != flag));
uint64_t val64 = data1 + (((uint64_t)data2) << 32);
return val64;
@@ -34,7 +56,7 @@ template <typename T, class FUNC>
class LLPrimitives {
private:
template <int HAS_SRC1, int HAS_SRC2, int HAS_DST1, int HAS_DST2>
- static __device__ void ReduceCopyGeneric(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
+ static __device__ void ReduceCopyGeneric(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
if (size <= 0) return;
size_t size64 = size * sizeof(T) / sizeof(uint64_t);
uint64_t* src1A = (uint64_t*)src1;
@@ -46,9 +68,9 @@ class LLPrimitives {
uint64_t val;
if (HAS_SRC1) {
val = readAL(src1A+offset);
- if (HAS_SRC2) val = MULTI<FUNC, T>()(readLL(src2+offset, iflag), val);
+ if (HAS_SRC2) val = MULTI<FUNC, T>()(readLL(abortFlag, src2+offset, iflag), val);
} else if (HAS_SRC2) {
- val = readLL(src2+offset, iflag);
+ val = readLL(abortFlag, src2+offset, iflag);
}
if (HAS_DST1) storeAL(dst1A+offset, val);
if (HAS_DST2) storeLL(dst2+offset, val, oflag);
@@ -64,7 +86,7 @@ class LLPrimitives {
T* vals = (T*)&lastVal;
if (HAS_SRC2) {
- uint64_t lastVal2 = readLL(src2+size64, iflag);
+ uint64_t lastVal2 = readLL(abortFlag, src2+size64, iflag);
T* src2B = (T*)&lastVal2;
for (int offset = 0; offset < sizeRem; offset++) {
vals[offset] = HAS_SRC1 ? FUNC()(src2B[offset], src1B[offset]) : src2B[offset];
@@ -83,32 +105,32 @@ class LLPrimitives {
}
}
public:
- static __device__ void ReduceCopy(const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 0, 0, 1>(src, NULL, NULL, dst, size, 0, oflag, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) {
+ return ReduceCopyGeneric<1, 0, 0, 1>(abortFlag, src, NULL, NULL, dst, size, 0, oflag, nthreads);
}
- static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) {
- return ReduceCopyGeneric<0, 1, 1, 0>(NULL, src, dst, NULL, size, iflag, 0, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) {
+ return ReduceCopyGeneric<0, 1, 1, 0>(abortFlag, NULL, src, dst, NULL, size, iflag, 0, nthreads);
}
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 0, 1>(src1, src2, NULL, dst, size, iflag, oflag, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
+ return ReduceCopyGeneric<1, 1, 0, 1>(abortFlag, src1, src2, NULL, dst, size, iflag, oflag, nthreads);
}
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 1, 0>(src1, src2, dst, NULL, size, iflag, 0, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) {
+ return ReduceCopyGeneric<1, 1, 1, 0>(abortFlag, src1, src2, dst, NULL, size, iflag, 0, nthreads);
}
- static __device__ void ReduceCopy(const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 0, 1, 1>(src, NULL, dst1, dst2, size, 0, oflag, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) {
+ return ReduceCopyGeneric<1, 0, 1, 1>(abortFlag, src, NULL, dst1, dst2, size, 0, oflag, nthreads);
}
- static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<0, 1, 1, 1>(NULL, src, dst1, dst2, size, iflag, oflag, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
+ return ReduceCopyGeneric<0, 1, 1, 1>(abortFlag, NULL, src, dst1, dst2, size, iflag, oflag, nthreads);
}
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 1, 1>(src1, src2, dst1, dst2, size, iflag, oflag, nthreads);
+ static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
+ return ReduceCopyGeneric<1, 1, 1, 1>(abortFlag, src1, src2, dst1, dst2, size, iflag, oflag, nthreads);
}
};
@@ -118,12 +140,22 @@ class LLPrimitives {
(step % NCCL_LL_CHUNKS)
#define WAIT_NEXT \
- if (tid == 0) { \
+ { \
+ uint32_t abortFlag = 0; \
+ size_t spins = 0; \
while (sendHead + NCCL_LL_CHUNKS <= step) { \
sendHead = sendHeadPtr[0]; \
+ ++spins; \
+ if (spins == LL_SPINS_BEFORE_CHECK_ABORT) { \
+ abortFlag = *args->comm->abortFlag; \
+ if (abortFlag != 0) { \
+ break; \
+ } \
+ spins = 0; \
+ } \
} \
- } \
- asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads));
+ exitIfAbortBarrier(abortFlag, llNthreads); \
+ }
#define POST_SIZE \
if (tid == 0 && sizesFifo) sizesFifo[step % NCCL_LL_CHUNKS] = (maxOffset <= 0) ? -1 : (maxOffset*2*(int)sizeof(T));
diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h
index 8df152e..ae9fab2 100644
--- a/src/collectives/device/primitives.h
+++ b/src/collectives/device/primitives.h
@@ -24,17 +24,39 @@
* corresponding substep by previous step) before executing the transfer.
* After each substep is transfered, all PostFlag arguments get updated to
* the value SUBSTEPS*step+substep+1.
+ *
+ * The wait operation will read the abortFlag after spinning for
+ * SPINS_BEFORE_CHECK_ABORT iterations. If it reads the abort flag and the
+ * flag is non-zero, it will leave without waiting for the wait condition
+ * to be achieved. The output abortFlag argument will be updated in case
+ * the abort flag is read.
*/
+#define SPINS_BEFORE_CHECK_ABORT 100000
class WaitFlag {
+ volatile uint32_t * abortFlag;
volatile uint64_t * const flag;
const int shift;
public:
__device__ __forceinline__
- WaitFlag(volatile uint64_t * const flag, const int shift) : flag(flag), shift(shift) { }
+ WaitFlag(volatile uint32_t *abortFlag, volatile uint64_t * const flag, const int shift)
+ : abortFlag(abortFlag), flag(flag), shift(shift) { }
__device__ __forceinline__
- void wait(uint64_t val) { while ((*flag + shift) < val) /*SPIN*/; }
+ void wait(uint32_t *outAbortFlag, uint64_t val) {
+ size_t spins = 0;
+ while ((*flag + shift) < val) {
+ /*SPIN*/
+ ++spins;
+ if (spins == SPINS_BEFORE_CHECK_ABORT) {
+ *outAbortFlag = *abortFlag;
+ if (*outAbortFlag != 0) {
+ return;
+ }
+ spins = 0;
+ }
+ }
+ }
};
@@ -67,17 +89,17 @@ bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
-void WaitOnFlags(uint64_t val) { }
+void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val) { }
template <typename... TAIL_Ts> __device__ __forceinline__
-void WaitOnFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
- flag.wait(val);
- WaitOnFlags(val, tail...);
+void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
+ flag.wait(outAbortFlag, val);
+ WaitOnFlags(outAbortFlag, val, tail...);
}
template <typename... TAIL_Ts> __device__ __forceinline__
-void WaitOnFlags(uint64_t val, PostFlag, TAIL_Ts... tail) {
- WaitOnFlags(val, tail...);
+void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val, PostFlag, TAIL_Ts... tail) {
+ WaitOnFlags(outAbortFlag, val, tail...);
}
@@ -125,6 +147,20 @@ nullptr_t ptradd(nullptr_t ptr, int i) {
}
+// Each thread sets a predicate to true if val == 1
+// all CTA's threads enter the barrier and do a popc on their predicates being True
+// If any of the thread's predicate was True, all the threads call exit()
+static inline __device__
+void exitIfAbortBarrier(uint32_t val) {
+ uint32_t popc;
+ asm ("{");
+ asm volatile (" .reg .pred barr_pred;");
+ asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(val));
+ asm volatile (" bar.red.popc.u32 %0, 14, barr_pred;" : "=r"(popc));
+ asm ("}");
+ if (popc) { asm volatile ("exit;"); }
+}
+
// Implementation of primitive types
template <int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
class Primitives {
@@ -140,6 +176,8 @@ class Primitives {
DST2_T dst2,
int len, int maxoffset, uint64_t step, SYNC_Ts... flags) {
+ uint32_t abort = 0;
+
enum { noSrc2 = std::is_same<SRC2_T, nullptr_t>::value };
enum { noDst2 = std::is_same<DST2_T, nullptr_t>::value };
static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
@@ -158,7 +196,7 @@ class Primitives {
if (tid < nthreads) {
if (AnyAre<WaitFlag>(flags...)) {
if (tid == 0) {
- WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
+ WaitOnFlags(&abort, SUBSTEPS*step + sub + 1, flags...);
}
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
}
@@ -178,12 +216,10 @@ class Primitives {
ptradd(src2, sliceOffset),
realSize
);
- if (AnyAre<PostFlag>(flags...)) {
- __syncthreads();
- }
+ exitIfAbortBarrier(abort);
} else {
+ exitIfAbortBarrier(abort);
if (AnyAre<PostFlag>(flags...)) {
- __syncthreads();
PostSizeToFlags(SUBSTEPS*step+sub, realSize*sizeof(T), flags...);
__threadfence_system();
PostToFlags(SUBSTEPS*step + sub + 1, flags...);
diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h
index f5694b1..f815104 100644
--- a/src/collectives/device/reduce.h
+++ b/src/collectives/device/reduce.h
@@ -22,8 +22,8 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) {
struct ncclComm* comm = args->comm;
struct ncclRing* ring = comm->rings+blockIdx.x;
- WaitFlag waitDoneFromNext(ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0);
+ WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, 0);
PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0);
PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCE_BUFCHUNKS*REDUCE_SUBSTEPS);
@@ -37,6 +37,7 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) {
const int rank = ring->devUserRanks[0];
const int prevRank = ring->devUserRanks[nranks-1];
const int root = args->root;
+ uint32_t shouldExit = 0;
if (tid == 0) {
// Update in case we skipped some collectives
@@ -44,11 +45,11 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) {
if (rank != root) {
// Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
+ WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0);
+ waitOpCountNext.wait(&shouldExit, args->opCount);
}
}
- __syncthreads();
+ exitIfAbortBarrier(shouldExit);
uint64_t step = 0ULL;
int boffset = 0;
@@ -97,13 +98,14 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) {
if (tid == 0) {
if (rank != root) {
// Wait for next to have consumed data before resetting the flag
- waitDoneFromNext.wait(REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1));
+ waitDoneFromNext.wait(&shouldExit, REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1));
*ring->send.conn.head = 0ULL;
}
*ring->recv.conn.tail = 0ULL;
__threadfence_system();
*ring->recv.conn.opCount = args->opCount+1;
}
+ exitIfAbortBarrier(shouldExit);
}
#include "ll_kernel.h"
@@ -156,6 +158,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) {
if (prevRank == root) {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
nextOutput + boffset,
maxOffset, flag, llNthreads);
@@ -163,6 +166,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) {
NEXT_STEP_LL;
} else if (rank == root) {
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
prevInput + boffset,
thisOutput + offset,
@@ -172,6 +176,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) {
} else {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
prevInput + boffset,
nextOutput + boffset,
diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h
index cad011b..842cf76 100644
--- a/src/collectives/device/reduce_scatter.h
+++ b/src/collectives/device/reduce_scatter.h
@@ -23,8 +23,8 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) {
struct ncclComm* comm = args->comm;
struct ncclRing* ring = comm->rings+blockIdx.x;
- WaitFlag waitDoneFromNext(ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS);
+ WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS);
PostFlag postDoneToPrev(ring->recv.conn.head, REDUCESCATTER_SUBSTEPS, NULL, 0);
PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS);
@@ -35,15 +35,16 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) {
const int buffSize = ring->buffSize / sizeof(T);
const int sliceSize = buffSize / REDUCESCATTER_BUFCHUNKS;
const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
+ uint32_t shouldExit = 0;
if (tid == 0) {
// Update in case we skipped some collectives
*ring->recv.conn.opCount = args->opCount;
// Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
+ WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0);
+ waitOpCountNext.wait(&shouldExit, args->opCount);
}
- __syncthreads();
+ exitIfAbortBarrier(shouldExit);
uint64_t step = 0ULL;
int poffset, noffset = 0;
@@ -111,12 +112,13 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) {
}
if (tid == 0) {
- waitDoneFromNext.wait(REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS));
+ waitDoneFromNext.wait(&shouldExit, REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS));
*ring->send.conn.head = 0ULL;
*ring->recv.conn.tail = 0ULL;
__threadfence_system();
*ring->recv.conn.opCount = args->opCount+1;
}
+ exitIfAbortBarrier(shouldExit);
}
#include "ll_kernel.h"
@@ -176,6 +178,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
nextOutput + noffset,
maxOffset, nflag, llNthreads);
@@ -190,6 +193,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
WAIT_NEXT;
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
prevInput + poffset,
nextOutput + noffset,
@@ -206,6 +210,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
offset = chunkOffset + rankDest * size;
LL::ReduceCopy(
+ comm->abortFlag,
thisInput + offset,
prevInput + poffset,
thisOutput + chunkOffset,
diff --git a/src/include/core.h b/src/include/core.h
index 66b353c..20762c8 100644
--- a/src/include/core.h
+++ b/src/include/core.h
@@ -226,6 +226,14 @@ struct ncclComm {
int groupCudaStream;
cudaStream_t groupStream;
+ // Whether there has been a fatal error on the communicator
+ ncclResult_t fatalError;
+
+ // Whether the operations enqueued on this communicator should abort
+ // On host: this pointer has been obtained from cudaHostAlloc(cudaHostAllocMapped)
+ // On device: this pointer has been obtained from cudaHostGetDevicePointer()
+ volatile uint32_t *abortFlag;
+
// Device copy of the communicator
struct ncclComm *devComm;
diff --git a/src/init.cu b/src/init.cu
index 6669251..77807ca 100644
--- a/src/init.cu
+++ b/src/init.cu
@@ -136,6 +136,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
free(comm->intraCGMode);
free(comm->intraCC);
}
+ CUDACHECK(cudaFreeHost((void *)comm->abortFlag));
free(comm);
return ncclSuccess;
@@ -173,6 +174,10 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
// Don't allow the user to overload the default setting in older CUDA builds
comm->groupCudaStream = NCCL_GROUP_CUDA_STREAM;
#endif
+ comm->fatalError = ncclSuccess;
+
+ CUDACHECK(cudaHostAlloc((void**) &comm->abortFlag, sizeof(uint32_t), cudaHostAllocMapped));
+ *comm->abortFlag = 0;
comm->argsptr = &comm->args;
@@ -189,6 +194,10 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
for (int r=0; r<comm->nRings; r++) {
NCCLCHECK(ncclCudaMemcpy(comm->rings[r].devUserRanks, comm->rings[r].userRanks, comm->nRanks));
}
+ // Copy the device-accessible pointer to comm->abortFlag
+ void *devAbortFlag;
+ CUDACHECK(cudaHostGetDevicePointer(&devAbortFlag, (uint32_t *)comm->abortFlag, 0));
+ CUDACHECK(cudaMemcpy(&comm->devComm->abortFlag, &devAbortFlag, sizeof(uint32_t *), cudaMemcpyHostToDevice));
return ncclSuccess;
}
@@ -769,6 +778,10 @@ ncclResult_t ncclCommDestroy(ncclComm_t comm) {
CUDACHECK(cudaSetDevice(commDevice));
}
+ // Ask anything that might still be running on the device to quit
+ *comm->abortFlag = 1;
+ CUDACHECK(cudaStreamSynchronize(comm->groupStream));
+
NCCLCHECK(commFree(comm));
if (savedDevice != commDevice)
@@ -790,6 +803,14 @@ const char* ncclGetErrorString(ncclResult_t code) {
}
}
+NCCL_API(ncclResult_t, ncclCommGetAsyncError, ncclComm_t comm, ncclResult_t *asyncError);
+ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) {
+ NCCLCHECK(PtrCheck(comm, "ncclGetAsyncError", "comm"));
+ NCCLCHECK(PtrCheck(asyncError, "ncclGetAsyncError", "asyncError"));
+ *asyncError = comm->fatalError;
+ return ncclSuccess;
+}
+
NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
NCCLCHECK(PtrCheck(comm, "CommCount", "comm"));
diff --git a/src/misc/enqueue.cu b/src/misc/enqueue.cu
index dc5d9cc..57d1808 100644
--- a/src/misc/enqueue.cu
+++ b/src/misc/enqueue.cu
@@ -219,6 +219,9 @@ ncclResult_t ncclEnqueueCheck(ncclFunc_t func, const char* primName, const void*
void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root,
ncclComm_t comm, cudaStream_t stream) {
if (comm == NULL) return ncclInvalidArgument;
+ if (comm->fatalError != ncclSuccess) {
+ return ncclInvalidUsage;
+ }
// Launch asynchronously if needed
if (ncclAsyncMode()) {
ncclResult_t ret = ncclSuccess;
diff --git a/src/misc/group.cu b/src/misc/group.cu
index c7b31cf..c8e2d6d 100644
--- a/src/misc/group.cu
+++ b/src/misc/group.cu
@@ -124,6 +124,18 @@ ncclResult_t ncclGroupEnd() {
ncclResult_t ret = ncclGroupError;
if (ret != ncclSuccess) goto group_cleanup;
+ // check if any of the communicators used in this group has
+ // encountered a fatal error.
+ for (int i=0; i<ncclGroupIndex; i++) {
+ struct ncclAsyncArgs* args = ncclGroupArgs+i;
+ if (args->funcType == ASYNC_FUNC_COLL) {
+ if (args->coll.comm->fatalError != ncclSuccess) {
+ ret = ncclInvalidUsage;
+ goto group_cleanup;
+ }
+ }
+ }
+
/* Collectives are done in three steps :
* 1. Barrier Check In. Only the last call may call cudaLaunchKernel[cooperative]
* 2. Barrier Wait. No CUDA call is permitted
diff --git a/src/nccl.h.in b/src/nccl.h.in
index 7227625..26e9969 100644
--- a/src/nccl.h.in
+++ b/src/nccl.h.in
@@ -68,7 +68,8 @@ ncclResult_t pncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId
ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist);
ncclResult_t pncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist);
-/* Frees resources associated with communicator object. */
+/* Frees resources associated with communicator object and aborts any operations
+ * that might still be running on the device. */
ncclResult_t ncclCommDestroy(ncclComm_t comm);
ncclResult_t pncclCommDestroy(ncclComm_t comm);
@@ -76,6 +77,10 @@ ncclResult_t pncclCommDestroy(ncclComm_t comm);
const char* ncclGetErrorString(ncclResult_t result);
const char* pncclGetErrorString(ncclResult_t result);
+/* Checks whether the communicator has encountered any asynchronous errors. */
+ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError);
+ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError);
+
/* Gets the number of ranks in the communicator clique. */
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count);
ncclResult_t pncclCommCount(const ncclComm_t comm, int* count);
diff --git a/src/transport.cu b/src/transport.cu
index f5f9d75..7160d04 100644
--- a/src/transport.cu
+++ b/src/transport.cu
@@ -150,6 +150,7 @@ void* persistentThread(void *opaqueInfo) {
}
ncclResult_t res = info->func(&args);
if (res != ncclSuccess) {
+ info->comm->fatalError = res;
WARN("%s:%d -> %d [Proxy thread error]", __FILE__, __LINE__, res);
}
}