diff options
Diffstat (limited to 'src/collectives/device/prims_ll.h')
-rw-r--r-- | src/collectives/device/prims_ll.h | 46 |
1 files changed, 16 insertions, 30 deletions
diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index f919493..9e362f9 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -1,9 +1,16 @@ +/************************************************************************* + * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + template <typename T, class FUNC, int NRECV, int NSEND> class ncclLLPrimitives { private: const int tid; const int nthreads; const int wid; + const int stepLines; int nrecv = 0; int nsend = 0; struct ncclConnInfo* recvConn = NULL; @@ -22,8 +29,8 @@ class ncclLLPrimitives { union ncclLLFifoLine* sendBuff[NSEND]; struct ncclDevComm* comm; - inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } - inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } + inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; } + inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; } inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } @@ -33,19 +40,6 @@ class ncclLLPrimitives { asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); } - uint32_t mismatch = 0; - const uint64_t opCount; - - inline __device__ void checkMismatch(struct ncclConnInfo* conn) { - if (mismatch > 20) { - // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch - // Note that we are not using _threadfence_system in LL so the error cannot be asserted - *(comm->fatalDevError) = ncclDevSuspectedMismatch; - } else if (conn && *conn->opCountRem > opCount) { - mismatch += 1; - } - } - uint32_t spins = 0; uint32_t abort = 0; @@ -53,7 +47,6 @@ class ncclLLPrimitives { spins++; if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { abort = *(comm->abortFlag); - if (wid == i) checkMismatch(send ? sendConn : recvConn); spins = 0; } return abort; @@ -61,14 +54,13 @@ class ncclLLPrimitives { inline __device__ void waitSend(int nbytes) { spins = 0; - mismatch = 0; if (sendConnHeadPtr) { while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(wid, 1)) break; } if (sendConnFifoPtr) { - int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes; + int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size; } sendConnHead += 1; @@ -88,7 +80,7 @@ class ncclLLPrimitives { // LL Cleanup : write all flags in the slice to make sure we don't have // data corruption when flag loops over. if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { - for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i)); + for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i)); } sendStep[i]++; } @@ -98,7 +90,6 @@ class ncclLLPrimitives { uint32_t flag = recvFlag(i); uint32_t data1, flag1, data2, flag2; spins = 0; - mismatch = 0; do { asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); if (checkAbort(i, 0)) break; @@ -164,7 +155,7 @@ class ncclLLPrimitives { } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { - recvBuff[i] = conn->llBuff; + recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; recvStep[i] = conn->step; if (wid == i) recvConn = conn; nrecv++; @@ -173,13 +164,11 @@ class ncclLLPrimitives { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { recvConnHeadPtr = recvConn->head; recvConnHead = recvConn->step; - // Update opCount in case we skipped some operations - *(recvConn->opCountLoc) = opCount; } } __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { - sendBuff[i] = conn->llBuff; + sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL]; sendStep[i] = conn->step; if (wid == i) sendConn = conn; nsend++; @@ -189,15 +178,13 @@ class ncclLLPrimitives { sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; - sendConnFifoPtr = sendConn->fifo; - *(sendConn->opCountLoc) = opCount; + sendConnFifoPtr = sendConn->sizesFifo; } } __device__ __forceinline__ void saveRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { recvConn->step = recvConnHead; - *(recvConn->opCountLoc) = opCount+1; __threadfence_block(); } } @@ -205,15 +192,14 @@ class ncclLLPrimitives { __device__ __forceinline__ void saveSendSync() { if (tid < nsend) { sendConn->step = sendConnHead; - *(sendConn->opCountLoc) = opCount+1; __threadfence_block(); } } public: __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), opCount(opCount) { + ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm) + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) { // Make sure step is updated before we read it. barrier(); |