diff options
Diffstat (limited to 'src/enqueue.cc')
-rw-r--r-- | src/enqueue.cc | 393 |
1 files changed, 240 insertions, 153 deletions
diff --git a/src/enqueue.cc b/src/enqueue.cc index 40e8f57..a427bd9 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -9,58 +9,58 @@ #include "coll_net.h" // Only generate inline kernels for LL -#define NCCL_FUNC5(coll, op, dtype) \ - (void*)NCCL_KERN_NAME(coll##LL, op, dtype), \ - (void*)NCCL_KERN_NAME(coll##LL, op, dtype), \ - (void*)NCCL_KERN_NAME(coll##LL, op, dtype) +#define NCCL_FUNC5(func, algo, redop, dtype) \ + (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype), \ + (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype), \ + (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype) -#define NCCL_FUNC4(coll, op, dtype) \ - (void*)NCCL_FUNC5(coll##Tree, op, dtype), \ - (void*)NCCL_FUNC5(coll##Ring, op, dtype), \ - (void*)NCCL_FUNC5(coll##CollNet, op, dtype) +#define NCCL_FUNC4(func, redop, type) \ + (void*)NCCL_FUNC5(func, TREE, redop, type), \ + (void*)NCCL_FUNC5(func, RING, redop, type), \ + (void*)NCCL_FUNC5(func, COLLNET, redop, type) // Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(coll, op) \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, u8), \ - (void*)NCCL_FUNC4(coll, op, i32), \ - (void*)NCCL_FUNC4(coll, op, u32), \ - (void*)NCCL_FUNC4(coll, op, i64), \ - (void*)NCCL_FUNC4(coll, op, u64), \ - (void*)NCCL_FUNC4(coll, op, f16), \ - (void*)NCCL_FUNC4(coll, op, f32), \ - (void*)NCCL_FUNC4(coll, op, f64) -#define NCCL_FUNCS3B(coll, op) \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8), \ - (void*)NCCL_FUNC4(coll, op, i8) +#define NCCL_FUNCS3A(func, redop) \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, uint8_t), \ + (void*)NCCL_FUNC4(func, redop, int32_t), \ + (void*)NCCL_FUNC4(func, redop, uint32_t), \ + (void*)NCCL_FUNC4(func, redop, int64_t), \ + (void*)NCCL_FUNC4(func, redop, uint64_t), \ + (void*)NCCL_FUNC4(func, redop, half), \ + (void*)NCCL_FUNC4(func, redop, float), \ + (void*)NCCL_FUNC4(func, redop, double) +#define NCCL_FUNCS3B(func, redop) \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t), \ + (void*)NCCL_FUNC4(func, redop, int8_t) // Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. -#define NCCL_FUNCS2A(coll) \ - NCCL_FUNCS3A(coll, sum), \ - NCCL_FUNCS3A(coll, sum), \ - NCCL_FUNCS3A(coll, sum), \ - NCCL_FUNCS3A(coll, sum) -#define NCCL_FUNCS2B(coll) \ - NCCL_FUNCS3B(coll, copy), \ - NCCL_FUNCS3B(coll, copy), \ - NCCL_FUNCS3B(coll, copy), \ - NCCL_FUNCS3B(coll, copy) +#define NCCL_FUNCS2A(func) \ + NCCL_FUNCS3A(func, Sum), \ + NCCL_FUNCS3A(func, Sum), \ + NCCL_FUNCS3A(func, Sum), \ + NCCL_FUNCS3A(func, Sum) +#define NCCL_FUNCS2B(func) \ + NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum) // Must be consistent with the ncclFuncSet enum static void* const ncclKerns[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { - (void*)NCCL_KERN_NAME(ncclSendRecv, copy, i8), - NCCL_FUNCS2B(ncclBroadcast), - NCCL_FUNCS2A(ncclReduce), - NCCL_FUNCS2B(ncclAllGather), - NCCL_FUNCS2A(ncclReduceScatter), - NCCL_FUNCS2A(ncclAllReduce) + (void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_FUNCS2B(Broadcast), + NCCL_FUNCS2A(Reduce), + NCCL_FUNCS2B(AllGather), + NCCL_FUNCS2A(ReduceScatter), + NCCL_FUNCS2A(AllReduce) }; /*****************************************************************************/ @@ -87,41 +87,57 @@ ncclResult_t ncclLaunchCooperativeKernelMultiDevice(struct cudaLaunchParams *par return ncclSuccess; } -ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams* params) { +static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** work, struct ncclWorkElem* base) { + if (channel->workCount == NCCL_MAX_OPS) { + WARN("Too many aggregated operations on channel %d (%d max)", channel->id, NCCL_MAX_OPS); + return ncclInvalidUsage; + } + int opIndex = channel->workFifoTail%NCCL_MAX_OPS; + struct ncclWork* w = channel->workFifo+opIndex; + struct ncclWorkElem* e = w->elems; + volatile uint8_t* activePtr = (volatile uint8_t*)&e->active; + while (activePtr[0] != 0) sched_yield(); + memset(w, 0, sizeof(struct ncclWork)); + // Initialize with work elem if provided + if (base) memcpy(e, base, sizeof(struct ncclWorkElem)); + e->active = 1; + e->index = opIndex; + channel->workFifoTail++; + channel->workCount++; + if (work) *work = w; + return ncclSuccess; +} + +static ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams* params) { // Only launch blocks where we have work to do. for (int c=0; c<comm->p2pnChannels; c++) { - if (comm->channels[c].collCount) params->gridDim.x = c+1; + if (comm->channels[c].workCount) params->gridDim.x = c+1; } // Set active = 2 for the last operation and add a no-op on empty channels (p2p case). for (int c=0; c<params->gridDim.x; c++) { struct ncclChannel* channel = comm->channels+c; - if (channel->collCount == 0) { - int opIndex = channel->collFifoTail; - struct ncclColl* c = channel->collectives+opIndex; - volatile uint8_t* activePtr = (volatile uint8_t*)&c->active; - while (activePtr[0] != 0) sched_yield(); - - c->args.p2p.delta = -1; // no-op - c->funcIndex = FUNC_INDEX_P2P; - c->args.comm = comm->devComm; - c->active = 1; - opIndex = (opIndex+1)%NCCL_MAX_OPS; - c->nextIndex = opIndex; - channel->collFifoTail = opIndex; - channel->collCount++; + if (channel->workCount == 0) { + struct ncclWork* w; + NCCLCHECK(getNextOp(channel, &w, NULL)); + struct ncclWorkElem* e = w->elems; + e->comm = comm->devComm; + e->funcIndex = FUNC_INDEX_P2P; + e->p2p.nThreads = 0; } - channel->collectives[(channel->collStart+channel->collCount-1)%NCCL_MAX_OPS].active = 2; + channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active = 2; } // Find the first operation, choose the kernel accordingly and pass it // as the first argument. - struct ncclColl* coll = comm->channels[0].collectives+comm->channels[0].collStart; - memcpy(&comm->args, coll, sizeof(struct ncclColl)); - // As we pass that coll directly, we can free it immediately. - coll->active = 0; - - params->func = ncclKerns[coll->funcIndex]; + struct ncclChannel* c0 = comm->channels; + struct ncclWork* work = c0->workFifo+((c0->workFifoTail-c0->workCount)%NCCL_MAX_OPS); + struct ncclWorkElem* elem = work->elems; + memcpy(&comm->args, elem, sizeof(struct ncclWorkElem)); + // As we inline the first coll directly, we can free it immediately. + if (elem->funcIndex != FUNC_INDEX_P2P) elem->active = 0; + + params->func = ncclKerns[elem->funcIndex]; return ncclSuccess; } @@ -131,7 +147,7 @@ ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast) { bool done = false; while (done == false) { if (val >= comm->intraRanks) { - WARN("Trying to launch too many collectives"); + WARN("Trying to launch too many work elements, max is %d", NCCL_MAX_OPS); return ncclInvalidUsage; } if (val+1 == comm->intraRanks) { @@ -151,7 +167,7 @@ ncclResult_t ncclCpuBarrierLast(struct ncclComm* comm) { volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase); int val = *ptr; if (__sync_bool_compare_and_swap(ptr, val, val+1) != true) { - WARN("Trying to launch too many collectives"); + WARN("Trying to launch too many work elements, max is %d", NCCL_MAX_OPS); return ncclInternalError; } return ncclSuccess; @@ -222,13 +238,18 @@ ncclResult_t ncclBarrierEnqueueWait(ncclComm_t comm) { // launch and the ncclProxyStart call could cause a deadlock. // Also, starting the proxies after the CUDA launch seems to be better for // performance (latency). + uint64_t max = 0ULL; for (int r=0; r<params->gridDim.x; r++) { struct ncclChannel* channel = comm->channels+r; - channel->collStart = channel->collFifoTail; - channel->collCount = 0; + max = std::max(max, channel->workFifoTail); + channel->workCount = 0; + } + for (int r=0; r<comm->p2pnChannels; r++) { + struct ncclChannel* channel = comm->channels+r; + channel->workFifoTail = max; } params->gridDim.x = params->blockDim.x = 0; - comm->lastOpCount = comm->opCount; + comm->lastOpCount = max; NCCLCHECK(ncclProxyStart(comm)); return ncclSuccess; } @@ -280,7 +301,8 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) { //if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); - int nc = (info->algorithm == NCCL_ALGO_COLLNET) ? comm->nChannels/2 : comm->nChannels; // CollNet uses one channel for up and one channel for down + int nc = (info->nChannels > 0) ? info->nChannels : + (info->algorithm == NCCL_ALGO_COLLNET) ? comm->nChannels/2 : comm->nChannels; // CollNet uses one channel for up and one channel for down int nt = comm->maxThreads[info->algorithm][info->protocol]; int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol]; while (info->nBytes < nc*nt*threadThreshold) { @@ -289,6 +311,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) { else break; } if (info->protocol == NCCL_PROTO_SIMPLE) nt += WARP_SIZE; // Extra warp for sync + if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE; info->nChannels = nc; info->nThreads = nt; return ncclSuccess; @@ -296,14 +319,14 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) { static ncclResult_t getPatternInfo(struct ncclInfo* info) { switch (info->coll) { - case ncclCollBroadcast: + case ncclFuncBroadcast: info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeDown : ncclPatternPipelineFrom; break; - case ncclCollReduce: + case ncclFuncReduce: info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break; - case ncclCollReduceScatter: - case ncclCollAllGather: + case ncclFuncReduceScatter: + case ncclFuncAllGather: info->pattern = ncclPatternRing; break; - case ncclCollAllReduce: + case ncclFuncAllReduce: info->pattern = info->algorithm == NCCL_ALGO_COLLNET ? ncclPatternCollTreeUp : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : ncclPatternRingTwice; break; default: WARN("Unknown pattern for collective %d algorithm %d", info->coll, info->algorithm); @@ -333,30 +356,22 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) { return ncclSuccess; } -static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclColl* coll, struct ncclProxyArgs* proxyArgs /* output */) { - coll->args.sendbuff = info->sendbuff; - coll->args.recvbuff = info->recvbuff; - coll->args.comm = info->comm->devComm; - - if (info->coll == ncclCollSendRecv) { - coll->args.p2p.sendCount = info->sendbytes; - coll->args.p2p.recvCount = info->recvbytes; - coll->args.p2p.delta = info->delta; - coll->funcIndex = FUNC_INDEX_P2P; - coll->args.p2p.nThreads = info->nThreads = info->comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]+2*WARP_SIZE; - return ncclSuccess; - } +static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) { + work->comm = info->comm->devComm; + // Set nstepsPerLoop and nchunksPerLoop NCCLCHECK(getAlgoInfo(info)); NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); - coll->args.coll.root = info->root; - coll->args.coll.count = info->count; - coll->args.coll.nChannels = info->nChannels; - coll->args.coll.nThreads = info->nThreads; + work->sendbuff = info->sendbuff; + work->recvbuff = info->recvbuff; + work->coll.root = info->root; + work->coll.count = info->count; + work->coll.nChannels = info->nChannels; + work->nThreads = info->nThreads; - coll->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol); + work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol); int stepSize = info->comm->buffSizes[info->protocol]/NCCL_STEPS; int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1; @@ -367,25 +382,25 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_SIMPLE) { if (info->pattern == ncclPatternTreeUpDown) { // Optimize chunkSize / nSteps - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth*8 && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth*4 && chunkSize > 65536) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth && chunkSize > 32768) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*8 && chunkSize > 131072) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*4 && chunkSize > 65536) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth && chunkSize > 32768) chunkSize /= 2; } // Use lastChunkSize as chunkSize - coll->args.coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->algorithm == NCCL_ALGO_COLLNET && info->protocol == NCCL_PROTO_SIMPLE) { // Optimize chunkSize / nSteps - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth*16 && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth*4 && chunkSize > 65536) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth && chunkSize > 32768) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth*16 && chunkSize > 131072) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth*4 && chunkSize > 65536) chunkSize /= 2; + while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth && chunkSize > 32768) chunkSize /= 2; // Use lastChunkSize as chunkSize - coll->args.coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->protocol == NCCL_PROTO_LL) { const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; - coll->args.coll.lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop); - ALIGN_SIZE(coll->args.coll.lastChunkSize, info->nThreads*sizeof(uint64_t)); - coll->args.coll.lastChunkSize /= ncclTypeSize(info->datatype); + work->coll.lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop); + ALIGN_SIZE(work->coll.lastChunkSize, info->nThreads*sizeof(uint64_t)); + work->coll.lastChunkSize /= ncclTypeSize(info->datatype); } else if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_LL128) { int nNodes = info->comm->nNodes; float ppn = info->comm->nRanks / (float)nNodes; @@ -393,7 +408,7 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*64/ppn && chunkSize > 131072) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*16/ppn && chunkSize > 32768) chunkSize /= 2; // Use lastChunkSize as chunkSize - coll->args.coll.lastChunkSize = chunkSize*NCCL_LL128_DATAELEMS/(NCCL_LL128_LINEELEMS*ncclTypeSize(info->datatype)); + work->coll.lastChunkSize = chunkSize*NCCL_LL128_DATAELEMS/(NCCL_LL128_LINEELEMS*ncclTypeSize(info->datatype)); } // Compute nSteps for proxies @@ -406,9 +421,13 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo proxyArgs->sliceSteps = sliceSteps; proxyArgs->chunkSteps = chunkSteps; proxyArgs->protocol = info->protocol; - proxyArgs->opCount = info->comm->opCount; proxyArgs->dtype = info->datatype; proxyArgs->redOp = info->op; + // This is used by P2P to reduce the receive buffer size. We don't use it in collectives + // because some protocols need to transmit more than the total size, plus they sometimes + // round up + proxyArgs->recvbytes = stepSize*proxyArgs->sliceSteps; + TRACE(NCCL_NET,"opCount %lx slicesteps %d spl %d cpl %d nbytes %zi -> protocol %d nchannels %d nthreads %d, nloops %d nsteps %d comm %p", proxyArgs->opCount, proxyArgs->sliceSteps, info->nstepsPerLoop, info->nchunksPerLoop, info->nBytes, info->protocol, info->nChannels, info->nThreads, nLoops, proxyArgs->nsteps, info->comm); @@ -427,32 +446,26 @@ static ncclResult_t checkSetStream(struct ncclInfo* info) { } ncclResult_t ncclSaveKernel(struct ncclInfo* info) { - if (info->comm->nRanks == 1 && info->coll != ncclCollSendRecv) { + if (info->comm->nRanks == 1) { if (info->sendbuff != info->recvbuff) CUDACHECK(cudaMemcpyAsync(info->recvbuff, info->sendbuff, info->nBytes, cudaMemcpyDeviceToDevice, info->stream)); return ncclSuccess; } - struct ncclColl coll; + struct ncclWorkElem work; struct ncclProxyArgs proxyArgs; memset(&proxyArgs, 0, sizeof(struct ncclProxyArgs)); - NCCLCHECK(computeColl(info, &coll, &proxyArgs)); + NCCLCHECK(computeColl(info, &work, &proxyArgs)); info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads); - int nChannels = info->coll == ncclCollSendRecv ? 1 : coll.args.coll.nChannels; + int nChannels = work.coll.nChannels; int nSubChannels = (info->pattern == ncclPatternCollTreeUp || info->pattern == ncclPatternCollTreeDown) ? 2 : 1; for (int bid=0; bid<nChannels*nSubChannels; bid++) { - int channelId = (info->coll == ncclCollSendRecv) ? info->channelId : - info->comm->myParams->gridDim.x % info->comm->nChannels; + int channelId = info->comm->myParams->gridDim.x % info->comm->nChannels; struct ncclChannel* channel = info->comm->channels+channelId; - if (channel->collCount == NCCL_MAX_OPS) { - WARN("Too many aggregated operations on channel %d (%d max)", channel->id, NCCL_MAX_OPS); - return ncclInvalidUsage; - } - // Proxy proxyArgs.channel = channel; // Adjust pattern for CollNet based on channel index @@ -460,67 +473,141 @@ ncclResult_t ncclSaveKernel(struct ncclInfo* info) { info->pattern = (channelId < info->comm->nChannels/nSubChannels) ? ncclPatternCollTreeUp : ncclPatternCollTreeDown; } - if (info->coll == ncclCollSendRecv) { - info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1); - NCCLCHECK(ncclProxySaveP2p(info, channel)); - } else { - NCCLCHECK(ncclProxySaveColl(&proxyArgs, info->pattern, info->root, info->comm->nRanks)); - } + if (proxyArgs.nsteps) NCCLCHECK(ncclProxySaveColl(&proxyArgs, info->pattern, info->root, info->comm->nRanks)); + info->comm->myParams->gridDim.x++; - int opIndex = channel->collFifoTail; - struct ncclColl* c = channel->collectives+opIndex; - volatile uint8_t* activePtr = (volatile uint8_t*)&c->active; - while (activePtr[0] != 0) sched_yield(); - - memcpy(c, &coll, sizeof(struct ncclColl)); - if (info->coll != ncclCollSendRecv) c->args.coll.bid = bid % coll.args.coll.nChannels; - - c->active = 1; - opIndex = (opIndex+1)%NCCL_MAX_OPS; - c->nextIndex = opIndex; - channel->collFifoTail = opIndex; - channel->collCount++; + work.coll.bid = bid % nChannels; + NCCLCHECK(getNextOp(channel, NULL, &work)); + } + return ncclSuccess; +} + +#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) +#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ + +ncclResult_t ncclSaveCommKernels(ncclComm_t comm) { + if (comm->asyncOpCount == 0) { + return ncclSuccess; + } else if (comm->asyncOpCount == 1) { + // No aggregation + struct ncclInfo* info = comm->asyncOps; + info->nChannels = 0; + NCCLCHECK(ncclSaveKernel(info)); + } else { + // Aggregation + size_t channelSize = NCCL_AGG_CHANNEL_SIZE * comm->nRanks; // scale channel size based on nranks as latency increases + // Reduce the per-channel size if we cannot fully utilize the channels + while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2; + for (int c = 0; c < comm->asyncOpCount; c++) { + struct ncclInfo* info = comm->asyncOps+c; + info->nChannels = std::min((int)DIVUP(info->nBytes, channelSize), comm->nChannels); // assign number of channels + NCCLCHECK(ncclSaveKernel(info)); + } + } + // Reset counters + comm->asyncOpCount = 0; + comm->asyncTotalSize = 0; + return ncclSuccess; +} + +static ncclResult_t ncclSaveAsyncColl(struct ncclInfo* info) { + ncclComm_t comm = info->comm; + if (comm->asyncOpCount >= NCCL_MAX_OPS) { + WARN("Too many async operations in progress, max is %d", NCCL_MAX_OPS); + return ncclInvalidUsage; } - info->comm->opCount++; + memcpy(comm->asyncOps+comm->asyncOpCount, info, sizeof(struct ncclInfo)); + comm->asyncOpCount++; + comm->asyncTotalSize += info->nBytes; return ncclSuccess; } -// Save p2p operations in comm->p2plist. Operations will be posted to channels +// Save p2p operations in comm->p2pSends and p2pRecvs. Operations will be posted to channels // during ncclGroupEnd() -ncclResult_t ncclSaveP2p(struct ncclInfo* info) { +static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { struct ncclComm* comm = info->comm; - struct ncclP2Plist* p2plist = &comm->p2plist; int peer = info->root; - p2plist->count++; ssize_t nBytes = info->count*ncclTypeSize(info->datatype); - if (info->recvbuff == NULL) { + if (info->opName[0] == 'S') { // Send if (peer != comm->rank) { int delta = (comm->nRanks - (comm->rank-peer)) % comm->nRanks; for (int c=0; c<comm->p2pnChannelsPerPeer; c++) { int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels; if (comm->channels[channelId].peers[peer].send.connected == 0) { - p2plist->connect.send[channelId*comm->nRanks+p2plist->connect.nsend[channelId]++] = peer; + comm->connectSend[peer] |= (1<<channelId); + comm->connect = 1; } } } - p2plist->peerlist[info->root].sendbytes = nBytes; - p2plist->peerlist[info->root].sendbuff = info->sendbuff; + NCCLCHECK(enqueueP2pInfo(comm->p2pSends+info->root, (void*)info->sendbuff, nBytes)); + comm->p2pSendCount++; } else { if (peer != comm->rank) { int delta = (comm->nRanks + (comm->rank-peer)) % comm->nRanks; for (int c=0; c<comm->p2pnChannelsPerPeer; c++) { int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels; if (comm->channels[channelId].peers[peer].recv.connected == 0) { - p2plist->connect.recv[channelId*comm->nRanks+p2plist->connect.nrecv[channelId]++] = peer; + comm->connectRecv[peer] |= (1<<channelId); + comm->connect = 1; } } } - p2plist->peerlist[info->root].recvbytes = nBytes; - p2plist->peerlist[info->root].recvbuff = info->recvbuff; + NCCLCHECK(enqueueP2pInfo(comm->p2pRecvs+info->root, info->recvbuff, nBytes)); + comm->p2pRecvCount++; } return ncclSuccess; } +static int getSegment(struct ncclInfo* info, struct ncclWork* work) { + for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != info->delta; s++) { + if (work->elems[s].p2p.nThreads == 0) return s; + } + return -1; +} + +static ncclResult_t saveP2pOp(struct ncclInfo* info /* input */, struct ncclWork* work, int s) { + struct ncclWorkElem* elem = work->elems+s; + elem->comm = info->comm->devComm; + elem->funcIndex = FUNC_INDEX_P2P; + elem->nThreads = info->nThreads = NCCL_MAX_NTHREADS; + elem->sendbuff = info->sendbuff; + elem->recvbuff = info->recvbuff; + elem->p2p.sendCount = info->sendbytes; + elem->p2p.recvCount = info->recvbytes; + elem->p2p.delta = info->delta; + const int nsegments = s+1; + int nThreads = 512; + while (nsegments*nThreads > 512) nThreads /= 2; + if (nThreads >= 128) nThreads += WARP_SIZE; + for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads; + return ncclSuccess; +} + +ncclResult_t ncclSaveP2pKernel(struct ncclInfo* info) { + int channelId = info->channelId; + struct ncclChannel* channel = info->comm->channels+channelId; + + // Try to reuse last p2p operation if not full yet + int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; + struct ncclWork* w = channel->workFifo+opIndex; + int segment = -1; + if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].p2p.nThreads == 0) { + // Try to pack more segments into a single operation + segment = getSegment(info, w); + } + if (segment == -1) { + NCCLCHECK(getNextOp(channel, &w, NULL)); + segment = 0; + } + + NCCLCHECK(ncclProxySaveP2p(info, channel, segment)); + NCCLCHECK(saveP2pOp(info, w, segment)); + info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1); + info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads); + + return ncclSuccess; +} + ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { // Launch asynchronously if needed if (ncclAsyncMode()) { @@ -542,10 +629,10 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count, info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream); - if (info->coll == ncclCollSendRecv) { //p2p stored separately + if (info->coll == ncclFuncSendRecv) { //p2p stored separately NCCLCHECKGOTO(ncclSaveP2p(info), ret, end); } else { - NCCLCHECKGOTO(ncclSaveKernel(info), ret, end); + NCCLCHECKGOTO(ncclSaveAsyncColl(info), ret, end); } end: if (savedDev != -1) CUDACHECK(cudaSetDevice(savedDev)); |