diff options
Diffstat (limited to 'src/collectives/device/all_gather.h')
-rw-r--r-- | src/collectives/device/all_gather.h | 74 |
1 files changed, 37 insertions, 37 deletions
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 059092c..724b1aa 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -11,26 +11,27 @@ template<int UNROLL, class FUNC, typename T> __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int nthreads = args->nThreads-WARP_SIZE; - const int bid = args->bid; + const int nthreads = args->coll.nThreads-WARP_SIZE; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - const ssize_t size = args->N; - const int nranks = comm->nRanks; - const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); + const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS); const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS; - const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize; + const int nranks = comm->nRanks; + const ssize_t loopSize = nChannels*(ssize_t)chunkSize; + const ssize_t size = args->coll.count; // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; - ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, FUNC> - prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC> + prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels)); + int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); ssize_t chunkOffset = gridOffset + bid*realChunkSize; @@ -75,27 +76,27 @@ __device__ void ncclAllGatherCollNetKernel(struct CollectiveArgs* args) { } template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nthreads = args->nThreads; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - - ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); - - const ssize_t size = args->N; - //const int rank = comm->rank; + const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS); + ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T); const int nranks = comm->nRanks; - ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); - const ssize_t loopSize = args->nChannels*chunkSize; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; + + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount); // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { if (size-gridOffset < loopSize) { - chunkSize = args->lastChunkSize; + chunkSize = args->coll.lastChunkSize; } ssize_t chunkOffset = gridOffset + bid*chunkSize; @@ -140,29 +141,28 @@ __device__ void ncclAllGatherCollNetLLKernel(struct CollectiveArgs* args) { } template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nthreads = args->nThreads; + const int nthreads = args->coll.nThreads; + const int bid = args->coll.bid; + const int nChannels = args->coll.nChannels; struct ncclDevComm* comm = args->comm; struct ncclChannel* channel = comm->channels+blockIdx.x; struct ncclRing* ring = &channel->ring; - - ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount); - - const ssize_t size = args->N; - //const int rank = comm->rank; - const int nranks = comm->nRanks; - ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T)); + const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS); + ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T)); // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2; + const int nranks = comm->nRanks; + const ssize_t loopSize = nChannels*chunkSize; + const ssize_t size = args->coll.count; - const ssize_t loopSize = args->nChannels*chunkSize; + ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount); // Compute pointers - const T * __restrict__ thisInput = (const T*)args->ThisInput; - T * __restrict__ thisOutput = (T*)args->ThisOutput; + const T * __restrict__ thisInput = (const T*)args->sendbuff; + T * __restrict__ thisOutput = (T*)args->recvbuff; for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize); + chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize); ssize_t chunkOffset = gridOffset + bid*chunkSize; |