diff options
author | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2019-11-20 01:57:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-20 01:57:39 +0300 |
commit | 299c554dccf923230321ad7495946543f3e9b457 (patch) | |
tree | 6a70b52080f0570fc87285b3b2300dbd2f2918ad /src/collectives/device | |
parent | ccb1298148327bacb9b83452ed6ae0b29417e7e2 (diff) |
2.5.6-1 (#255)
Add LL128 Protocol.
Rewrite the topology detection and tree/ring creation (#179). Improve
tree performance by sending/receiving from different GPUs. Add
model-based tuning to switch between the different algorithms and
protocols.
Rework P2P/SHM detection in containers (#155, #248).
Detect duplicated devices and return an error (#231).
Add tuning for GCP
Diffstat (limited to 'src/collectives/device')
-rw-r--r-- | src/collectives/device/Makefile | 2 | ||||
-rw-r--r-- | src/collectives/device/all_gather.h | 72 | ||||
-rw-r--r-- | src/collectives/device/all_reduce.h | 183 | ||||
-rw-r--r-- | src/collectives/device/broadcast.h | 52 | ||||
-rw-r--r-- | src/collectives/device/common.h | 31 | ||||
-rw-r--r-- | src/collectives/device/common_kernel.h | 2 | ||||
-rw-r--r-- | src/collectives/device/functions.cu | 15 | ||||
-rw-r--r-- | src/collectives/device/op128.h | 36 | ||||
-rw-r--r-- | src/collectives/device/primitives.h | 481 | ||||
-rw-r--r-- | src/collectives/device/prims_ll.h | 259 | ||||
-rw-r--r-- | src/collectives/device/prims_ll128.h | 410 | ||||
-rw-r--r-- | src/collectives/device/reduce.h | 49 | ||||
-rw-r--r-- | src/collectives/device/reduce_scatter.h | 67 |
13 files changed, 1269 insertions, 390 deletions
diff --git a/src/collectives/device/Makefile b/src/collectives/device/Makefile index 0ee587b..001059c 100644 --- a/src/collectives/device/Makefile +++ b/src/collectives/device/Makefile @@ -68,4 +68,4 @@ $(DEVOBJ) : $(LIBOBJ) $(NVCC) $(NVCUFLAGS) -dlink $^ -o $@ clean: - rm -f $(LIBOBJ) $(DEVOBJ) $(DEPFILES) $(DEPENDFILES) $(STATICLIB) test + rm -f $(LIBOBJ) $(DEVOBJ) $(DEPFILES) $(DEPENDFILES) $(RULESFILE) $(STATICLIB) diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 8e78730..0ad5ba9 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -11,7 +11,7 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; @@ -19,15 +19,15 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) { const ssize_t size = args->N; const int nranks = comm->nRanks; const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS; + const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS; const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize; // Compute pointers const T * __restrict__ thisInput = (const T*)args->ThisInput; T * __restrict__ thisOutput = (T*)args->ThisOutput; - ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, FUNC> + prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels)); @@ -129,3 +129,67 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { } + +#include "prims_ll128.h" +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nthreads = args->nThreads; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclRing* ring = &channel->ring; + + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; + + const ssize_t loopSize = args->nChannels*chunkSize; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize); + + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin AllGather steps /////////////// + ssize_t offset; + int nelem = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + if (thisInput + chunkOffset == thisOutput + offset) { // In place + LLprims.send(thisInput+chunkOffset, nelem); + } else { + LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem); + } + + // k-2 steps: copy to next GPU + for (int j=1; j<nranks-1; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + LLprims.recvCopySend(thisOutput+offset, nelem); + } + + // step k-1: final store + rankDest = ring->devUserRanks[1]; + offset = chunkOffset + rankDest * size; + + LLprims.recv(thisOutput+offset, nelem); + } +} + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllGatherTreeLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 9b058cc..2449c2b 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -11,7 +11,7 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; @@ -27,7 +27,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { T * __restrict__ thisOutput = (T*)args->ThisOutput; ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); + prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*args->nChannels)); @@ -85,23 +85,28 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclTree* tree = &channel->tree; const ssize_t size = args->N; const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); - const int chunkSize = args->lastChunkSize; + int chunkSize = args->lastChunkSize; + const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T); const ssize_t loopSize = args->nChannels*chunkSize; + if (loopSize > size) { + chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + } + // Compute pointers const T * __restrict__ thisInput = (const T*)args->ThisInput; T * __restrict__ thisOutput = (T*)args->ThisOutput; do { + struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -117,8 +122,9 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { } while(0); do { + struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -149,6 +155,8 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { //const int rank = comm->rank; const int nranks = comm->nRanks; ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T); + const ssize_t loopSize = args->nChannels*nranks*chunkSize; // Compute pointers @@ -156,10 +164,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { T * __restrict__ thisOutput = (T*)args->ThisOutput; for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - if (size-gridOffset < loopSize) { - chunkSize = args->lastChunkSize; - } - ssize_t chunkOffset = gridOffset + bid*nranks*chunkSize; + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); /////////////// begin AllReduce steps /////////////// ssize_t offset; @@ -168,7 +173,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // step 0: push data to next GPU slice = ring->devUserRanks[nranks-1]; - offset = chunkOffset + slice * chunkSize; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.send(thisInput+offset, nelem); @@ -176,7 +181,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // k-2 steps: reduce and copy to next GPU for (int j=2; j<nranks; ++j) { slice = ring->devUserRanks[nranks-j]; - offset = chunkOffset + slice * chunkSize; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceSend(thisInput+offset, nelem); @@ -185,7 +190,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // step k-1: reduce this buffer and data, which will produce the final // result that we store in this data and push to the next GPU slice = ring->devUserRanks[0]; - offset = chunkOffset + slice * chunkSize; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); @@ -193,7 +198,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // k-2 steps: copy to next GPU for (int j=1; j<nranks-1; ++j) { slice = ring->devUserRanks[nranks-j]; - offset = chunkOffset + slice * chunkSize; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); LLprims.recvCopySend(thisOutput+offset, nelem); @@ -201,7 +206,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { // Make final copy from buffer to dest. slice = ring->devUserRanks[1]; - offset = chunkOffset + slice * chunkSize; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; nelem = min(chunkSize, size-offset); // Here we need to copy from buffer to this output. @@ -216,16 +221,21 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; - struct ncclTree* tree = &channel->tree; const ssize_t size = args->N; ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T); const ssize_t loopSize = args->nChannels*chunkSize; + if (loopSize > size) { + chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + } + // Compute pointers const T * __restrict__ thisInput = (const T*)args->ThisInput; T * __restrict__ thisOutput = (T*)args->ThisOutput; do { + struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { @@ -243,6 +253,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { } while(0); do { + struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { @@ -259,3 +270,141 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { } } while(0); } + +#include "prims_ll128.h" +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nthreads = args->nThreads; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclRing* ring = &channel->ring; + + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; + + const ssize_t loopSize = args->nChannels*nranks*chunkSize; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize); + + /////////////// begin AllReduce steps /////////////// + ssize_t offset; + int nelem; + int slice; + + // step 0: push data to next GPU + slice = ring->devUserRanks[nranks-1]; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; + nelem = min(chunkSize, size-offset); + + LLprims.send(thisInput+offset, nelem); + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + slice = ring->devUserRanks[nranks-j]; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; + nelem = min(chunkSize, size-offset); + + LLprims.recvReduceSend(thisInput+offset, nelem); + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data and push to the next GPU + slice = ring->devUserRanks[0]; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; + nelem = min(chunkSize, size-offset); + + LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); + + // k-2 steps: copy to next GPU + for (int j=1; j<nranks-1; ++j) { + slice = ring->devUserRanks[nranks-j]; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; + nelem = min(chunkSize, size-offset); + + LLprims.recvCopySend(thisOutput+offset, nelem); + } + + // Make final copy from buffer to dest. + slice = ring->devUserRanks[1]; + offset = gridOffset + (slice*args->nChannels+bid) * chunkSize; + nelem = min(chunkSize, size-offset); + + // Here we need to copy from buffer to this output. + LLprims.recv(thisOutput+offset, nelem); + } +} + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->bid; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclTree* treeUp = &channel->treeUp; + struct ncclTree* treeDn = &channel->treeDn; + const ssize_t size = args->N; + ssize_t chunkSize = args->lastChunkSize; + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8; + const ssize_t loopSize = args->nChannels*chunkSize; + int nthreadsSplit = NCCL_LL128_SPLIT(nthreads); + + if (loopSize > size) { + chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + } + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + if (treeUp->up == -1) { + // ReduceAndBroadcast : max number of recv is 3, max number of send is 3 + ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem); + } + } else { + if (tid < nthreadsSplit) { + // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) + ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Up + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (treeUp->down[0] == -1) { + LLprims.send(thisInput+offset, nelem); + } else { + LLprims.recvReduceSend(thisInput+offset, nelem); + } + } + } else { + // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) + ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Down + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (treeDn->down[0] == -1) { + LLprims.recv(thisOutput+offset, nelem); + } else { + LLprims.recvCopySend(thisOutput+offset, nelem); + } + } + } + } +} diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index ae8667f..de8b989 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -11,7 +11,7 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; @@ -29,7 +29,7 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) { T * __restrict__ thisOutput = (T*)args->ThisOutput; ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); + prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels)); @@ -100,3 +100,51 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { } + +#include "prims_ll128.h" +template<int UNUSED, class FUNC, typename T> +__device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nthreads = args->nThreads; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclRing* ring = &channel->ring; + + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); + + const ssize_t size = args->N; + const int rank = ring->devUserRanks[0]; + const int nextRank = ring->devUserRanks[1]; + const int root = args->root; + + ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + + const ssize_t loopSize = args->nChannels*chunkSize; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize); + ssize_t offset = gridOffset + bid*chunkSize; + + int nelem = min(chunkSize, size-offset); + if (rank == root) { + if (thisInput == thisOutput) { + LLprims.send(thisInput+offset, nelem); + } else { + LLprims.copySend(thisInput + offset, thisOutput + offset, nelem); + } + } else if (nextRank == root) { + LLprims.recv(thisOutput + offset, nelem); + } else { + LLprims.recvCopySend(thisOutput + offset, nelem); + } + } +} + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclBroadcastTreeLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 8c336bf..46eb9f5 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -7,9 +7,8 @@ #ifndef NCCL_DEVICE_COMMON_H_ #define NCCL_DEVICE_COMMON_H_ -#include "../collectives.h" +#include "collectives.h" #include "devcomm.h" -#include "nccl.h" // Exit If Abort Barrier across CTA: make sure all threads exit consistently // Each thread sets a predicate to true if abort == 1 @@ -31,17 +30,19 @@ extern __device__ ncclKern_t ncclFuncs[]; static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) { int* d = (int*)dst; int* s = (int*)src; - // When aggregation is effective, if some threads have aborted inside the LL kernel, - // make sure the rest of the threads abort as well - exitIfAbortBarrier(0); for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o]; - __syncthreads(); } -static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid) { +static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid, struct ncclDevComm* comm) { + // Check whether the last operation was aborted and make sure all threads exit + int abort = tid == 0 ? *(comm->abortFlag) : 0; + exitIfAbortBarrier(abort); load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid); + __syncthreads(); if (tid == 0) hostColl->active = 0; } +extern __device__ volatile uint64_t* ncclShmem; + /* Functions for aggregation case */ #define IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \ @@ -51,10 +52,11 @@ __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \ #if NCCL_OP == 0 /* Kernels with the first operation inlined */ #define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \ -__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ int tid = threadIdx.x; \ int bid = blockIdx.x; \ + __shared__ volatile uint64_t shmem[NCCL_LL128_SHMEM_SIZE]; \ + ncclShmem = shmem; \ __shared__ struct ncclColl localColl; \ \ struct ncclDevComm* comm = firstColl.args.comm; \ @@ -65,7 +67,7 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ c = &firstColl; \ } else { \ c = &localColl; \ - load_coll(c, channel->devCollectives+channel->collFifoHead, tid); \ + load_coll(c, channel->devCollectives+channel->collFifoHead, tid, comm); \ } \ while (1) { \ if (tid < c->args.nThreads) { \ @@ -84,7 +86,7 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ \ /* Load next collective operation*/ \ c = &localColl; /* for bid 0 */ \ - load_coll(c, channel->devCollectives+nextIndex, tid); \ + load_coll(c, channel->devCollectives+nextIndex, tid, comm); \ } \ } #else @@ -93,13 +95,14 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ // Only generate inline kernels for LL #define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \ - IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \ - IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \ + IMPL_COLL_FUNC(coll##LL128, op, ncclFunc, dtype, ctype) \ + IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ + IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, al, NCCL_PROTO_LL)) \ #define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ - IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \ - IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1) + IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \ + IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) #if NCCL_TYPE == 0 #define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index 435a598..aa1e936 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -263,8 +263,6 @@ __device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthread } } -#define WARP_SIZE 32 - template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS> __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t, int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS], diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 010c454..034fe96 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -8,13 +8,16 @@ #include "collectives.h" #include "common.h" +__device__ volatile uint64_t* ncclShmem; + #define NCCL_FUNC5(coll, op, dtype) \ - NCCL_COLL_NAME(coll, op, dtype), \ - NCCL_COLL_NAME(coll##LL, op, dtype) + NCCL_COLL_NAME(coll##LL, op, dtype), \ + NCCL_COLL_NAME(coll##LL128, op, dtype), \ + NCCL_COLL_NAME(coll, op, dtype) #define NCCL_FUNC4(coll, op, dtype) \ - NCCL_FUNC5(coll##Ring, op, dtype), \ - NCCL_FUNC5(coll##Tree, op, dtype) + NCCL_FUNC5(coll##Tree, op, dtype), \ + NCCL_FUNC5(coll##Ring, op, dtype) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(coll, op) \ @@ -50,7 +53,7 @@ NCCL_FUNCS3B(coll, copy), \ NCCL_FUNCS3B(coll, copy) -// Must be consistent with ncclColl_t +// Must be consistent with ncclFunc_t #define NCCL_FUNCS() { \ NCCL_FUNCS2B(ncclBroadcast), \ NCCL_FUNCS2A(ncclReduce), \ @@ -59,7 +62,7 @@ NCCL_FUNCS2A(ncclAllReduce) } // Must be consistent with the ncclFuncSet enum -__device__ ncclKern_t ncclFuncs[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = { +__device__ ncclKern_t ncclFuncs[NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { // Don't try to initialize the host shadow copy of this device-side global // variable. There is no host pointer to a device-side function, which // confuses clang. This will be fixed in the next clang release. diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h new file mode 100644 index 0000000..9405dc2 --- /dev/null +++ b/src/collectives/device/op128.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef OP128_H_ +#define OP128_H_ + +inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { + asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" + : "=l"(v0), "=l"(v1) : "l"(ptr)); +} + +inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) { + asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" + :: "l"(v0), "l"(v1), "l"(ptr)); +} + +inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) { + uint64_t* shmemAsmPtr; + asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(shmemAsmPtr) : "l"(shmemGenericPtr)); + return shmemAsmPtr; +} + +inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) { + asm volatile("ld.volatile.shared.v2.u64 {%0,%1}, [%2];" + : "=l"(v0), "=l"(v1) : "l"(shmemAsmPtr)); +} + +inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) { + asm volatile("st.volatile.shared.v2.u64 [%2], {%0,%1};" + :: "l"(v0), "l"(v1), "l"(shmemAsmPtr)); +} + +#endif diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 7beeaf4..aa3d20d 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -37,15 +37,27 @@ class ncclPrimitives { private: const int tid; const int nthreads; + const int wid; + const int stepSize; int nrecv = 0; int nsend = 0; - const int stepSize; - struct ncclConnInfo* recvConn[NRECV]; - struct ncclConnInfo* sendConn[NSEND]; - volatile uint64_t* waitPtr; + struct ncclConnInfo* recvConn = NULL; + volatile uint64_t* recvConnHeadPtr = NULL; + uint64_t recvConnHead; + volatile uint64_t* recvConnTailPtr = NULL; + uint64_t recvConnTail; + uint64_t recvConnTailCache; // Cache last seen value + + struct ncclConnInfo* sendConn = NULL; + volatile int* sendConnFifoPtr = NULL; + volatile uint64_t* sendConnTailPtr = NULL; + uint64_t sendConnTail; + volatile uint64_t* sendConnHeadPtr = NULL; + uint64_t sendConnHead; + uint64_t sendConnHeadCache; // Cache last seen value + uint64_t recvStep[NRECV]; uint64_t sendStep[NSEND]; - uint64_t sendConnHead[NSEND]; const T* recvDirectBuff[NRECV]; T* sendDirectBuff[NSEND]; const T* recvBuff[NRECV]; @@ -60,15 +72,18 @@ class ncclPrimitives { inline __device__ void barrier() { asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); } + inline __device__ void subBarrier() { + asm volatile ("bar.sync 2, %0;" :: "r"(nthreads-WARP_SIZE)); + } uint32_t mismatch = 0; const uint64_t opCount; - inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) { + inline __device__ void checkMismatch(struct ncclConnInfo* conn) { if (mismatch) { // In non-LL, we use _threadfence_system before incrementing opCount, yet we are still waiting for credits here, so there must be a size mismatch *(comm->fatalDevError) = ncclDevAssertedMismatch; - } else if (remoteOpCount && *remoteOpCount > opCount) { + } else if (conn && *conn->opCountRem > opCount) { mismatch += 1; } } @@ -76,49 +91,55 @@ class ncclPrimitives { uint32_t spins = 0; uint32_t abort = 0; - inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) { + inline __device__ int checkAbort(int i, int send) { spins++; - if (spins == SPINS_BEFORE_CHECK_ABORT) { + if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { abort = *(comm->abortFlag); - checkMismatch(remoteOpCount); + if (wid == i) checkMismatch(send ? sendConn : recvConn); spins = 0; } return abort; } - inline __device__ void waitRecv(int i) { + inline __device__ void waitSend(int nbytes) { spins = 0; mismatch = 0; - recvStep[i] += SLICESTEPS; - if (tid == i) { - while (*(waitPtr) < recvStep[i]) { - if (checkAbort(recvConn[i]->opCountRem)) break; + if (sendConnHeadPtr) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) { + sendConnHeadCache = *sendConnHeadPtr; + if (checkAbort(wid, 1)) break; + } + if (sendConnFifoPtr) { + sendConnFifoPtr[sendConnHead%NCCL_STEPS] = nbytes; } + sendConnHead += SLICESTEPS; } } - inline __device__ void waitSend(int i) { + inline __device__ void waitRecv() { spins = 0; mismatch = 0; - sendStep[i] += SLICESTEPS; - if (tid == WARP_SIZE+i) { - while (sendConnHead[i] + NCCL_STEPS < sendStep[i]) { - sendConnHead[i] = *waitPtr; - if (checkAbort(sendConn[i]->opCountRem)) break; + if (recvConnTailPtr) { + while (recvConnTailCache < recvConnTail + SLICESTEPS) { + recvConnTailCache = *recvConnTailPtr; + if (checkAbort(wid, 0)) break; } + recvConnTail += SLICESTEPS; } } - inline __device__ void postRecv(int i) { - *(recvConn[i]->head) = recvStep[i] += SLICESTEPS; + inline __device__ void incRecv(int i) { + recvStep[i] += SLICESTEPS; } - - inline __device__ void postSend(int i) { - *(sendConn[i]->tail) = sendStep[i] += SLICESTEPS; + inline __device__ void postRecv() { + if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += SLICESTEPS; } - inline __device__ void postSendSize(int i, int size) { - if (sendConn[i]->fifo) sendConn[i]->fifo[sendStep[i]%NCCL_STEPS] = size; + inline __device__ void incSend(int i) { + sendStep[i] += SLICESTEPS; + } + inline __device__ void postSend() { + if (sendConnTailPtr) *sendConnTailPtr = sendConnTail += SLICESTEPS; } template <int DIRECTRECV> @@ -131,11 +152,22 @@ class ncclPrimitives { return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i); } + template <int DIRECTRECV> + inline __device__ int directRecvInc(int i, int directInc, int sliceInc) { + return DIRECTRECV && recvDirectBuff[i] ? directInc : sliceInc; + } + + template <int DIRECTSEND> + inline __device__ int directSendInc(int i, int directInc, int sliceInc) { + return DIRECTSEND && sendDirectBuff[i] ? directInc : sliceInc; + } + template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST> inline __device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem, int directOffset) { int offset = 0; - int sliceSize = stepSize * SLICESTEPS; + int sliceSize = stepSize*SLICESTEPS; + int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32); const T* srcs[RECV*NRECV+SRC]; srcs[0] = SRC ? srcPtr : directRecvPtr<DIRECTRECV>(0, directOffset); @@ -151,101 +183,126 @@ class ncclPrimitives { for (int i=1; i<NSEND && i<nsend; i++) dsts[DST+i] = directSendPtr<DIRECTSEND>(i, directOffset); } - #pragma unroll 1 + bool syncThread = tid >= nthreads-WARP_SIZE; + + #pragma unroll for (int slice=0; slice<SLICESPERCHUNK; ++slice) { - int realSize = max(0, min(sliceSize, nelem-offset)); - if (tid < nthreads) { - FOR_SEND(waitSend); - FOR_RECV(waitRecv); + int realSize = max(0, min(dataSize, nelem-offset)); + if (!syncThread) { + if (SEND) waitSend(realSize*sizeof(T)); + if (RECV) waitRecv(); if (realSize > 0) { - barrier(); + subBarrier(); if (DIRECTRECV && recvDirectBuff[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (SEND) { - ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize); + ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads-WARP_SIZE, 1, srcs, nsend, dsts+1, realSize); } } else { - ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize); + ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads-WARP_SIZE, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize); } } - exitIfAbortBarrier(abort); - } else { - exitIfAbortBarrier(abort); - FOR_SEND(postSendSize, realSize*sizeof(T)); - if (SEND) __threadfence_system(); - FOR_SEND(postSend); - FOR_RECV(postRecv); } - for (int i=0; i<RECV*NRECV+SRC; i++) srcs[i] += sliceSize; - for (int i=0; i<SEND*NSEND+DST; i++) dsts[i] += sliceSize; - offset += sliceSize; + barrier(); + FOR_SEND(incSend); + FOR_RECV(incRecv); + if (syncThread) { + if (SEND) { + if (realSize > 0 && wid == 0) __threadfence_system(); + __syncwarp(); + postSend(); + } + if (RECV) postRecv(); + } + srcs[0] += SRC ? realSize : directRecvInc<DIRECTRECV>(0, realSize, sliceSize); + for (int i=1-SRC; i<RECV*NRECV; i++) srcs[SRC+i] += sliceSize; + dsts[0] += DST ? realSize : directSendInc<DIRECTSEND>(0, realSize, sliceSize); + for (int i=1-DST; i<SEND*NSEND; i++) dsts[DST+i] += directSendInc<DIRECTSEND>(i, realSize, sliceSize); + offset += realSize; } } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) { - recvConn[i] = conn; - recvBuff[i] = (const T*)recvConn[i]->buff; - recvStep[i] = recvConn[i]->step; + recvBuff[i] = (const T*)conn->buff; + recvStep[i] = conn->step; recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS); - // Return credits in case we rounded up. - if (tid == nthreads) *recvConn[i]->head = recvStep[i]; - if (tid == i) { - waitPtr = recvConn[i]->tail; - *(recvConn[i]->opCountLoc) = opCount; - } recvDirectBuff[i] = NULL; - if (directBuff && recvConn[i]->direct) { + if (directBuff && conn->direct) { recvDirectBuff[i] = directBuff; - if (tid == 0) *recvConn[i]->ptrExchange = directBuff; + if (tid == 0) *conn->ptrExchange = directBuff; } + if (wid == i) recvConn = conn; + if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up nrecv++; } + __device__ __forceinline__ void loadRecvSync() { + if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && wid<nrecv) { + recvConnTailPtr = recvConn->tail; + recvConnTailCache = *recvConnTailPtr; + } + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConnHeadPtr = recvConn->head; + // Return credits in case we rounded up. + *recvConnHeadPtr = recvConnHead; + // Update opCount in case we skipped some operations + *(recvConn->opCountLoc) = opCount; + } + } __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i, T* directBuff) { - sendConn[i] = conn; - sendBuff[i] = (T*)sendConn[i]->buff; - sendStep[i] = sendConn[i]->step; + sendBuff[i] = (T*)conn->buff; + sendStep[i] = conn->step; sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS); - if (tid == WARP_SIZE+i) { - waitPtr = sendConn[i]->head; - sendConnHead[i] = *waitPtr; - *(sendConn[i]->opCountLoc) = opCount; - } sendDirectBuff[i] = NULL; - if (directBuff && sendConn[i]->direct) { - void* volatile* ptr = sendConn[i]->ptrExchange; + if (directBuff && conn->direct) { + void* volatile* ptr = conn->ptrExchange; while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL); - __syncthreads(); + barrier(); if (tid == 0) *ptr = NULL; } + if (wid == i) sendConn = conn; + if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up nsend++; } + __device__ __forceinline__ void loadSendSync() { + if (tid < nsend) { + sendConnHeadPtr = sendConn->head; + sendConnHeadCache = *sendConnHeadPtr; + sendConnFifoPtr = sendConn->fifo; + *(sendConn->opCountLoc) = opCount; + } + if (tid >= nthreads-WARP_SIZE && wid<nsend) { + sendConnTailPtr = sendConn->tail; + } + } - __device__ __forceinline__ void saveRecvConn(int i) { - if (tid == i) { - recvConn[i]->step = recvStep[i]; + __device__ __forceinline__ void saveRecvSync() { + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConn->step = recvConnHead; + *(recvConn->opCountLoc) = opCount+1; __threadfence_system(); - *(recvConn[i]->opCountLoc) += 1; } } - __device__ __forceinline__ void saveSendConn(int i) { - if (tid == WARP_SIZE+i) { - sendConn[i]->step = sendStep[i]; + __device__ __forceinline__ void saveSendSync() { + if (tid < nsend) { + sendConn->step = sendConnHead; + *(sendConn->opCountLoc) = opCount+1; __threadfence_system(); - *(sendConn[i]->opCountLoc) += 1; } } public: __device__ __forceinline__ ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) - : comm(comm), tid(tid), nthreads(nthreads), stepSize(stepSize), opCount(opCount) { - // Make sure step is updated before we read it - __syncthreads(); + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) { + // Make sure step is updated before we read it. + barrier(); for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff); for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i, directBuff); + loadRecvSync(); + loadSendSync(); } __device__ __forceinline__ void @@ -305,267 +362,13 @@ class ncclPrimitives { } __device__ __forceinline__ ~ncclPrimitives() { - // Save steps for next collective. Have thread 0 do it to be compatible - // with the way LL works. - for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i); - for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i); + // Save steps for the next operation + saveRecvSync(); + saveSendSync(); } }; -template <typename T, class FUNC, int NRECV, int NSEND> -class ncclLLPrimitives { - private: - const int tid; - const int nthreads; - int nrecv = 0; - int nsend = 0; - struct ncclConnInfo* recvConn[NRECV]; - struct ncclConnInfo* sendConn[NSEND]; - volatile uint64_t* waitPtr; - volatile uint64_t* postPtr; - volatile int* fifoPtr; - uint64_t recvStep[NRECV]; - uint64_t sendStep[NSEND]; - uint64_t sendConnHead; - union ncclLLFifoLine* recvBuff[NRECV]; - union ncclLLFifoLine* sendBuff[NSEND]; - struct ncclDevComm* comm; - - inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } - inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } - inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } - inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } - inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } - inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } - - // Exit If Abort Barrier : make sure all threads exit consistently - // Each thread sets a predicate to true if val == 1 - // all CTA's threads enter the barrier and do a popc on their predicates being True - // If any of the thread's predicate was True, all the threads call exit() - inline __device__ void exitIfAbortLocalBarrier() { - uint32_t popc; - asm ("{"); - asm volatile (" .reg .pred barr_pred;"); - asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort)); - asm volatile (" bar.red.popc.u32 %0, 14, %1, barr_pred;" : "=r"(popc) : "r"(nthreads)); - asm ("}"); - if (popc) { - // Make sure threads not participating in the operation get the abort and all threads exit - exitIfAbortBarrier(1); - } - } - - inline __device__ void barrier() { - asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); - } - - uint32_t mismatch = 0; - const uint64_t opCount; +#include "prims_ll.h" +//#include "prims_ll128.h" - inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) { - if (mismatch > 20) { - // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch - // Note that we are not using _threadfence_system in LL so the error cannot be asserted - *(comm->fatalDevError) = ncclDevSuspectedMismatch; - } else if (remoteOpCount && *remoteOpCount > opCount) { - mismatch += 1; - } - } - - uint32_t spins = 0; - uint32_t abort = 0; - - inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) { - spins++; - if (spins == SPINS_BEFORE_CHECK_ABORT) { - abort = *(comm->abortFlag); - checkMismatch(remoteOpCount); - spins = 0; - } - return abort; - } - - inline __device__ void waitSend(int i, int nbytes) { - spins = 0; - mismatch = 0; - if (tid == WARP_SIZE+i) { - while (sendConnHead + NCCL_STEPS < sendStep[i] + 1) { - sendConnHead = *waitPtr; - if (checkAbort(sendConn[i]->opCountRem)) break; - } - if (fifoPtr) { - int size = ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes; - fifoPtr[sendStep[i]%NCCL_STEPS] = size; - } - } - } - - inline __device__ void postRecv(int i) { - recvStep[i]++; - if (tid == i) *postPtr = recvStep[i]; - } - - inline __device__ void postSend(int i, int offset) { - // LL Cleanup : write all flags in the slice to make sure we don't have - // data corruption when flag loops over. - if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { - for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i)); - } - sendStep[i]++; - } - - __device__ uint64_t readLL(int i, int offset) { - union ncclLLFifoLine* src = recvPtr(i) + offset; - uint32_t flag = recvFlag(i); - uint32_t data1, flag1, data2, flag2; - spins = 0; - mismatch = 0; - do { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); - if (checkAbort(recvConn[i]->opCountRem)) break; - } while ((flag1 != flag) || (flag2 != flag)); - uint64_t val64 = data1 + (((uint64_t)data2) << 32); - return val64; - } - - __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag)); - } - - // Using memcpy handles misaligned pointers. - __device__ uint64_t readAL(uint64_t* src) { - uint64_t val; - memcpy((char*)&val, (char*)src, sizeof(uint64_t)); - return val; - } - - __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) { - memcpy((char*)dst, (char*)&val, nbytes); - } - - template <int RECV, int SEND, int SRC, int DST> - __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) { - uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T); - FOR_SEND(waitSend, nbytes*2); - barrier(); - uint32_t npack = DIVUP(nbytes, sizeof(uint64_t)); - uint64_t* srcPack = (uint64_t*)srcPtr; - uint64_t* dstPack = (uint64_t*)dstPtr; - int offset = tid; - // Do multiples of 64 bits - #pragma unroll 2 - for (; offset<npack; offset+=nthreads) { - // Recv : local, then intra-node, then inter-node - uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset); - if (RECV) { - if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val); - for (int i=1; i<NRECV && i<nrecv; i++) { - val = MULTI<FUNC, T>()(readLL(i, offset), val); - } - } - - // Send : inter-node, then intra-node, then local - if (SEND) { - for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i)); - storeLL(sendPtr(0)+offset, val, sendFlag(0)); - } - if (DST) { - if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) { - // Last incomplete word - storeAL(dstPack+offset, val, nbytes & 0x7); - } else { - storeAL(dstPack+offset, val, sizeof(uint64_t)); - } - } - } - exitIfAbortLocalBarrier(); - FOR_RECV(postRecv); - FOR_SEND(postSend, offset); - } - - __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { - recvConn[i] = conn; - recvBuff[i] = recvConn[i]->llBuff; - recvStep[i] = recvConn[i]->step; - if (tid == i) { - postPtr = recvConn[i]->head; - *(recvConn[i]->opCountLoc) = opCount; - } - nrecv++; - } - - __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { - sendConn[i] = conn; - sendBuff[i] = sendConn[i]->llBuff; - sendStep[i] = sendConn[i]->step; - if (tid == WARP_SIZE+i) { - waitPtr = sendConn[i]->head; - fifoPtr = sendConn[i]->fifo; - sendConnHead = *waitPtr; - *(sendConn[i]->opCountLoc) = opCount; - } - nsend++; - } - - __device__ __forceinline__ void saveRecvConn(int i) { - if (tid == i) { - recvConn[i]->step = recvStep[i]; - *(recvConn[i]->opCountLoc) += 1; - __threadfence_block(); - } - } - - __device__ __forceinline__ void saveSendConn(int i) { - if (tid == WARP_SIZE+i) { - sendConn[i]->step = sendStep[i]; - *(sendConn[i]->opCountLoc) += 1; - __threadfence_block(); - } - } - - public: - __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) - : comm(comm), tid(tid), nthreads(nthreads), opCount(opCount) { - // Make sure step is updated before we read it. - barrier(); - - for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i); - for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i); - } - - __device__ void send(const T* src, int nelem) { - return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recv(T* dst, int nelem) { - return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceSend(const T* src, int nelem) { - return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem); - } - - __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { - return LLGenericOp<1, 0, 1, 1>(src, dst, nelem); - } - - __device__ void copySend(const T* src, T* dst, int nelem) { - return LLGenericOp<0, 1, 1, 1>(src, dst, nelem); - } - - __device__ void recvCopySend(T* dst, int nelem) { - return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem); - } - - __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { - return LLGenericOp<1, 1, 1, 1>(src, dst, nelem); - } - - __device__ __forceinline__ ~ncclLLPrimitives() { - // Save steps for the next operation - for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i); - for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i); - } -}; #endif diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h new file mode 100644 index 0000000..f919493 --- /dev/null +++ b/src/collectives/device/prims_ll.h @@ -0,0 +1,259 @@ +template <typename T, class FUNC, int NRECV, int NSEND> +class ncclLLPrimitives { + private: + const int tid; + const int nthreads; + const int wid; + int nrecv = 0; + int nsend = 0; + struct ncclConnInfo* recvConn = NULL; + volatile uint64_t* recvConnHeadPtr = NULL; + uint64_t recvConnHead; + + struct ncclConnInfo* sendConn = NULL; + volatile int* sendConnFifoPtr = NULL; + volatile uint64_t* sendConnHeadPtr = NULL; + uint64_t sendConnHead; + uint64_t sendConnHeadCache; // Cache last seen value + + uint64_t recvStep[NRECV]; + uint64_t sendStep[NSEND]; + union ncclLLFifoLine* recvBuff[NRECV]; + union ncclLLFifoLine* sendBuff[NSEND]; + struct ncclDevComm* comm; + + inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } + inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; } + inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } + inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } + inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } + inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } + + inline __device__ void barrier() { + asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); + } + + uint32_t mismatch = 0; + const uint64_t opCount; + + inline __device__ void checkMismatch(struct ncclConnInfo* conn) { + if (mismatch > 20) { + // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch + // Note that we are not using _threadfence_system in LL so the error cannot be asserted + *(comm->fatalDevError) = ncclDevSuspectedMismatch; + } else if (conn && *conn->opCountRem > opCount) { + mismatch += 1; + } + } + + uint32_t spins = 0; + uint32_t abort = 0; + + inline __device__ int checkAbort(int i, int send) { + spins++; + if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { + abort = *(comm->abortFlag); + if (wid == i) checkMismatch(send ? sendConn : recvConn); + spins = 0; + } + return abort; + } + + inline __device__ void waitSend(int nbytes) { + spins = 0; + mismatch = 0; + if (sendConnHeadPtr) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { + sendConnHeadCache = *sendConnHeadPtr; + if (checkAbort(wid, 1)) break; + } + if (sendConnFifoPtr) { + int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes; + sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size; + } + sendConnHead += 1; + } + barrier(); + } + + inline __device__ void incRecv(int i) { + recvStep[i] += 1; + } + inline __device__ void postRecv() { + barrier(); + if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; + } + + inline __device__ void incSend(int i, int offset) { + // LL Cleanup : write all flags in the slice to make sure we don't have + // data corruption when flag loops over. + if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { + for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i)); + } + sendStep[i]++; + } + + __device__ uint64_t readLL(int i, int offset) { + union ncclLLFifoLine* src = recvPtr(i) + offset; + uint32_t flag = recvFlag(i); + uint32_t data1, flag1, data2, flag2; + spins = 0; + mismatch = 0; + do { + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); + if (checkAbort(i, 0)) break; + } while ((flag1 != flag) || (flag2 != flag)); + uint64_t val64 = data1 + (((uint64_t)data2) << 32); + return val64; + } + + __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag)); + } + + // Using memcpy handles misaligned pointers. + __device__ uint64_t readAL(uint64_t* src) { + uint64_t val; + memcpy((char*)&val, (char*)src, sizeof(uint64_t)); + return val; + } + + __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) { + memcpy((char*)dst, (char*)&val, nbytes); + } + + template <int RECV, int SEND, int SRC, int DST> + __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) { + uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T); + uint32_t npack = DIVUP(nbytes, sizeof(uint64_t)); + uint64_t* srcPack = (uint64_t*)srcPtr; + uint64_t* dstPack = (uint64_t*)dstPtr; + int offset = tid; + + // Always waitSend in case of cleanup + if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine)); + + // Do multiples of 64 bits + #pragma unroll 2 + for (; offset<npack; offset+=nthreads) { + // Recv : local, then intra-node, then inter-node + uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset); + if (RECV) { + if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val); + for (int i=1; i<NRECV && i<nrecv; i++) { + val = MULTI<FUNC, T>()(readLL(i, offset), val); + } + } + + // Send : inter-node, then intra-node, then local + if (SEND) { + for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i)); + storeLL(sendPtr(0)+offset, val, sendFlag(0)); + } + if (DST) { + if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) { + // Last incomplete word + storeAL(dstPack+offset, val, nbytes & 0x7); + } else { + storeAL(dstPack+offset, val, sizeof(uint64_t)); + } + } + } + FOR_RECV(incRecv); if (RECV) postRecv(); + FOR_SEND(incSend, offset); + } + + __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { + recvBuff[i] = conn->llBuff; + recvStep[i] = conn->step; + if (wid == i) recvConn = conn; + nrecv++; + } + __device__ __forceinline__ void loadRecvSync() { + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConnHeadPtr = recvConn->head; + recvConnHead = recvConn->step; + // Update opCount in case we skipped some operations + *(recvConn->opCountLoc) = opCount; + } + } + + __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { + sendBuff[i] = conn->llBuff; + sendStep[i] = conn->step; + if (wid == i) sendConn = conn; + nsend++; + } + __device__ __forceinline__ void loadSendSync() { + if (tid < nsend) { + sendConnHeadPtr = sendConn->head; + sendConnHeadCache = *sendConnHeadPtr; + sendConnHead = sendConn->step; + sendConnFifoPtr = sendConn->fifo; + *(sendConn->opCountLoc) = opCount; + } + } + + __device__ __forceinline__ void saveRecvSync() { + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConn->step = recvConnHead; + *(recvConn->opCountLoc) = opCount+1; + __threadfence_block(); + } + } + + __device__ __forceinline__ void saveSendSync() { + if (tid < nsend) { + sendConn->step = sendConnHead; + *(sendConn->opCountLoc) = opCount+1; + __threadfence_block(); + } + } + + public: + __device__ __forceinline__ + ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), opCount(opCount) { + // Make sure step is updated before we read it. + barrier(); + + for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i); + for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i); + loadRecvSync(); + loadSendSync(); + } + + __device__ void send(const T* src, int nelem) { + return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem); + } + + __device__ void recv(T* dst, int nelem) { + return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem); + } + + __device__ void recvReduceSend(const T* src, int nelem) { + return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem); + } + + __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { + return LLGenericOp<1, 0, 1, 1>(src, dst, nelem); + } + + __device__ void copySend(const T* src, T* dst, int nelem) { + return LLGenericOp<0, 1, 1, 1>(src, dst, nelem); + } + + __device__ void recvCopySend(T* dst, int nelem) { + return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem); + } + + __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { + return LLGenericOp<1, 1, 1, 1>(src, dst, nelem); + } + + __device__ __forceinline__ ~ncclLLPrimitives() { + // Save steps for the next operation + saveRecvSync(); + saveSendSync(); + } +}; diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h new file mode 100644 index 0000000..40a8cff --- /dev/null +++ b/src/collectives/device/prims_ll128.h @@ -0,0 +1,410 @@ +/************************************************************************* + * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "op128.h" + +#define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1) + +template <typename T, class FUNC, int NRECV, int NSEND> +class ncclLL128Primitives { + private: + const int tid; + const int nthreads; + const int wid; + const int warp; + const bool flagThread; + int nrecv = 0; + int nsend = 0; + struct ncclConnInfo* recvConn = NULL; + volatile uint64_t* recvConnHeadPtr = NULL; + uint64_t recvConnHead; + + struct ncclConnInfo* sendConn = NULL; + volatile int* sendConnFifoPtr = NULL; + volatile uint64_t* sendConnTailPtr = NULL; + uint64_t sendConnTail; + volatile uint64_t* sendConnHeadPtr = NULL; + uint64_t sendConnHead; + uint64_t sendConnHeadCache; // Cache last seen value + + uint64_t recvStep[NRECV]; + uint64_t sendStep[NSEND]; + uint64_t* recvBuff[NRECV]; + uint64_t* sendBuff[NSEND]; + struct ncclDevComm* comm; + + volatile uint64_t* shmem; + + inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; } + inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; } + inline __device__ uint64_t* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } + inline __device__ uint64_t* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } + inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; } + inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; } + + inline __device__ void barrier() { + if (NSEND>NRECV) { + asm volatile ("bar.sync 2, %0;" :: "r"(nthreads)); + } else { + asm volatile ("bar.sync 3, %0;" :: "r"(nthreads)); + } + } + + uint32_t mismatch = 0; + const uint64_t opCount; + + inline __device__ void checkMismatch(struct ncclConnInfo* conn) { + if (mismatch > 20) { + // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch + // Note that we are not using _threadfence_system in LL so the error cannot be asserted + *(comm->fatalDevError) = ncclDevSuspectedMismatch; + } else if (conn && *conn->opCountRem > opCount) { + mismatch += 1; + } + } + + uint32_t spins = 0; + uint32_t abort = 0; + + inline __device__ int checkAbort(int i, int send) { + spins++; + if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { + abort = *(comm->abortFlag); + if (wid == i) checkMismatch(send ? sendConn : recvConn); + spins = 0; + } + return abort; + } + + inline __device__ void waitSend(int nbytes) { + spins = 0; + mismatch = 0; + if (sendConnHeadPtr) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { + sendConnHeadCache = *sendConnHeadPtr; + if (checkAbort(wid, 1)) break; + } + if (sendConnFifoPtr) { + sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes; + } + sendConnHead += 1; + } + } + + inline __device__ void incRecv(int i) { + recvStep[i] += 1; + } + inline __device__ void postRecv() { + if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; + } + + inline __device__ void incSend(int i) { + sendStep[i] += 1; + } + inline __device__ void postSend() { + if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; } + } + + template <int ELEMS_PER_THREAD> + inline __device__ void loadSrcToShmem128(int maxOffset, const uint64_t* src64Ptr) { +#if 0 + uint64_t v[ELEMS_PER_THREAD]; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + if (u*WARP_SIZE < maxOffset) load128(src64Ptr+u*WARP_SIZE, v[u], v[u+1]); + } + uint64_t* shmemAsmPtr = shmemCvtPtr(shmem); + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + storeShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]); + } +#else + uint64_t* shmemAsmPtr = shmemCvtPtr(shmem); + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + if (u*WARP_SIZE < maxOffset) { + uint64_t v0, v1; + load128(src64Ptr+u*WARP_SIZE, v0, v1); + storeShmem128(shmemAsmPtr+u*WARP_SIZE, v0, v1); + } + } +#endif + } + + inline __device__ void loadSrcToShmem(int start, int end, const T* srcPtr) { + T* shmemPtr = (T*)(shmem-2*wid); + for (int offset = start+wid; offset < end; offset += WARP_SIZE) { + shmemPtr[offset] = srcPtr[offset]; + } + } + + template <int ELEMS_PER_THREAD> + inline __device__ void storeShmemToDst128(int maxOffset, uint64_t* dst64Ptr) { + uint64_t v[ELEMS_PER_THREAD]; + uint64_t* shmemAsmPtr = shmemCvtPtr(shmem); + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + loadShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]); + } + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + if (u*WARP_SIZE < maxOffset) store128(dst64Ptr+u*WARP_SIZE, v[u], v[u+1]); + } + } + + inline __device__ void storeShmemToDst(int start, int end, T* dstPtr) { + T* shmemPtr = (T*)(shmem-2*wid); + for (int offset = start+wid; offset < end; offset += WARP_SIZE) { + dstPtr[offset] = shmemPtr[offset]; + } + } + + #define WARP_MASK 0xffffffff + + template <int ELEMS_PER_THREAD, int RECV, int SEND, int SRC, int DST> + __device__ __forceinline__ void recvReduceSendCopy(int ll128Offset) { + uint64_t v[ELEMS_PER_THREAD]; + + /************* Data Loading : SHMEM -> REG **************/ + if (SRC) { + volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + v[u] = shmem64Ptr[u*(WARP_SIZE-2)]; + if (!flagThread) v[u+1] = shmem64Ptr[u*(WARP_SIZE-2)+1]; + } + } + /*********** End Data Loading : SHMEM -> REG ************/ + + /************************ Recv **************************/ + if (RECV) { + uint64_t flag = recvFlag(0); + uint64_t* ptr = recvPtr(0)+ll128Offset; + bool needReload; + uint64_t v0, v1; + do { + needReload = false; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + load128(ptr+u*WARP_SIZE, v0, v1); + needReload |= flagThread && (v1 != flag); + } + } while (__any_sync(WARP_MASK, needReload) && checkAbort(0, 0) == 0); + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + load128(ptr+u*WARP_SIZE, v0, v1); + v[u] = SRC ? MULTI<FUNC, T>()(v0, v[u]) : v0; + v[u+1] = SRC ? MULTI<FUNC, T>()(v1, v[u+1]) : v1; + } + + for (int i=1; i<NRECV && i<nrecv; i++) { + uint64_t flag = recvFlag(i); + uint64_t* ptr = recvPtr(i)+ll128Offset; + uint64_t v0, v1; + do { + needReload = false; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + load128(ptr+u*WARP_SIZE, v0, v1); + needReload |= flagThread && (v1 != flag); + } + } while (__any_sync(WARP_MASK, needReload) && checkAbort(i, 0) == 0); + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + load128(ptr+u*WARP_SIZE, v0, v1); + v[u] = MULTI<FUNC, T>()(v0, v[u]); + v[u+1] = MULTI<FUNC, T>()(v1, v[u+1]); + } + } + } + /********************** End Recv ************************/ + + /************************ Send **************************/ + if (SEND) { + for (int i=1; i<NSEND && i<nsend; i++) { + int flag = sendFlag(i); + uint64_t* ptr = sendPtr(i)+ll128Offset; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]); + } + } + int flag = sendFlag(0); + uint64_t* ptr = sendPtr(0)+ll128Offset; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]); + } + } + /********************** End Send ************************/ + + /************* Data Storing : REG -> SHMEM **************/ + if (DST) { + volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS; + #pragma unroll + for (int u=0; u<ELEMS_PER_THREAD; u+=2) { + shmem64Ptr[u*(WARP_SIZE-2)] = v[u]; + if (!flagThread) shmem64Ptr[u*(WARP_SIZE-2)+1] = v[u+1]; + } + } + /*********** End data Storing : REG -> SHMEM ************/ + } + + #define LL128INC (WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD) + #define ELEMINC (LL128INC-(LL128INC/NCCL_LL128_LINEELEMS)) + + template <int RECV, int SEND, int SRC, int DST> + __device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem) { + if (nelem <= 0) { + // Don't move any data but still increase steps and sync with prev/next + if (SEND) waitSend(0); + FOR_SEND(incSend); if (SEND) postSend(); + FOR_RECV(incRecv); if (RECV) postRecv(); + return; + } + const int nelem64 = ((nelem*sizeof(T))/(2*sizeof(uint64_t)))*2; + const uint64_t* src64Ptr = ((uint64_t*)srcPtr); + uint64_t* dst64Ptr = ((uint64_t*)dstPtr); + + int ll128Offset = LL128INC*warp+2*wid; + int elemOffset = ELEMINC*warp; + const int nwarps = nthreads/WARP_SIZE; + + if (SEND) waitSend(DIVUP(nelem*sizeof(T), ELEMINC*sizeof(uint64_t))*LL128INC*sizeof(uint64_t)); + barrier(); + + while (elemOffset*(sizeof(uint64_t)/sizeof(T)) < nelem) { + const int maxOffset128 = min(nelem64-elemOffset, (int)ELEMINC); + const int maxOffset = min(nelem-(elemOffset*((int)(sizeof(uint64_t)/sizeof(T)))), (int)(ELEMINC*(sizeof(uint64_t)/sizeof(T)))); + if (SRC) { + int done = 0; + if ((((uint64_t)srcPtr)&0xf) == 0) { + loadSrcToShmem128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, src64Ptr+elemOffset+2*wid); + done = maxOffset128*(sizeof(uint64_t)/sizeof(T)); + } + loadSrcToShmem(done, maxOffset, (T*)(src64Ptr+elemOffset)); + } + __syncwarp(); + recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SRC, DST>(ll128Offset); + __syncwarp(); + if (DST) { + int done = 0; + if ((((uint64_t)dstPtr)&0xf) == 0) { + storeShmemToDst128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, dst64Ptr+elemOffset+2*wid); + done = maxOffset128*(sizeof(uint64_t)/sizeof(T)); + } + storeShmemToDst(done, maxOffset, (T*)(dst64Ptr+elemOffset)); + } + __syncwarp(); + ll128Offset += LL128INC*nwarps; + elemOffset += ELEMINC*nwarps; + } + + barrier(); + FOR_SEND(incSend); if (SEND) postSend(); + FOR_RECV(incRecv); if (RECV) postRecv(); + } + + __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) { + recvBuff[i] = conn->ll128Buff; + recvStep[i] = conn->step; + if (wid == i) recvConn = conn; + nrecv++; + } + __device__ __forceinline__ void loadRecvSync() { + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConnHeadPtr = recvConn->head; + recvConnHead = recvConn->step; + // Update opCount in case we skipped some operations + *(recvConn->opCountLoc) = opCount; + } + } + + __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { + sendBuff[i] = conn->ll128Buff; + sendStep[i] = conn->step; + if (wid == i) sendConn = conn; + nsend++; + } + __device__ __forceinline__ void loadSendSync() { + if (tid < nsend) { + sendConnHeadPtr = sendConn->head; + sendConnHeadCache = *sendConnHeadPtr; + sendConnHead = sendConn->step; + sendConnFifoPtr = sendConn->fifo; + *(sendConn->opCountLoc) = opCount; + } + if (tid >= nthreads-WARP_SIZE && wid<nsend) { + if (sendConn->fifo) { + sendConnTailPtr = sendConn->tail; + sendConnTail = sendConn->step; + } + } + } + + __device__ __forceinline__ void saveRecvSync() { + if (tid >= nthreads-WARP_SIZE && wid < nrecv) { + recvConn->step = recvConnHead; + *(recvConn->opCountLoc) = opCount+1; + __threadfence_block(); + } + } + + __device__ __forceinline__ void saveSendSync() { + if (tid < nsend) { + sendConn->step = sendConnHead; + *(sendConn->opCountLoc) = opCount+1; + __threadfence_block(); + } + } + + public: + __device__ __forceinline__ + ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) { + // Make sure step is updated before we read it. + barrier(); + + for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i); + for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i); + loadRecvSync(); + loadSendSync(); + } + + __device__ void send(const T* src, int nelem) { + return GenericOp<0, 1, 1, 0>(src, NULL, nelem); + } + + __device__ void recv(T* dst, int nelem) { + return GenericOp<1, 0, 0, 1>(NULL, dst, nelem); + } + + __device__ void recvReduceSend(const T* src, int nelem) { + return GenericOp<1, 1, 1, 0>(src, NULL, nelem); + } + + __device__ void recvReduceCopy(const T* src, T* dst, int nelem) { + return GenericOp<1, 0, 1, 1>(src, dst, nelem); + } + + __device__ void copySend(const T* src, T* dst, int nelem) { + return GenericOp<0, 1, 1, 1>(src, dst, nelem); + } + + __device__ void recvCopySend(T* dst, int nelem) { + return GenericOp<1, 1, 0, 1>(NULL, dst, nelem); + } + + __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) { + return GenericOp<1, 1, 1, 1>(src, dst, nelem); + } + + __device__ __forceinline__ ~ncclLL128Primitives() { + // Save steps for the next operation + saveRecvSync(); + saveSendSync(); + } +}; diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index d2d5d3b..0680abe 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -11,7 +11,7 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; @@ -30,7 +30,7 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) { T * __restrict__ thisOutput = (T*)args->ThisOutput; ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); + prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels)); @@ -93,3 +93,48 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { } + +#include "prims_ll128.h" +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nthreads = args->nThreads; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclRing* ring = &channel->ring; + + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); + + const ssize_t size = args->N; + const int rank = comm->rank; + const int nranks = comm->nRanks; + const int prevRank = ring->devUserRanks[nranks-1]; + const int root = args->root; + + ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + + const ssize_t loopSize = args->nChannels*chunkSize; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize); + ssize_t offset = gridOffset + bid*chunkSize; + + int nelem = min(chunkSize, size-offset); + if (prevRank == root) { + LLprims.send(thisInput+offset, nelem); + } else if (rank == root) { + LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); + } else { + LLprims.recvReduceSend(thisInput+offset, nelem); + } + } +} + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceTreeLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 09ba56e..1985148 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -11,7 +11,7 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = blockDim.x - 1; + const int nthreads = args->nThreads-WARP_SIZE; const int bid = args->bid; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; @@ -19,7 +19,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) { const ssize_t size = args->N; const int nranks = comm->nRanks; const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); - const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS; + const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS; const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize; // Compute pointers @@ -27,7 +27,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) { T * __restrict__ thisOutput = (T*)args->ThisOutput; ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); + prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels)); @@ -121,3 +121,64 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { } + +#include "prims_ll128.h" +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nthreads = args->nThreads; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + struct ncclRing* ring = &channel->ring; + + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. + const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; + + const ssize_t loopSize = args->nChannels*chunkSize; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize); + + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin ReduceScatter steps /////////////// + ssize_t offset; + int nelem = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[nranks-1]; + offset = chunkOffset + rankDest * size; + + LLprims.send(thisInput+offset, nelem); + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + LLprims.recvReduceSend(thisInput+offset, nelem); + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem); + } +} + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceScatterTreeLL128Kernel(struct CollectiveArgs* args) { } |