Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device/all_gather.h')
-rw-r--r--src/collectives/device/all_gather.h74
1 files changed, 37 insertions, 37 deletions
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index 059092c..724b1aa 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,26 +11,27 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads-WARP_SIZE;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
- const ssize_t size = args->N;
- const int nranks = comm->nRanks;
- const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
- const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
ssize_t chunkOffset = gridOffset + bid*realChunkSize;
@@ -75,27 +76,27 @@ __device__ void ncclAllGatherCollNetKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int bid = args->bid;
- const int nthreads = args->nThreads;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
-
- const ssize_t size = args->N;
- //const int rank = comm->rank;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
const int nranks = comm->nRanks;
- ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
- chunkSize = args->lastChunkSize;
+ chunkSize = args->coll.lastChunkSize;
}
ssize_t chunkOffset = gridOffset + bid*chunkSize;
@@ -140,29 +141,28 @@ __device__ void ncclAllGatherCollNetLLKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int bid = args->bid;
- const int nthreads = args->nThreads;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
-
- const ssize_t size = args->N;
- //const int rank = comm->rank;
- const int nranks = comm->nRanks;
- ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
- const ssize_t loopSize = args->nChannels*chunkSize;
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize);
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
ssize_t chunkOffset = gridOffset + bid*chunkSize;