From 6f4f4e92688bb9a32327b1be031bf1b87a3bc429 Mon Sep 17 00:00:00 2001 From: Wojtek Wasko Date: Tue, 9 Oct 2018 02:02:01 -0700 Subject: Abort mechanism and API for querying async errors. Change-Id: If1f8fadc719b136788609a10416658f3ef76cf35 --- src/collectives/device/all_gather.h | 18 +++++--- src/collectives/device/all_reduce.h | 19 ++++++--- src/collectives/device/broadcast.h | 18 +++++--- src/collectives/device/ll_kernel.h | 76 +++++++++++++++++++++++---------- src/collectives/device/primitives.h | 62 +++++++++++++++++++++------ src/collectives/device/reduce.h | 17 +++++--- src/collectives/device/reduce_scatter.h | 17 +++++--- src/include/core.h | 8 ++++ src/init.cu | 21 +++++++++ src/misc/enqueue.cu | 3 ++ src/misc/group.cu | 12 ++++++ src/nccl.h.in | 7 ++- src/transport.cu | 1 + 13 files changed, 213 insertions(+), 66 deletions(-) diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index a30e575..cef6097 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -26,8 +26,8 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) { int prevdirect = ring->recv.conn.direct; int nextdirect = ring->send.conn.direct; - WaitFlag waitDoneFromNext(ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS); - WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLGATHER_SUBSTEPS); + WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS); + WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, ALLGATHER_SUBSTEPS); PostFlag postDoneToPrev(ring->recv.conn.head, ALLGATHER_SUBSTEPS, NULL, 0); PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS); @@ -38,13 +38,14 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) { const int buffSize = ring->buffSize / sizeof(T); const int sliceSize = buffSize / ALLGATHER_BUFCHUNKS; const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + uint32_t shouldExit = 0; if (tid == 0) { // Update in case we skipped some collectives *ring->recv.conn.opCount = args->opCount; // Wait for next to be ready - WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); - waitOpCountNext.wait(args->opCount); + WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0); + waitOpCountNext.wait(&shouldExit, args->opCount); if (prevdirect) { *ring->recv.conn.ptrExchange = args->ThisOutput; } @@ -55,7 +56,7 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) { *ptr = nullptr; } } - __syncthreads(); + exitIfAbortBarrier(shouldExit); uint64_t step = 0ULL; int poffset, noffset = 0; @@ -157,12 +158,13 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) { } if (tid == 0) { - waitDoneFromNext.wait(ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS)); + waitDoneFromNext.wait(&shouldExit, ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS)); *ring->send.conn.head = 0ULL; *ring->recv.conn.tail = 0ULL; __threadfence_system(); *ring->recv.conn.opCount = args->opCount+1; } + exitIfAbortBarrier(shouldExit); } #include "ll_kernel.h" @@ -223,11 +225,13 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; if (thisInput + chunkOffset == thisOutput + offset) { // In place LL::ReduceCopy( + comm->abortFlag, thisInput + chunkOffset, nextOutput + noffset, maxOffset, nflag, llNthreads); } else { LL::ReduceCopy( + comm->abortFlag, thisInput + chunkOffset, thisOutput + offset, nextOutput + noffset, @@ -244,6 +248,7 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, prevInput + poffset, thisOutput + offset, nextOutput + noffset, @@ -259,6 +264,7 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) { offset = chunkOffset + rankDest * size; LL::ReduceCopy( + comm->abortFlag, prevInput + poffset, thisOutput + offset, maxOffset, pflag, llNthreads); diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index d7abc64..6d6e2a0 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -26,8 +26,8 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) { int prevdirect = ring->recv.conn.direct; int nextdirect = ring->send.conn.direct; - WaitFlag waitDoneFromNext(ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS); - WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLREDUCE_SUBSTEPS); + WaitFlag waitDoneFromNext(args->comm->abortFlag, ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS); + WaitFlag waitReadyFromPrev(args->comm->abortFlag, ring->recv.conn.tail, ALLREDUCE_SUBSTEPS); PostFlag postDoneToPrev(ring->recv.conn.head, ALLREDUCE_SUBSTEPS, NULL, 0); PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS); @@ -39,13 +39,14 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) { const int buffSize = ring->buffSize / sizeof(T); const int sliceSize = buffSize / ALLREDUCE_BUFCHUNKS; const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + uint32_t shouldExit = 0; if (tid == 0) { // Update in case we skipped some collectives *ring->recv.conn.opCount = args->opCount; // Wait for next to be ready - WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); - waitOpCountNext.wait(args->opCount); + WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0); + waitOpCountNext.wait(&shouldExit, args->opCount); if (prevdirect) { *ring->recv.conn.ptrExchange = args->ThisOutput; } @@ -56,7 +57,7 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) { *ptr = nullptr; } } - __syncthreads(); + exitIfAbortBarrier(shouldExit); uint64_t step = 0ULL; int poffset, noffset = 0; @@ -188,12 +189,13 @@ __device__ void ncclAllReduceKernel(struct CollectiveArgs* args) { if (tid == 0) { // Wait for next to have consumed all data before we reset the flag - waitDoneFromNext.wait(ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS)); + waitDoneFromNext.wait(&shouldExit, ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS)); *ring->send.conn.head = 0ULL; *ring->recv.conn.tail = 0ULL; __threadfence_system(); *ring->recv.conn.opCount = args->opCount+1; } + exitIfAbortBarrier(shouldExit); } #include "ll_kernel.h" @@ -254,6 +256,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + args->comm->abortFlag, thisInput + offset, nextOutput + noffset, maxOffset, nflag, llNthreads); @@ -269,6 +272,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + args->comm->abortFlag, thisInput + offset, prevInput + poffset, nextOutput + noffset, @@ -287,6 +291,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + args->comm->abortFlag, thisInput + offset, prevInput + poffset, thisOutput + offset, @@ -305,6 +310,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + args->comm->abortFlag, prevInput + poffset, thisOutput + offset, nextOutput + noffset, @@ -322,6 +328,7 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { // Here we need to copy from buffer to this output. LL::ReduceCopy( + args->comm->abortFlag, prevInput + poffset, thisOutput + offset, maxOffset, pflag, llNthreads); diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index c2f6d00..8f89e95 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -25,8 +25,8 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) { int prevdirect = ring->recv.conn.direct; int nextdirect = ring->send.conn.direct; - WaitFlag waitDoneFromNext(ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS); - WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0); + WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS); + WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, 0); PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0); PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, BROADCAST_BUFCHUNKS*BROADCAST_SUBSTEPS); @@ -39,14 +39,15 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) { const int rank = ring->devUserRanks[0]; const int nextRank = ring->devUserRanks[1]; const int root = args->root; + uint32_t shouldExit = 0; if (tid == 0) { // Update in case we skipped some collectives *ring->recv.conn.opCount = args->opCount; if (nextRank != root) { // Wait for next to be ready - WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); - waitOpCountNext.wait(args->opCount); + WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0); + waitOpCountNext.wait(&shouldExit, args->opCount); } if (rank != root && prevdirect) { *ring->recv.conn.ptrExchange = args->ThisOutput; @@ -58,7 +59,7 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) { *ptr = nullptr; } } - __syncthreads(); + exitIfAbortBarrier(shouldExit); uint64_t step = 0ULL; int boffset = 0; @@ -129,13 +130,14 @@ __device__ void ncclBroadcastKernel(struct CollectiveArgs* args) { if (tid == 0) { if (nextRank != root) { // Wait for next to have consumed data before resetting the flag - waitDoneFromNext.wait(BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1)); + waitDoneFromNext.wait(&shouldExit, BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1)); *ring->send.conn.head = 0ULL; } *ring->recv.conn.tail = 0ULL; __threadfence_system(); *ring->recv.conn.opCount = args->opCount+1; } + exitIfAbortBarrier(shouldExit); } #include "ll_kernel.h" @@ -188,11 +190,13 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; if (thisInput == thisOutput) { LL::ReduceCopy( + comm->abortFlag, thisInput + offset, nextOutput + boffset, maxOffset, flag, llNthreads); } else { LL::ReduceCopy( + comm->abortFlag, thisInput + offset, thisOutput + offset, nextOutput + boffset, @@ -202,6 +206,7 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) { NEXT_STEP_LL; } else if (nextRank == root) { LL::ReduceCopy( + comm->abortFlag, prevInput + boffset, thisOutput + offset, maxOffset, flag, llNthreads); @@ -210,6 +215,7 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) { } else { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, prevInput + boffset, thisOutput + offset, nextOutput + boffset, diff --git a/src/collectives/device/ll_kernel.h b/src/collectives/device/ll_kernel.h index 5ec3c9a..f8f70ed 100644 --- a/src/collectives/device/ll_kernel.h +++ b/src/collectives/device/ll_kernel.h @@ -7,10 +7,32 @@ #ifndef NCCL_LL_KERNEL_H_ #define NCCL_LL_KERNEL_H_ -static __device__ uint64_t readLL(union ncclLLFifoLine* src, uint32_t flag) { +#define LL_SPINS_BEFORE_CHECK_ABORT 100000 + +// Each thread sets a predicate to true if val == 1 +// nthreads threads enter the barrier and do a popc on their predicates being True +// If any of the thread's predicate was True, all the threads call exit() +inline __device__ void exitIfAbortBarrier(uint32_t abortFlag, int nthreads) { + uint32_t popc; + asm ("{"); + asm volatile (" .reg .pred barr_pred;"); + asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abortFlag)); + asm volatile (" bar.red.popc.u32 %0, 14, %1, barr_pred;" : "=r"(popc) : "r"(nthreads)); + asm ("}"); + if (popc) { asm volatile ("exit;"); } +} + +static __device__ uint64_t readLL(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, uint32_t flag) { uint32_t data1, flag1, data2, flag2; + size_t spins = 0; do { asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); + if (++spins == LL_SPINS_BEFORE_CHECK_ABORT) { + if (*abortFlag != 0) { + asm volatile ("exit;"); + } + spins = 0; + } } while ((flag1 != flag) || (flag2 != flag)); uint64_t val64 = data1 + (((uint64_t)data2) << 32); return val64; @@ -34,7 +56,7 @@ template class LLPrimitives { private: template - static __device__ void ReduceCopyGeneric(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + static __device__ void ReduceCopyGeneric(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { if (size <= 0) return; size_t size64 = size * sizeof(T) / sizeof(uint64_t); uint64_t* src1A = (uint64_t*)src1; @@ -46,9 +68,9 @@ class LLPrimitives { uint64_t val; if (HAS_SRC1) { val = readAL(src1A+offset); - if (HAS_SRC2) val = MULTI()(readLL(src2+offset, iflag), val); + if (HAS_SRC2) val = MULTI()(readLL(abortFlag, src2+offset, iflag), val); } else if (HAS_SRC2) { - val = readLL(src2+offset, iflag); + val = readLL(abortFlag, src2+offset, iflag); } if (HAS_DST1) storeAL(dst1A+offset, val); if (HAS_DST2) storeLL(dst2+offset, val, oflag); @@ -64,7 +86,7 @@ class LLPrimitives { T* vals = (T*)&lastVal; if (HAS_SRC2) { - uint64_t lastVal2 = readLL(src2+size64, iflag); + uint64_t lastVal2 = readLL(abortFlag, src2+size64, iflag); T* src2B = (T*)&lastVal2; for (int offset = 0; offset < sizeRem; offset++) { vals[offset] = HAS_SRC1 ? FUNC()(src2B[offset], src1B[offset]) : src2B[offset]; @@ -83,32 +105,32 @@ class LLPrimitives { } } public: - static __device__ void ReduceCopy(const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) { - return ReduceCopyGeneric<1, 0, 0, 1>(src, NULL, NULL, dst, size, 0, oflag, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 0, 0, 1>(abortFlag, src, NULL, NULL, dst, size, 0, oflag, nthreads); } - static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) { - return ReduceCopyGeneric<0, 1, 1, 0>(NULL, src, dst, NULL, size, iflag, 0, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) { + return ReduceCopyGeneric<0, 1, 1, 0>(abortFlag, NULL, src, dst, NULL, size, iflag, 0, nthreads); } - static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) { - return ReduceCopyGeneric<1, 1, 0, 1>(src1, src2, NULL, dst, size, iflag, oflag, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 0, 1>(abortFlag, src1, src2, NULL, dst, size, iflag, oflag, nthreads); } - static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) { - return ReduceCopyGeneric<1, 1, 1, 0>(src1, src2, dst, NULL, size, iflag, 0, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 1, 0>(abortFlag, src1, src2, dst, NULL, size, iflag, 0, nthreads); } - static __device__ void ReduceCopy(const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) { - return ReduceCopyGeneric<1, 0, 1, 1>(src, NULL, dst1, dst2, size, 0, oflag, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 0, 1, 1>(abortFlag, src, NULL, dst1, dst2, size, 0, oflag, nthreads); } - static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { - return ReduceCopyGeneric<0, 1, 1, 1>(NULL, src, dst1, dst2, size, iflag, oflag, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<0, 1, 1, 1>(abortFlag, NULL, src, dst1, dst2, size, iflag, oflag, nthreads); } - static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { - return ReduceCopyGeneric<1, 1, 1, 1>(src1, src2, dst1, dst2, size, iflag, oflag, nthreads); + static __device__ void ReduceCopy(volatile uint32_t *abortFlag, const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 1, 1>(abortFlag, src1, src2, dst1, dst2, size, iflag, oflag, nthreads); } }; @@ -118,12 +140,22 @@ class LLPrimitives { (step % NCCL_LL_CHUNKS) #define WAIT_NEXT \ - if (tid == 0) { \ + { \ + uint32_t abortFlag = 0; \ + size_t spins = 0; \ while (sendHead + NCCL_LL_CHUNKS <= step) { \ sendHead = sendHeadPtr[0]; \ + ++spins; \ + if (spins == LL_SPINS_BEFORE_CHECK_ABORT) { \ + abortFlag = *args->comm->abortFlag; \ + if (abortFlag != 0) { \ + break; \ + } \ + spins = 0; \ + } \ } \ - } \ - asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads)); + exitIfAbortBarrier(abortFlag, llNthreads); \ + } #define POST_SIZE \ if (tid == 0 && sizesFifo) sizesFifo[step % NCCL_LL_CHUNKS] = (maxOffset <= 0) ? -1 : (maxOffset*2*(int)sizeof(T)); diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 8df152e..ae9fab2 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -24,17 +24,39 @@ * corresponding substep by previous step) before executing the transfer. * After each substep is transfered, all PostFlag arguments get updated to * the value SUBSTEPS*step+substep+1. + * + * The wait operation will read the abortFlag after spinning for + * SPINS_BEFORE_CHECK_ABORT iterations. If it reads the abort flag and the + * flag is non-zero, it will leave without waiting for the wait condition + * to be achieved. The output abortFlag argument will be updated in case + * the abort flag is read. */ +#define SPINS_BEFORE_CHECK_ABORT 100000 class WaitFlag { + volatile uint32_t * abortFlag; volatile uint64_t * const flag; const int shift; public: __device__ __forceinline__ - WaitFlag(volatile uint64_t * const flag, const int shift) : flag(flag), shift(shift) { } + WaitFlag(volatile uint32_t *abortFlag, volatile uint64_t * const flag, const int shift) + : abortFlag(abortFlag), flag(flag), shift(shift) { } __device__ __forceinline__ - void wait(uint64_t val) { while ((*flag + shift) < val) /*SPIN*/; } + void wait(uint32_t *outAbortFlag, uint64_t val) { + size_t spins = 0; + while ((*flag + shift) < val) { + /*SPIN*/ + ++spins; + if (spins == SPINS_BEFORE_CHECK_ABORT) { + *outAbortFlag = *abortFlag; + if (*outAbortFlag != 0) { + return; + } + spins = 0; + } + } + } }; @@ -67,17 +89,17 @@ bool AnyAre(FIRST_T first, TAIL_Ts... tail) { // Wait on all WaitFlags, ignore PostFlags __device__ __forceinline__ -void WaitOnFlags(uint64_t val) { } +void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val) { } template __device__ __forceinline__ -void WaitOnFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) { - flag.wait(val); - WaitOnFlags(val, tail...); +void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val, WaitFlag flag, TAIL_Ts... tail) { + flag.wait(outAbortFlag, val); + WaitOnFlags(outAbortFlag, val, tail...); } template __device__ __forceinline__ -void WaitOnFlags(uint64_t val, PostFlag, TAIL_Ts... tail) { - WaitOnFlags(val, tail...); +void WaitOnFlags(uint32_t *outAbortFlag, uint64_t val, PostFlag, TAIL_Ts... tail) { + WaitOnFlags(outAbortFlag, val, tail...); } @@ -125,6 +147,20 @@ nullptr_t ptradd(nullptr_t ptr, int i) { } +// Each thread sets a predicate to true if val == 1 +// all CTA's threads enter the barrier and do a popc on their predicates being True +// If any of the thread's predicate was True, all the threads call exit() +static inline __device__ +void exitIfAbortBarrier(uint32_t val) { + uint32_t popc; + asm ("{"); + asm volatile (" .reg .pred barr_pred;"); + asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(val)); + asm volatile (" bar.red.popc.u32 %0, 14, barr_pred;" : "=r"(popc)); + asm ("}"); + if (popc) { asm volatile ("exit;"); } +} + // Implementation of primitive types template > class Primitives { @@ -140,6 +176,8 @@ class Primitives { DST2_T dst2, int len, int maxoffset, uint64_t step, SYNC_Ts... flags) { + uint32_t abort = 0; + enum { noSrc2 = std::is_same::value }; enum { noDst2 = std::is_same::value }; static_assert(noSrc2 || std::is_same::value, @@ -158,7 +196,7 @@ class Primitives { if (tid < nthreads) { if (AnyAre(flags...)) { if (tid == 0) { - WaitOnFlags(SUBSTEPS*step + sub + 1, flags...); + WaitOnFlags(&abort, SUBSTEPS*step + sub + 1, flags...); } asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); } @@ -178,12 +216,10 @@ class Primitives { ptradd(src2, sliceOffset), realSize ); - if (AnyAre(flags...)) { - __syncthreads(); - } + exitIfAbortBarrier(abort); } else { + exitIfAbortBarrier(abort); if (AnyAre(flags...)) { - __syncthreads(); PostSizeToFlags(SUBSTEPS*step+sub, realSize*sizeof(T), flags...); __threadfence_system(); PostToFlags(SUBSTEPS*step + sub + 1, flags...); diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index f5694b1..f815104 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -22,8 +22,8 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) { struct ncclComm* comm = args->comm; struct ncclRing* ring = comm->rings+blockIdx.x; - WaitFlag waitDoneFromNext(ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS); - WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0); + WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS); + WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, 0); PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0); PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCE_BUFCHUNKS*REDUCE_SUBSTEPS); @@ -37,6 +37,7 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) { const int rank = ring->devUserRanks[0]; const int prevRank = ring->devUserRanks[nranks-1]; const int root = args->root; + uint32_t shouldExit = 0; if (tid == 0) { // Update in case we skipped some collectives @@ -44,11 +45,11 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) { if (rank != root) { // Wait for next to be ready - WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); - waitOpCountNext.wait(args->opCount); + WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0); + waitOpCountNext.wait(&shouldExit, args->opCount); } } - __syncthreads(); + exitIfAbortBarrier(shouldExit); uint64_t step = 0ULL; int boffset = 0; @@ -97,13 +98,14 @@ __device__ void ncclReduceKernel(struct CollectiveArgs* args) { if (tid == 0) { if (rank != root) { // Wait for next to have consumed data before resetting the flag - waitDoneFromNext.wait(REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1)); + waitDoneFromNext.wait(&shouldExit, REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1)); *ring->send.conn.head = 0ULL; } *ring->recv.conn.tail = 0ULL; __threadfence_system(); *ring->recv.conn.opCount = args->opCount+1; } + exitIfAbortBarrier(shouldExit); } #include "ll_kernel.h" @@ -156,6 +158,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) { if (prevRank == root) { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, thisInput + offset, nextOutput + boffset, maxOffset, flag, llNthreads); @@ -163,6 +166,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) { NEXT_STEP_LL; } else if (rank == root) { LL::ReduceCopy( + comm->abortFlag, thisInput + offset, prevInput + boffset, thisOutput + offset, @@ -172,6 +176,7 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) { } else { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, thisInput + offset, prevInput + boffset, nextOutput + boffset, diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index cad011b..842cf76 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -23,8 +23,8 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) { struct ncclComm* comm = args->comm; struct ncclRing* ring = comm->rings+blockIdx.x; - WaitFlag waitDoneFromNext(ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS); - WaitFlag waitReadyFromPrev(ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS); + WaitFlag waitDoneFromNext(comm->abortFlag, ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS); + WaitFlag waitReadyFromPrev(comm->abortFlag, ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS); PostFlag postDoneToPrev(ring->recv.conn.head, REDUCESCATTER_SUBSTEPS, NULL, 0); PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS); @@ -35,15 +35,16 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) { const int buffSize = ring->buffSize / sizeof(T); const int sliceSize = buffSize / REDUCESCATTER_BUFCHUNKS; const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + uint32_t shouldExit = 0; if (tid == 0) { // Update in case we skipped some collectives *ring->recv.conn.opCount = args->opCount; // Wait for next to be ready - WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); - waitOpCountNext.wait(args->opCount); + WaitFlag waitOpCountNext(comm->abortFlag, ring->send.conn.opCount, 0); + waitOpCountNext.wait(&shouldExit, args->opCount); } - __syncthreads(); + exitIfAbortBarrier(shouldExit); uint64_t step = 0ULL; int poffset, noffset = 0; @@ -111,12 +112,13 @@ __device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) { } if (tid == 0) { - waitDoneFromNext.wait(REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS)); + waitDoneFromNext.wait(&shouldExit, REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS)); *ring->send.conn.head = 0ULL; *ring->recv.conn.tail = 0ULL; __threadfence_system(); *ring->recv.conn.opCount = args->opCount+1; } + exitIfAbortBarrier(shouldExit); } #include "ll_kernel.h" @@ -176,6 +178,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, thisInput + offset, nextOutput + noffset, maxOffset, nflag, llNthreads); @@ -190,6 +193,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) { WAIT_NEXT; LL::ReduceCopy( + comm->abortFlag, thisInput + offset, prevInput + poffset, nextOutput + noffset, @@ -206,6 +210,7 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) { offset = chunkOffset + rankDest * size; LL::ReduceCopy( + comm->abortFlag, thisInput + offset, prevInput + poffset, thisOutput + chunkOffset, diff --git a/src/include/core.h b/src/include/core.h index 66b353c..20762c8 100644 --- a/src/include/core.h +++ b/src/include/core.h @@ -226,6 +226,14 @@ struct ncclComm { int groupCudaStream; cudaStream_t groupStream; + // Whether there has been a fatal error on the communicator + ncclResult_t fatalError; + + // Whether the operations enqueued on this communicator should abort + // On host: this pointer has been obtained from cudaHostAlloc(cudaHostAllocMapped) + // On device: this pointer has been obtained from cudaHostGetDevicePointer() + volatile uint32_t *abortFlag; + // Device copy of the communicator struct ncclComm *devComm; diff --git a/src/init.cu b/src/init.cu index 6669251..77807ca 100644 --- a/src/init.cu +++ b/src/init.cu @@ -136,6 +136,7 @@ static ncclResult_t commFree(ncclComm_t comm) { free(comm->intraCGMode); free(comm->intraCC); } + CUDACHECK(cudaFreeHost((void *)comm->abortFlag)); free(comm); return ncclSuccess; @@ -173,6 +174,10 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { // Don't allow the user to overload the default setting in older CUDA builds comm->groupCudaStream = NCCL_GROUP_CUDA_STREAM; #endif + comm->fatalError = ncclSuccess; + + CUDACHECK(cudaHostAlloc((void**) &comm->abortFlag, sizeof(uint32_t), cudaHostAllocMapped)); + *comm->abortFlag = 0; comm->argsptr = &comm->args; @@ -189,6 +194,10 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { for (int r=0; rnRings; r++) { NCCLCHECK(ncclCudaMemcpy(comm->rings[r].devUserRanks, comm->rings[r].userRanks, comm->nRanks)); } + // Copy the device-accessible pointer to comm->abortFlag + void *devAbortFlag; + CUDACHECK(cudaHostGetDevicePointer(&devAbortFlag, (uint32_t *)comm->abortFlag, 0)); + CUDACHECK(cudaMemcpy(&comm->devComm->abortFlag, &devAbortFlag, sizeof(uint32_t *), cudaMemcpyHostToDevice)); return ncclSuccess; } @@ -769,6 +778,10 @@ ncclResult_t ncclCommDestroy(ncclComm_t comm) { CUDACHECK(cudaSetDevice(commDevice)); } + // Ask anything that might still be running on the device to quit + *comm->abortFlag = 1; + CUDACHECK(cudaStreamSynchronize(comm->groupStream)); + NCCLCHECK(commFree(comm)); if (savedDevice != commDevice) @@ -790,6 +803,14 @@ const char* ncclGetErrorString(ncclResult_t code) { } } +NCCL_API(ncclResult_t, ncclCommGetAsyncError, ncclComm_t comm, ncclResult_t *asyncError); +ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) { + NCCLCHECK(PtrCheck(comm, "ncclGetAsyncError", "comm")); + NCCLCHECK(PtrCheck(asyncError, "ncclGetAsyncError", "asyncError")); + *asyncError = comm->fatalError; + return ncclSuccess; +} + NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count); ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) { NCCLCHECK(PtrCheck(comm, "CommCount", "comm")); diff --git a/src/misc/enqueue.cu b/src/misc/enqueue.cu index dc5d9cc..57d1808 100644 --- a/src/misc/enqueue.cu +++ b/src/misc/enqueue.cu @@ -219,6 +219,9 @@ ncclResult_t ncclEnqueueCheck(ncclFunc_t func, const char* primName, const void* void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) { if (comm == NULL) return ncclInvalidArgument; + if (comm->fatalError != ncclSuccess) { + return ncclInvalidUsage; + } // Launch asynchronously if needed if (ncclAsyncMode()) { ncclResult_t ret = ncclSuccess; diff --git a/src/misc/group.cu b/src/misc/group.cu index c7b31cf..c8e2d6d 100644 --- a/src/misc/group.cu +++ b/src/misc/group.cu @@ -124,6 +124,18 @@ ncclResult_t ncclGroupEnd() { ncclResult_t ret = ncclGroupError; if (ret != ncclSuccess) goto group_cleanup; + // check if any of the communicators used in this group has + // encountered a fatal error. + for (int i=0; ifuncType == ASYNC_FUNC_COLL) { + if (args->coll.comm->fatalError != ncclSuccess) { + ret = ncclInvalidUsage; + goto group_cleanup; + } + } + } + /* Collectives are done in three steps : * 1. Barrier Check In. Only the last call may call cudaLaunchKernel[cooperative] * 2. Barrier Wait. No CUDA call is permitted diff --git a/src/nccl.h.in b/src/nccl.h.in index 7227625..26e9969 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -68,7 +68,8 @@ ncclResult_t pncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); ncclResult_t pncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); -/* Frees resources associated with communicator object. */ +/* Frees resources associated with communicator object and aborts any operations + * that might still be running on the device. */ ncclResult_t ncclCommDestroy(ncclComm_t comm); ncclResult_t pncclCommDestroy(ncclComm_t comm); @@ -76,6 +77,10 @@ ncclResult_t pncclCommDestroy(ncclComm_t comm); const char* ncclGetErrorString(ncclResult_t result); const char* pncclGetErrorString(ncclResult_t result); +/* Checks whether the communicator has encountered any asynchronous errors. */ +ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); +ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); + /* Gets the number of ranks in the communicator clique. */ ncclResult_t ncclCommCount(const ncclComm_t comm, int* count); ncclResult_t pncclCommCount(const ncclComm_t comm, int* count); diff --git a/src/transport.cu b/src/transport.cu index f5f9d75..7160d04 100644 --- a/src/transport.cu +++ b/src/transport.cu @@ -150,6 +150,7 @@ void* persistentThread(void *opaqueInfo) { } ncclResult_t res = info->func(&args); if (res != ncclSuccess) { + info->comm->fatalError = res; WARN("%s:%d -> %d [Proxy thread error]", __FILE__, __LINE__, res); } } -- cgit v1.2.3