Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device/all_reduce.h')
-rw-r--r--src/collectives/device/all_reduce.h223
1 files changed, 115 insertions, 108 deletions
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index 4e04f88..6891ac0 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,26 +11,27 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads-WARP_SIZE;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
- const ssize_t size = args->N;
- const int nranks = comm->nRanks;
- const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
- const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*args->nChannels));
+ ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
@@ -85,28 +86,29 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads-WARP_SIZE;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- const ssize_t size = args->N;
- const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
- int chunkSize = args->lastChunkSize;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ int chunkSize = args->coll.lastChunkSize;
const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
if (loopSize > size) {
- chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
}
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -124,17 +126,17 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
if (tree->up == -1) {
- prims.send(thisOutput+offset, nelem);
+ prims.directSend(thisOutput+offset, offset, nelem);
} else if (tree->down[0] == -1) {
- prims.recv(thisOutput+offset, nelem);
+ prims.directRecv(thisOutput+offset, offset, nelem);
} else {
- prims.recvCopySend(thisOutput+offset, nelem);
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
}
}
} while(0);
@@ -143,27 +145,28 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads-WARP_SIZE;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- const ssize_t size = args->N;
- const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
- int chunkSize = args->lastChunkSize;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ int chunkSize = args->coll.lastChunkSize;
const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
if (loopSize > size) {
- chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
}
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- if (blockIdx.x < args->nChannels) { // first half of the channels do reduce
+ if (blockIdx.x < nChannels) { // first half of the channels do reduce
struct ncclTree* tree = &channel->collTreeUp;
- ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -178,9 +181,9 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
}
}
- if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast
+ if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
struct ncclTree* tree = &channel->collTreeDn;
- ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -199,28 +202,27 @@ __device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int bid = args->bid;
- const int nthreads = args->nThreads;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
-
- const ssize_t size = args->N;
- //const int rank = comm->rank;
- const int nranks = comm->nRanks;
- ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*nranks*chunkSize;
+ const ssize_t size = args->coll.count;
- const ssize_t loopSize = args->nChannels*nranks*chunkSize;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm, args->opCount);
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
/////////////// begin AllReduce steps ///////////////
ssize_t offset;
@@ -229,7 +231,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// step 0: push data to next GPU
chunk = ring->devUserRanks[nranks-1];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.send(thisInput+offset, nelem);
@@ -237,7 +239,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
chunk = ring->devUserRanks[nranks-j];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceSend(thisInput+offset, nelem);
@@ -246,7 +248,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// step k-1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
chunk = ring->devUserRanks[0];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
@@ -254,7 +256,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
chunk = ring->devUserRanks[nranks-j];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvCopySend(thisOutput+offset, nelem);
@@ -262,7 +264,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// Make final copy from buffer to dest.
chunk = ring->devUserRanks[1];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
// Here we need to copy from buffer to this output.
@@ -273,27 +275,29 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- const ssize_t size = args->N;
- ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
if (loopSize > size) {
- chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
}
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount);
+ ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -311,7 +315,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount);
+ ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -330,26 +334,28 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- const ssize_t size = args->N;
- ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
if (loopSize > size) {
- chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
}
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- if (blockIdx.x < args->nChannels) { // first half of the channels do reduce
+ if (blockIdx.x < nChannels) { // first half of the channels do reduce
struct ncclTree* tree = &channel->collTreeUp;
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount);
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -364,9 +370,9 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
}
}
- if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast
+ if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
struct ncclTree* tree = &channel->collTreeDn;
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount);
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -386,29 +392,28 @@ __device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int bid = args->bid;
- const int nthreads = args->nThreads;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclRing* ring = &channel->ring;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
-
- const ssize_t size = args->N;
- //const int rank = comm->rank;
- const int nranks = comm->nRanks;
- ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*nranks*chunkSize;
+ const ssize_t size = args->coll.count;
- const ssize_t loopSize = args->nChannels*nranks*chunkSize;
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm, args->opCount);
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
/////////////// begin AllReduce steps ///////////////
ssize_t offset;
@@ -417,7 +422,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
// step 0: push data to next GPU
chunk = ring->devUserRanks[nranks-1];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.send(thisInput+offset, nelem);
@@ -425,7 +430,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
chunk = ring->devUserRanks[nranks-j];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceSend(thisInput+offset, nelem);
@@ -434,7 +439,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
// step k-1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
chunk = ring->devUserRanks[0];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
@@ -442,7 +447,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
chunk = ring->devUserRanks[nranks-j];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvCopySend(thisOutput+offset, nelem);
@@ -450,7 +455,7 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
// Make final copy from buffer to dest.
chunk = ring->devUserRanks[1];
- offset = gridOffset + (chunk*args->nChannels+bid) * chunkSize;
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
// Here we need to copy from buffer to this output.
@@ -461,29 +466,31 @@ __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = args->nThreads;
- const int bid = args->bid;
+ const int nthreads = args->coll.nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
struct ncclTree* treeUp = &channel->treeUp;
struct ncclTree* treeDn = &channel->treeDn;
- const ssize_t size = args->N;
- ssize_t chunkSize = args->lastChunkSize;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = args->coll.lastChunkSize;
const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
- const ssize_t loopSize = args->nChannels*chunkSize;
+ const ssize_t loopSize = nChannels*chunkSize;
int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
+ const ssize_t size = args->coll.count;
if (loopSize > size) {
- chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
}
// Compute pointers
- const T * __restrict__ thisInput = (const T*)args->ThisInput;
- T * __restrict__ thisOutput = (T*)args->ThisOutput;
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
if (treeUp->up == -1) {
// ReduceAndBroadcast : max number of recv is 3, max number of send is 3
- ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, channel, comm, args->opCount);
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
@@ -492,7 +499,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
} else {
if (tid < nthreadsSplit) {
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, channel, comm, args->opCount);
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -505,7 +512,7 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
}
} else {
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, channel, comm, args->opCount);
+ ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;