Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device/prims_ll.h')
-rw-r--r--src/collectives/device/prims_ll.h46
1 files changed, 16 insertions, 30 deletions
diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h
index f919493..9e362f9 100644
--- a/src/collectives/device/prims_ll.h
+++ b/src/collectives/device/prims_ll.h
@@ -1,9 +1,16 @@
+/*************************************************************************
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
template <typename T, class FUNC, int NRECV, int NSEND>
class ncclLLPrimitives {
private:
const int tid;
const int nthreads;
const int wid;
+ const int stepLines;
int nrecv = 0;
int nsend = 0;
struct ncclConnInfo* recvConn = NULL;
@@ -22,8 +29,8 @@ class ncclLLPrimitives {
union ncclLLFifoLine* sendBuff[NSEND];
struct ncclDevComm* comm;
- inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
- inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
@@ -33,19 +40,6 @@ class ncclLLPrimitives {
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
}
- uint32_t mismatch = 0;
- const uint64_t opCount;
-
- inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
- if (mismatch > 20) {
- // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
- // Note that we are not using _threadfence_system in LL so the error cannot be asserted
- *(comm->fatalDevError) = ncclDevSuspectedMismatch;
- } else if (conn && *conn->opCountRem > opCount) {
- mismatch += 1;
- }
- }
-
uint32_t spins = 0;
uint32_t abort = 0;
@@ -53,7 +47,6 @@ class ncclLLPrimitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
- if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@@ -61,14 +54,13 @@ class ncclLLPrimitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
- mismatch = 0;
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
sendConnHeadCache = *sendConnHeadPtr;
if (checkAbort(wid, 1)) break;
}
if (sendConnFifoPtr) {
- int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes;
+ int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size;
}
sendConnHead += 1;
@@ -88,7 +80,7 @@ class ncclLLPrimitives {
// LL Cleanup : write all flags in the slice to make sure we don't have
// data corruption when flag loops over.
if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
- for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
+ for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
}
sendStep[i]++;
}
@@ -98,7 +90,6 @@ class ncclLLPrimitives {
uint32_t flag = recvFlag(i);
uint32_t data1, flag1, data2, flag2;
spins = 0;
- mismatch = 0;
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
if (checkAbort(i, 0)) break;
@@ -164,7 +155,7 @@ class ncclLLPrimitives {
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
- recvBuff[i] = conn->llBuff;
+ recvBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
recvStep[i] = conn->step;
if (wid == i) recvConn = conn;
nrecv++;
@@ -173,13 +164,11 @@ class ncclLLPrimitives {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConnHeadPtr = recvConn->head;
recvConnHead = recvConn->step;
- // Update opCount in case we skipped some operations
- *(recvConn->opCountLoc) = opCount;
}
}
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
- sendBuff[i] = conn->llBuff;
+ sendBuff[i] = (union ncclLLFifoLine*)conn->buffs[NCCL_PROTO_LL];
sendStep[i] = conn->step;
if (wid == i) sendConn = conn;
nsend++;
@@ -189,15 +178,13 @@ class ncclLLPrimitives {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
- sendConnFifoPtr = sendConn->fifo;
- *(sendConn->opCountLoc) = opCount;
+ sendConnFifoPtr = sendConn->sizesFifo;
}
}
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
recvConn->step = recvConnHead;
- *(recvConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
@@ -205,15 +192,14 @@ class ncclLLPrimitives {
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
sendConn->step = sendConnHead;
- *(sendConn->opCountLoc) = opCount+1;
__threadfence_block();
}
}
public:
__device__ __forceinline__
- ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
- : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), opCount(opCount) {
+ ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm)
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines) {
// Make sure step is updated before we read it.
barrier();