diff options
Diffstat (limited to 'src/collectives/device/all_reduce.h')
-rw-r--r-- | src/collectives/device/all_reduce.h | 223 |
1 files changed, 115 insertions, 108 deletions
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 4e04f88..6891ac0 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -11,26 +11,27 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->bid; + const int nthreads = args->coll.nThreads-WARP_SIZE; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - const ssize_t size = args->N; - const int nranks = comm->nRanks; - const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); + const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS; - const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize; + const int nranks = comm->nRanks; + const ssize_t loopSize = nChannels*(ssize_t)chunkSize; + const ssize_t size = args->coll.count; // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; - ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC> + prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*args->nChannels)); + ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels)); ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize; @@ -85,28 +86,29 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->bid; + const int nthreads = args->coll.nThreads-WARP_SIZE; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - const ssize_t size = args->N; - const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); - int chunkSize = args->lastChunkSize; + const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); + int chunkSize = args->coll.lastChunkSize; const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; if (loopSize > size) { - chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; } // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; do { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -124,17 +126,17 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; int nelem = min(chunkSize, size-offset); if (tree->up == -1) { - prims.send(thisOutput+offset, nelem); + prims.directSend(thisOutput+offset, offset, nelem); } else if (tree->down[0] == -1) { - prims.recv(thisOutput+offset, nelem); + prims.directRecv(thisOutput+offset, offset, nelem); } else { - prims.recvCopySend(thisOutput+offset, nelem); + prims.directRecvCopySend(thisOutput+offset, offset, nelem); } } } while(0); @@ -143,27 +145,28 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->bid; + const int nthreads = args->coll.nThreads-WARP_SIZE; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - const ssize_t size = args->N; - const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); - int chunkSize = args->lastChunkSize; + const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); + int chunkSize = args->coll.lastChunkSize; const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; if (loopSize > size) { - chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; } // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; - if (blockIdx.x < args->nChannels) { // first half of the channels do reduce + if (blockIdx.x < nChannels) { // first half of the channels do reduce struct ncclTree* tree = &channel->collTreeUp; - ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -178,9 +181,9 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) { } } - if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast + if (blockIdx.x >= nChannels) { // second half of the channels do broadcast struct ncclTree* tree = &channel->collTreeDn; - ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -199,28 +202,27 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nthreads = args->nThreads; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - - ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); - - const ssize_t size = args->N; - //const int rank = comm->rank; - const int nranks = comm->nRanks; - ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); + ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T); + const int nranks = comm->nRanks; + const ssize_t loopSize = nChannels*nranks*chunkSize; + const ssize_t size = args->coll.count; - const ssize_t loopSize = args->nChannels*nranks*chunkSize; + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount); // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); + chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); /////////////// begin AllReduce steps /////////////// ssize_t offset; @@ -229,7 +231,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // step 0: push data to next GPU chunk = ring->devUserRanks[nranks-1]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.send(thisInput+offset, nelem); @@ -237,7 +239,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // k-2 steps: reduce and copy to next GPU for (int j=2; j<nranks; ++j) { chunk = ring->devUserRanks[nranks-j]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceSend(thisInput+offset, nelem); @@ -246,7 +248,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // step k-1: reduce this buffer and data, which will produce the final // result that we store in this data and push to the next GPU chunk = ring->devUserRanks[0]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); @@ -254,7 +256,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // k-2 steps: copy to next GPU for (int j=1; j<nranks-1; ++j) { chunk = ring->devUserRanks[nranks-j]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvCopySend(thisOutput+offset, nelem); @@ -262,7 +264,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // Make final copy from buffer to dest. chunk = ring->devUserRanks[1]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); // Here we need to copy from buffer to this output. @@ -273,27 +275,29 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->bid; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - const ssize_t size = args->N; - ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); + ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; if (loopSize > size) { - chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; } // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; do { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount); + ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -311,7 +315,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount); + ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -330,26 +334,28 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->bid; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - const ssize_t size = args->N; - ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); + ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; if (loopSize > size) { - chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; } // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; - if (blockIdx.x < args->nChannels) { // first half of the channels do reduce + if (blockIdx.x < nChannels) { // first half of the channels do reduce struct ncclTree* tree = &channel->collTreeUp; - ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount); + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -364,9 +370,9 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) { } } - if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast + if (blockIdx.x >= nChannels) { // second half of the channels do broadcast struct ncclTree* tree = &channel->collTreeDn; - ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount); + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -386,29 +392,28 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nthreads = args->nThreads; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - - ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); - - const ssize_t size = args->N; - //const int rank = comm->rank; - const int nranks = comm->nRanks; - ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); + ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; + const int nranks = comm->nRanks; + const ssize_t loopSize = nChannels*nranks*chunkSize; + const ssize_t size = args->coll.count; - const ssize_t loopSize = args->nChannels*nranks*chunkSize; + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount); // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); + chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); /////////////// begin AllReduce steps /////////////// ssize_t offset; @@ -417,7 +422,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { // step 0: push data to next GPU chunk = ring->devUserRanks[nranks-1]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.send(thisInput+offset, nelem); @@ -425,7 +430,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { // k-2 steps: reduce and copy to next GPU for (int j=2; j<nranks; ++j) { chunk = ring->devUserRanks[nranks-j]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceSend(thisInput+offset, nelem); @@ -434,7 +439,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { // step k-1: reduce this buffer and data, which will produce the final // result that we store in this data and push to the next GPU chunk = ring->devUserRanks[0]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); @@ -442,7 +447,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { // k-2 steps: copy to next GPU for (int j=1; j<nranks-1; ++j) { chunk = ring->devUserRanks[nranks-j]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvCopySend(thisOutput+offset, nelem); @@ -450,7 +455,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { // Make final copy from buffer to dest. chunk = ring->devUserRanks[1]; - offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize; + offset = gridOffset + (chunk*nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); // Here we need to copy from buffer to this output. @@ -461,29 +466,31 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads; - const int bid = args->bid; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclTree* treeUp = &channel->treeUp; struct ncclTree* treeDn = &channel->treeDn; - const ssize_t size = args->N; - ssize_t chunkSize = args->lastChunkSize; + const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); + ssize_t chunkSize = args->coll.lastChunkSize; const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8; - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; int nthreadsSplit = NCCL_LL128_SPLIT(nthreads); + const ssize_t size = args->coll.count; if (loopSize > size) { - chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize; } // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; if (treeUp->up == -1) { // ReduceAndBroadcast : max number of recv is 3, max number of send is 3 - ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, channel, comm, args->opCount); + ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*chunkSize; int nelem = min(chunkSize, size-offset); @@ -492,7 +499,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) { } else { if (tid < nthreadsSplit) { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, channel, comm, args->opCount); + ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -505,7 +512,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) { } } else { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, channel, comm, args->opCount); + ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; |