Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2018-12-14 02:56:12 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2019-01-30 02:19:27 +0300
commit1450d42675be325cd3b7a684d4b231eedceb22fb (patch)
treedc1f88ad03d598c3bb03f20dd81d8ef671fc2bff /src/collectives
parent4861e197fd83f0ac324ac0c21051820f8866e6ea (diff)
2.4.2-1
Add tree algorithms for allreduce to improve performance at scale. Add ncclCommAbort() and ncclCommGetAsyncError() to properly handle network errors and be permit recover. Detect initial CPU affinity and no longer escape it.
Diffstat (limited to 'src/collectives')
-rw-r--r--src/collectives/all_gather.cu22
-rw-r--r--src/collectives/all_reduce.cu26
-rw-r--r--src/collectives/broadcast.cu34
-rw-r--r--src/collectives/collectives.h37
-rw-r--r--src/collectives/device/Makefile39
-rw-r--r--src/collectives/device/all_gather.cu8
-rw-r--r--src/collectives/device/all_gather.h218
-rw-r--r--src/collectives/device/all_reduce.cu14
-rw-r--r--src/collectives/device/all_reduce.h381
-rw-r--r--src/collectives/device/broadcast.cu8
-rw-r--r--src/collectives/device/broadcast.h200
-rw-r--r--src/collectives/device/common.h112
-rw-r--r--src/collectives/device/common_kernel.h186
-rw-r--r--src/collectives/device/functions.cu10
-rwxr-xr-xsrc/collectives/device/gen_rules.sh28
-rw-r--r--src/collectives/device/ll_kernel.h154
-rw-r--r--src/collectives/device/primitives.h709
-rw-r--r--src/collectives/device/reduce.cu14
-rw-r--r--src/collectives/device/reduce.h165
-rw-r--r--src/collectives/device/reduce_kernel.h94
-rw-r--r--src/collectives/device/reduce_scatter.cu14
-rw-r--r--src/collectives/device/reduce_scatter.h158
-rw-r--r--src/collectives/reduce.cu23
-rw-r--r--src/collectives/reduce_scatter.cu22
24 files changed, 1142 insertions, 1534 deletions
diff --git a/src/collectives/all_gather.cu b/src/collectives/all_gather.cu
index 8dec28e..db21dee 100644
--- a/src/collectives/all_gather.cu
+++ b/src/collectives/all_gather.cu
@@ -4,29 +4,15 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "common_coll.h"
#include "enqueue.h"
#include "collectives.h"
-ncclResult_t ncclAllGatherFunc(const void* sendbuff, void* recvbuff, size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- size_t nbytes = count*ncclTypeSize(datatype);
- INFO(NCCL_COLL,"AllGather: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
- if (comm->nRanks == 1) {
- if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
- } else {
- NCCLCHECK(transportSaveProxies(ALLGATHER_SUBSTEPS, ALLGATHER_BUFCHUNKS, comm->nRanks-1, comm->nRanks, nbytes*comm->nRanks, proxyPatternRing, comm));
- NCCLCHECK(saveKernel(ncclCollAllGather, sendbuff, recvbuff, nbytes, ncclInt8, op, root, comm, stream, nbytes*comm->nRanks, 1));
- }
- return ncclSuccess;
-}
-
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclAllGatherFunc, "AllGather", sendbuff, recvbuff, sendcount, datatype,
- ncclSum, 0, comm, stream);
+ struct ncclInfo info = { ncclCollAllGather, "AllGather",
+ sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
+ ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
+ return ncclEnqueueCheck(&info);
}
diff --git a/src/collectives/all_reduce.cu b/src/collectives/all_reduce.cu
index cc14083..1492c90 100644
--- a/src/collectives/all_reduce.cu
+++ b/src/collectives/all_reduce.cu
@@ -4,29 +4,15 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "common_coll.h"
#include "enqueue.h"
#include "collectives.h"
-ncclResult_t ncclAllReduceFunc(const void* sendbuff, void* recvbuff, size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- size_t nbytes = count*ncclTypeSize(datatype);
- INFO(NCCL_COLL,"AllReduce: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
- if (comm->nRanks == 1) {
- if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
- } else {
- NCCLCHECK(transportSaveProxies(ALLREDUCE_SUBSTEPS, ALLREDUCE_BUFCHUNKS, (comm->nRanks)*2-2, comm->nRanks, nbytes, proxyPatternRing, comm));
- NCCLCHECK(saveKernel(ncclCollAllReduce, sendbuff, recvbuff, count, datatype, op, root, comm, stream, nbytes, comm->nRanks));
- }
- return ncclSuccess;
-}
-
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclAllReduceFunc, "AllReduce", sendbuff, recvbuff, count, datatype,
- op, 0, comm, stream);
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
+ struct ncclInfo info = { ncclCollAllReduce, "AllReduce",
+ sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
+ ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
+ return ncclEnqueueCheck(&info);
}
diff --git a/src/collectives/broadcast.cu b/src/collectives/broadcast.cu
index 91ce905..6a3d0a8 100644
--- a/src/collectives/broadcast.cu
+++ b/src/collectives/broadcast.cu
@@ -4,39 +4,23 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "common_coll.h"
#include "enqueue.h"
#include "collectives.h"
-ncclResult_t ncclBroadcastFunc(const void* sendbuff, void* recvbuff, const size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- size_t nbytes = count*ncclTypeSize(datatype);
- INFO(NCCL_COLL,"Broadcast: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
- if (comm->nRanks == 1) {
- if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
- } else {
- NCCLCHECK(transportSaveProxies(BROADCAST_SUBSTEPS, BROADCAST_BUFCHUNKS, 1, 1, nbytes, proxyPatternFrom(root), comm));
- NCCLCHECK(saveKernel(ncclCollBroadcast, sendbuff, recvbuff, nbytes, ncclInt8, op, root, comm, stream, nbytes, 1));
- }
-
- return ncclSuccess;
+NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
+ ncclComm_t comm, cudaStream_t stream);
+ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
+ ncclComm_t comm, cudaStream_t stream) {
+ struct ncclInfo info = { ncclCollBroadcast, "Broadcast",
+ sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
+ BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
+ return ncclEnqueueCheck(&info);
}
-
/* Deprecated original "in place" function, similar to MPI */
NCCL_API(ncclResult_t, ncclBcast, void* buff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclBroadcastFunc, "Bcast", buff, buff, count, datatype,
- ncclSum, root, comm, stream);
+ return ncclBroadcast(buff, buff, count, datatype, root, comm, stream);
}
-NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
- ncclComm_t comm, cudaStream_t stream);
-ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
- ncclComm_t comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclBroadcastFunc, "Broadcast", sendbuff, recvbuff, count, datatype,
- ncclSum, root, comm, stream);
-}
diff --git a/src/collectives/collectives.h b/src/collectives/collectives.h
index 4a5cb7a..e6b19cb 100644
--- a/src/collectives/collectives.h
+++ b/src/collectives/collectives.h
@@ -7,9 +7,7 @@
#ifndef NCCL_COLLECTIVES_H_
#define NCCL_COLLECTIVES_H_
-typedef enum { ncclCollBroadcast, ncclCollReduce, ncclCollAllGather, ncclCollReduceScatter, ncclCollAllReduce, ncclCollCount } ncclColl_t;
-
-#define FUNC_INDEX(coll, redop, dtype, ll) ((((coll*ncclNumOps + redop)*ncclNumTypes) + dtype)*2+ll)
+#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*2+(al))*2+(ll))
#define NCCL_COLL_NAME(coll, op, dtype) \
coll##_##op##_##dtype
@@ -18,13 +16,17 @@ typedef enum { ncclCollBroadcast, ncclCollReduce, ncclCollAllGather, ncclCollRed
coll##Kernel_##op##_##dtype
/* Declare all collective operations */
-#define DECL_COLL4(coll, op, dtype) \
+#define DECL_COLL5(coll, op, dtype) \
extern __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args); \
- extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl coll); \
+ extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl c); \
+
+#define DECL_COLL4(coll, op, dtype) \
+ DECL_COLL5(coll, op, dtype) \
+ DECL_COLL5(coll##LL, op, dtype)
#define DECL_COLL3(coll, op, dtype) \
- DECL_COLL4(coll##LL, op, dtype) \
- DECL_COLL4(coll, op, dtype)
+ DECL_COLL4(coll##Ring, op, dtype) \
+ DECL_COLL4(coll##Tree, op, dtype)
#define DECL_COLL2(coll, op) \
DECL_COLL3(coll, op, i8) \
@@ -52,15 +54,16 @@ typedef enum { ncclCollBroadcast, ncclCollReduce, ncclCollAllGather, ncclCollRed
DECL_ALL_COLLS
-#define ALLREDUCE_SUBSTEPS 2
-#define ALLREDUCE_BUFCHUNKS 2
-#define ALLGATHER_SUBSTEPS 2
-#define ALLGATHER_BUFCHUNKS 2
-#define REDUCESCATTER_SUBSTEPS 2
-#define REDUCESCATTER_BUFCHUNKS 2
-#define BROADCAST_SUBSTEPS 8
-#define BROADCAST_BUFCHUNKS 2
-#define REDUCE_SUBSTEPS 8
-#define REDUCE_BUFCHUNKS 2
+// CHUNKSIZE must be a multiple of SLICESIZE
+#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
+#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2)
+#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4)
+#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2)
+#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4)
+#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2)
+#define BROADCAST_SLICESTEPS 1
+#define BROADCAST_CHUNKSTEPS 1
+#define REDUCE_SLICESTEPS 1
+#define REDUCE_CHUNKSTEPS 1
#endif
diff --git a/src/collectives/device/Makefile b/src/collectives/device/Makefile
index e2bcd49..8e92596 100644
--- a/src/collectives/device/Makefile
+++ b/src/collectives/device/Makefile
@@ -12,18 +12,13 @@ OBJDIR := $(BUILDDIR)/obj/collectives/device
LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu
-LIBOBJ := $(patsubst %.cu,$(OBJDIR)/%_sum.o, $(LIBSRCFILES)) \
- $(patsubst %.cu,$(OBJDIR)/%_prod.o, $(LIBSRCFILES)) \
- $(patsubst %.cu,$(OBJDIR)/%_min.o, $(LIBSRCFILES)) \
- $(patsubst %.cu,$(OBJDIR)/%_max.o, $(LIBSRCFILES)) \
- $(OBJDIR)/functions.o
-
LIBSRCFILES += functions.cu
DEPFILES := $(patsubst %.cu, $(OBJDIR)/%.d, $(LIBSRCFILES))
-DEPENDFILES := $(DEPFILES:%.d=%.dep)
+DEPENDFILES:= $(DEPFILES:%.d=%.dep)
STATICLIB := $(OBJDIR)/colldevice.a
DEVOBJ := $(OBJDIR)/devlink.o
+RULESFILE := $(OBJDIR)/Makefile.rules
NVCUFLAGS += -I. -I.. -I$(BUILDDIR)/include -I../../include --compiler-options "-fPIC -fvisibility=hidden"
@@ -33,6 +28,16 @@ all: $(STATICLIB)
# Dummy rule so that the extra dependency (%.dep) files are preserved by make
all_deps: $(DEPENDFILES)
+# Auto-generating the rules per op/reduction/datatype/algorithm
+$(RULESFILE) :
+ @printf "Generating %-35s > %s\n" rules $@
+ @mkdir -p $(OBJDIR)
+ @./gen_rules.sh $(OBJDIR) > $@
+
+-include $(RULESFILE)
+
+LIBOBJ := $(GENOBJS) $(OBJDIR)/functions.o
+
-include $(DEPFILES)
$(STATICLIB): $(LIBOBJ) $(DEVOBJ)
@@ -58,26 +63,6 @@ $(OBJDIR)/functions.o : functions.cu $(OBJDIR)/functions.dep
mkdir -p `dirname $@`
$(NVCC) $(NVCUFLAGS) -dc $< -o $@
-$(OBJDIR)/%_sum.o : %.cu $(OBJDIR)/%.dep
- @printf "Compiling %-35s > %s\n" $< $@
- mkdir -p `dirname $@`
- $(NVCC) -DNCCL_OP=0 $(NVCUFLAGS) -dc $< -o $@
-
-$(OBJDIR)/%_prod.o : %.cu $(OBJDIR)/%.dep
- @printf "Compiling %-35s > %s\n" $< $@
- mkdir -p `dirname $@`
- $(NVCC) -DNCCL_OP=1 $(NVCUFLAGS) -dc $< -o $@
-
-$(OBJDIR)/%_min.o : %.cu $(OBJDIR)/%.dep
- @printf "Compiling %-35s > %s\n" $< $@
- mkdir -p `dirname $@`
- $(NVCC) -DNCCL_OP=2 $(NVCUFLAGS) -dc $< -o $@
-
-$(OBJDIR)/%_max.o : %.cu $(OBJDIR)/%.dep
- @printf "Compiling %-35s > %s\n" $< $@
- mkdir -p `dirname $@`
- $(NVCC) -DNCCL_OP=3 $(NVCUFLAGS) -dc $< -o $@
-
# ... and create the device-side linked object with all those.
$(DEVOBJ) : $(LIBOBJ)
$(NVCC) $(NVCUFLAGS) -dlink $^ -o $@
diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu
index 0f572ce..530bf14 100644
--- a/src/collectives/device/all_gather.cu
+++ b/src/collectives/device/all_gather.cu
@@ -4,12 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "common.h"
#include "all_gather.h"
+#include "common.h"
#include "collectives.h"
-#define UNROLL 4
-
-#if NCCL_OP == 0
-IMPL_COLL3(ncclAllGather, copy, FuncSum, i8, int8_t, ncclCollAllGather, ncclSum, ncclInt8);
-#endif
+IMPL_COLL_C(ncclAllGather, ncclCollAllGather);
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index a30e575..36809c9 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -8,72 +8,35 @@
#include "primitives.h"
#include "collectives.h"
-// Increase Step and poffset/noffset for buffer sync
-#define NEXT_STEP \
- step++; \
- poffset = noffset; \
- noffset += sliceSize; \
- if (noffset == buffSize) noffset = 0;
-
template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
+__device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = blockDim.x - 1;
const int bid = args->bid;
- __shared__ T* sharedNextOutput;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- int prevdirect = ring->recv.conn.direct;
- int nextdirect = ring->send.conn.direct;
-
- WaitFlag waitDoneFromNext(ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLGATHER_SUBSTEPS);
- PostFlag postDoneToPrev(ring->recv.conn.head, ALLGATHER_SUBSTEPS, NULL, 0);
- PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS);
-
- typedef Primitives<UNROLL, ALLGATHER_SUBSTEPS, T> Prims;
-
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
const ssize_t size = args->N;
const int nranks = comm->nRanks;
- const int buffSize = ring->buffSize / sizeof(T);
- const int sliceSize = buffSize / ALLGATHER_BUFCHUNKS;
- const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
-
- if (tid == 0) {
- // Update in case we skipped some collectives
- *ring->recv.conn.opCount = args->opCount;
- // Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
- if (prevdirect) {
- *ring->recv.conn.ptrExchange = args->ThisOutput;
- }
- if (nextdirect) {
- void* volatile* ptr = &(ring->devMemSend->ptrExchange);
- while (*ptr == nullptr);
- sharedNextOutput = (T*)*ptr;
- *ptr = nullptr;
- }
- }
- __syncthreads();
-
- uint64_t step = 0ULL;
- int poffset, noffset = 0;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- T * __restrict__ prevInput = (T*)ring->recv.conn.buff;
- T * __restrict__ nextOutput = (T*)ring->send.conn.buff;
+
+ ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings));
- ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*realChunkSize;
/////////////// begin AllGather steps ///////////////
ssize_t offset;
- int maxOffset = min(chunkSize, size-chunkOffset);
+ int nelem = min(realChunkSize, size-chunkOffset);
int rankDest;
// step 0: push data to next GPU
@@ -81,129 +44,51 @@ __device__ void ncclAllGatherKernel(struct CollectiveArgs* args) {
offset = chunkOffset + rankDest * size;
if (thisInput + chunkOffset == thisOutput + offset) { // In place
- Prims::Copy(tid, nthreads,
- thisInput + chunkOffset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
+ prims.directSend(thisInput+chunkOffset, offset, nelem);
} else {
- Prims::DoubleCopy(tid, nthreads,
- thisInput + chunkOffset,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
+ prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem);
}
- NEXT_STEP; // Increases step, poffset, noffset
-
// k-2 steps: copy to next GPU
- if (prevdirect) {
- for (int j=1; j<nranks-1; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- Prims::Copy(tid, nthreads,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
- }
- Prims::Copy(tid, nthreads,
- NULL,
- NULL,
- 0, 0,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
- } else {
- for (int j=1; j<nranks-1; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- Prims::DoubleCopy(tid, nthreads,
- prevInput + poffset,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
- }
-
- // Make final copy from buffer to dest.
- rankDest = ring->devUserRanks[1];
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
offset = chunkOffset + rankDest * size;
- // Here we need to copy from buffer to this output.
- Prims::Copy(tid, nthreads,
- prevInput + poffset,
- thisOutput + offset,
- sliceSize, maxOffset,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
}
- }
- if (tid == 0) {
- waitDoneFromNext.wait(ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS));
- *ring->send.conn.head = 0ULL;
- *ring->recv.conn.tail = 0ULL;
- __threadfence_system();
- *ring->recv.conn.opCount = args->opCount+1;
+ // Make final copy from buffer to dest.
+ rankDest = ring->devUserRanks[1];
+ offset = chunkOffset + rankDest * size;
+
+ // Final wait/copy.
+ prims.directRecv(thisOutput+offset, offset, nelem);
}
}
-#include "ll_kernel.h"
-
-#define NEXT_STEP_LL \
- poffset = noffset; \
- pflag = nflag; \
- noffset += NCCL_LL_SLICE_LINES; \
- if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \
- nflag++; \
- step++;
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclAllGatherTreeKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) {
+__device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int bid = args->bid;
- const int llNthreads = args->nThreads;
+ const int nthreads = args->nThreads;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead;
- volatile uint64_t * sendHeadPtr = ring->send.conn.llHead;
- volatile int * sizesFifo = ring->send.conn.llFifo;
- uint64_t sendHead = sendHeadPtr[0];
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
- typedef LLPrimitives<T, FUNC> LL;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
const ssize_t size = args->N;
//const int rank = comm->rank;
const int nranks = comm->nRanks;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nRings*chunkSize;
-
- uint64_t step = ring->send.conn.llStep;
- uint32_t pflag, nflag = step + 1;
- int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step);
+ const ssize_t loopSize = args->nChannels*chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff;
- union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
@@ -213,57 +98,34 @@ __device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) {
/////////////// begin AllGather steps ///////////////
ssize_t offset;
- int maxOffset = min(chunkSize, size-chunkOffset);
+ int nelem = min(chunkSize, size-chunkOffset);
int rankDest;
// step 0: push data to next GPU
rankDest = ring->devUserRanks[0];
offset = chunkOffset + rankDest * size;
- WAIT_NEXT;
if (thisInput + chunkOffset == thisOutput + offset) { // In place
- LL::ReduceCopy(
- thisInput + chunkOffset,
- nextOutput + noffset,
- maxOffset, nflag, llNthreads);
+ LLprims.send(thisInput+chunkOffset, nelem);
} else {
- LL::ReduceCopy(
- thisInput + chunkOffset,
- thisOutput + offset,
- nextOutput + noffset,
- maxOffset, nflag, llNthreads);
+ LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
}
- POST_SIZE;
-
- NEXT_STEP_LL;
// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
rankDest = ring->devUserRanks[nranks-j];
offset = chunkOffset + rankDest * size;
- WAIT_NEXT;
- LL::ReduceCopy(
- prevInput + poffset,
- thisOutput + offset,
- nextOutput + noffset,
- maxOffset, pflag, nflag, llNthreads);
- POST_SIZE;
- ACK_PREV;
-
- NEXT_STEP_LL;
+ LLprims.recvCopySend(thisOutput+offset, nelem);
}
// step k-1: final store
rankDest = ring->devUserRanks[1];
offset = chunkOffset + rankDest * size;
- LL::ReduceCopy(
- prevInput + poffset,
- thisOutput + offset,
- maxOffset, pflag, llNthreads);
- ACK_PREV;
+ LLprims.recv(thisOutput+offset, nelem);
}
-
- FIFO_CLEANING_AND_SAVE_STEP(nflag);
}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu
index caa1479..aaa96b4 100644
--- a/src/collectives/device/all_reduce.cu
+++ b/src/collectives/device/all_reduce.cu
@@ -4,18 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "common.h"
#include "all_reduce.h"
+#include "common.h"
#include "collectives.h"
-#define UNROLL 4
-
-#if NCCL_OP == 0
-IMPL_COLL2(ncclAllReduce, sum, FuncSum, ncclCollAllReduce, ncclSum);
-#elif NCCL_OP == 1
-IMPL_COLL2(ncclAllReduce, prod, FuncProd, ncclCollAllReduce, ncclProd);
-#elif NCCL_OP == 2
-IMPL_COLL2(ncclAllReduce, min, FuncMin, ncclCollAllReduce, ncclMin);
-#elif NCCL_OP == 3
-IMPL_COLL2(ncclAllReduce, max, FuncMax, ncclCollAllReduce, ncclMax);
-#endif
+IMPL_COLL_R(ncclAllReduce, ncclCollAllReduce);
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index d7abc64..ea89a71 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -8,233 +8,152 @@
#include "primitives.h"
#include "collectives.h"
-// Increase Step and poffset/noffset for buffer sync
-#define NEXT_STEP \
- step++; \
- poffset = noffset; \
- noffset += sliceSize; \
- if (noffset == buffSize) noffset = 0;
-
template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllReduceKernel(struct CollectiveArgs* args) {
+__device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = blockDim.x - 1;
const int bid = args->bid;
- __shared__ T* sharedNextOutput;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- int prevdirect = ring->recv.conn.direct;
- int nextdirect = ring->send.conn.direct;
-
- WaitFlag waitDoneFromNext(ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLREDUCE_SUBSTEPS);
- PostFlag postDoneToPrev(ring->recv.conn.head, ALLREDUCE_SUBSTEPS, NULL, 0);
- PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS);
-
- typedef Primitives<UNROLL, ALLREDUCE_SUBSTEPS, T, FUNC> Prims;
-
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
const ssize_t size = args->N;
- //const int rank = comm->rank;
const int nranks = comm->nRanks;
- const int buffSize = ring->buffSize / sizeof(T);
- const int sliceSize = buffSize / ALLREDUCE_BUFCHUNKS;
- const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
-
- if (tid == 0) {
- // Update in case we skipped some collectives
- *ring->recv.conn.opCount = args->opCount;
- // Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
- if (prevdirect) {
- *ring->recv.conn.ptrExchange = args->ThisOutput;
- }
- if (nextdirect) {
- void* volatile* ptr = &(ring->devMemSend->ptrExchange);
- while (*ptr == nullptr);
- sharedNextOutput = (T*)*ptr;
- *ptr = nullptr;
- }
- }
- __syncthreads();
-
- uint64_t step = 0ULL;
- int poffset, noffset = 0;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- T * __restrict__ prevInput = (T*)ring->recv.conn.buff;
- T * __restrict__ nextOutput = (T*)ring->send.conn.buff;
+
+ ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
- int chunkSize = min(sliceSize, DIVUP(size-gridOffset,nranks*args->nRings));
- ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*nranks*chunkSize;
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*args->nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
/////////////// begin AllReduce steps ///////////////
ssize_t offset;
- int maxOffset;
+ int nelem;
int slice;
// step 0: push data to next GPU
slice = ring->devUserRanks[nranks-1];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
+ offset = chunkOffset + slice * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
- Prims::Copy(tid, nthreads,
- thisInput + offset,
- nextOutput + noffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
-
- NEXT_STEP; // Increases step, poffset, noffset
+ prims.send(thisInput+offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
slice = ring->devUserRanks[nranks-j];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- Prims::Reduce(tid, nthreads,
- prevInput + poffset,
- thisInput + offset,
- nextOutput + noffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
+ offset = chunkOffset + slice * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
+
+ prims.recvReduceSend(thisInput+offset, nelem);
}
// step k-1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
slice = ring->devUserRanks[0];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
+ offset = chunkOffset + slice * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
- Prims::ReduceCopy(tid, nthreads,
- prevInput + poffset,
- thisInput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- thisOutput + offset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
+ prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
// k-2 steps: copy to next GPU
- if (prevdirect) {
- for (int j=1; j<nranks-1; ++j) {
- slice = ring->devUserRanks[nranks - j];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- Prims::Copy(tid, nthreads,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
- }
- Prims::Copy(tid, nthreads,
- NULL,
- NULL,
- 0, 0,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
- } else {
- for (int j=1; j<nranks-1; ++j) {
- slice = ring->devUserRanks[nranks - j];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- Prims::DoubleCopy(tid, nthreads,
- prevInput + poffset,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
- }
+ for (int j=1; j<nranks-1; ++j) {
+ slice = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + slice * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
- // Make final copy from buffer to dest.
- slice = ring->devUserRanks[1];
- offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- // Here we need to copy from buffer to this output.
- Prims::Copy(tid, nthreads,
- prevInput + poffset,
- thisOutput + offset,
- sliceSize, maxOffset,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
}
- }
- if (tid == 0) {
- // Wait for next to have consumed all data before we reset the flag
- waitDoneFromNext.wait(ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS));
- *ring->send.conn.head = 0ULL;
- *ring->recv.conn.tail = 0ULL;
- __threadfence_system();
- *ring->recv.conn.opCount = args->opCount+1;
+ // Make final copy from buffer to dest.
+ slice = ring->devUserRanks[1];
+ offset = chunkOffset + slice * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
+
+ // Final wait/copy.
+ prims.directRecv(thisOutput+offset, offset, nelem);
}
}
-#include "ll_kernel.h"
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = blockDim.x - 1;
+ const int bid = args->bid;
+ struct ncclComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->tree;
+ const ssize_t size = args->N;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = args->lastChunkSize;
+ const ssize_t loopSize = args->nChannels*chunkSize;
-#define NEXT_STEP_LL \
- poffset = noffset; \
- pflag = nflag; \
- noffset += NCCL_LL_SLICE_LINES; \
- if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \
- nflag++; \
- step++;
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ do {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ } while(0);
+
+ do {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.recv(thisOutput+offset, nelem);
+ } else {
+ prims.recvCopySend(thisOutput+offset, nelem);
+ }
+ }
+ } while(0);
+}
template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
+__device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int bid = args->bid;
- const int llNthreads = args->nThreads;
+ const int nthreads = args->nThreads;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead;
- volatile uint64_t * sendHeadPtr = ring->send.conn.llHead;
- volatile int * sizesFifo = ring->send.conn.llFifo;
- uint64_t sendHead = sendHeadPtr[0];
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
- typedef LLPrimitives<T, FUNC> LL;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
const ssize_t size = args->N;
//const int rank = comm->rank;
const int nranks = comm->nRanks;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nRings*nranks*chunkSize;
-
- uint64_t step = ring->send.conn.llStep;
- uint32_t pflag, nflag = step + 1;
- int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step);
+ const ssize_t loopSize = args->nChannels*nranks*chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff;
- union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
@@ -244,89 +163,99 @@ __device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) {
/////////////// begin AllReduce steps ///////////////
ssize_t offset;
- int maxOffset;
+ int nelem;
int slice;
// step 0: push data to next GPU
slice = ring->devUserRanks[nranks-1];
offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
+ nelem = min(chunkSize, size-offset);
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- nextOutput + noffset,
- maxOffset, nflag, llNthreads);
- POST_SIZE;
-
- NEXT_STEP_LL;
+ LLprims.send(thisInput+offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
slice = ring->devUserRanks[nranks-j];
offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + poffset,
- nextOutput + noffset,
- maxOffset, pflag, nflag, llNthreads);
- POST_SIZE;
- ACK_PREV;
-
- NEXT_STEP_LL;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.recvReduceSend(thisInput+offset, nelem);
}
// step k-1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
slice = ring->devUserRanks[0];
offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
+ nelem = min(chunkSize, size-offset);
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + poffset,
- thisOutput + offset,
- nextOutput + noffset,
- maxOffset, pflag, nflag, llNthreads);
- POST_SIZE;
- ACK_PREV;
-
- NEXT_STEP_LL;
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
- slice = ring->devUserRanks[nranks - j];
+ slice = ring->devUserRanks[nranks-j];
offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
-
- WAIT_NEXT;
- LL::ReduceCopy(
- prevInput + poffset,
- thisOutput + offset,
- nextOutput + noffset,
- maxOffset, pflag, nflag, llNthreads);
- POST_SIZE;
- ACK_PREV;
-
- NEXT_STEP_LL;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.recvCopySend(thisOutput+offset, nelem);
}
// Make final copy from buffer to dest.
slice = ring->devUserRanks[1];
offset = chunkOffset + slice * chunkSize;
- maxOffset = min(chunkSize, size-offset);
+ nelem = min(chunkSize, size-offset);
// Here we need to copy from buffer to this output.
- LL::ReduceCopy(
- prevInput + poffset,
- thisOutput + offset,
- maxOffset, pflag, llNthreads);
- ACK_PREV;
+ LLprims.recv(thisOutput+offset, nelem);
}
+}
- FIFO_CLEANING_AND_SAVE_STEP(nflag);
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->bid;
+ struct ncclComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->tree;
+ const ssize_t size = args->N;
+ ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ do {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ } while(0);
+
+ do {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+ }
+ } while(0);
}
diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu
index 4125de4..b83ee70 100644
--- a/src/collectives/device/broadcast.cu
+++ b/src/collectives/device/broadcast.cu
@@ -4,12 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "common.h"
#include "broadcast.h"
+#include "common.h"
#include "collectives.h"
-#define UNROLL 4
-
-#if NCCL_OP == 0
-IMPL_COLL3(ncclBroadcast, copy, FuncSum, i8, int8_t, ncclCollBroadcast, ncclSum, ncclInt8);
-#endif
+IMPL_COLL_C(ncclBroadcast, ncclCollBroadcast);
diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h
index c2f6d00..fb18312 100644
--- a/src/collectives/device/broadcast.h
+++ b/src/collectives/device/broadcast.h
@@ -8,174 +8,74 @@
#include "primitives.h"
#include "collectives.h"
-// Increase Step and boffset for buffer sync
-#define NEXT_STEP \
- step++; \
- boffset += sliceSize; \
- if (boffset == buffSize) boffset = 0;
-
template<int UNROLL, class FUNC, typename T>
-__device__ void ncclBroadcastKernel(struct CollectiveArgs* args) {
+__device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = blockDim.x - 1;
const int bid = args->bid;
- __shared__ T* sharedNextOutput;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- int prevdirect = ring->recv.conn.direct;
- int nextdirect = ring->send.conn.direct;
-
- WaitFlag waitDoneFromNext(ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0);
- PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0);
- PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, BROADCAST_BUFCHUNKS*BROADCAST_SUBSTEPS);
-
- typedef Primitives<UNROLL, BROADCAST_SUBSTEPS, T> Prims;
-
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
const ssize_t size = args->N;
- const int buffSize = ring->buffSize / sizeof(T);
- const int sliceSize = buffSize / BROADCAST_BUFCHUNKS;
- const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS;
+ const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
const int rank = ring->devUserRanks[0];
const int nextRank = ring->devUserRanks[1];
const int root = args->root;
- if (tid == 0) {
- // Update in case we skipped some collectives
- *ring->recv.conn.opCount = args->opCount;
- if (nextRank != root) {
- // Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
- }
- if (rank != root && prevdirect) {
- *ring->recv.conn.ptrExchange = args->ThisOutput;
- }
- if (nextRank != root && nextdirect) {
- void* volatile* ptr = &(ring->devMemSend->ptrExchange);
- while (*ptr == nullptr);
- sharedNextOutput = (T*)*ptr;
- *ptr = nullptr;
- }
- }
- __syncthreads();
-
- uint64_t step = 0ULL;
- int boffset = 0;
-
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- T * __restrict__ prevInput = (T*)ring->recv.conn.buff;
- T * __restrict__ nextOutput = (T*)ring->send.conn.buff;
+
+ ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings));
- ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t offset = gridOffset + bid*chunkSize;
- int maxOffset = min(chunkSize, size-offset);
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t offset = gridOffset + bid*realChunkSize;
+ int nelem = min(realChunkSize, size-offset);
if (rank == root) {
if (thisInput == thisOutput) {
- Prims::Copy(tid, nthreads,
- thisInput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
+ prims.send(thisInput+offset, nelem);
} else {
- Prims::DoubleCopy(tid, nthreads,
- thisInput + offset,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
+ prims.copySend(thisInput+offset, thisOutput+offset, nelem);
}
} else if (nextRank == root) {
- if (prevdirect) maxOffset = 0; // Only wait for signals
- Prims::Copy(tid, nthreads,
- prevInput + boffset,
- thisOutput + offset,
- sliceSize, maxOffset,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
+ prims.recv(thisOutput+offset, nelem);
} else {
- if (prevdirect) {
- Prims::Copy(tid, nthreads,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
- } else {
- Prims::DoubleCopy(tid, nthreads,
- prevInput + boffset,
- thisOutput + offset,
- nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset),
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
- }
- }
- NEXT_STEP; // Increases step, boffset
- }
-
- if (tid == 0) {
- if (nextRank != root) {
- // Wait for next to have consumed data before resetting the flag
- waitDoneFromNext.wait(BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1));
- *ring->send.conn.head = 0ULL;
+ prims.recvCopySend(thisOutput+offset, nelem);
}
- *ring->recv.conn.tail = 0ULL;
- __threadfence_system();
- *ring->recv.conn.opCount = args->opCount+1;
}
}
-#include "ll_kernel.h"
-
-#define NEXT_STEP_LL \
- boffset += NCCL_LL_SLICE_LINES; \
- if (boffset == NCCL_LL_BUFF_LINES) boffset = 0; \
- flag++; \
- step++;
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclBroadcastTreeKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) {
+__device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int bid = args->bid;
- const int llNthreads = args->nThreads;
+ const int nthreads = args->nThreads;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead;
- volatile uint64_t * sendHeadPtr = ring->send.conn.llHead;
- volatile int * sizesFifo = ring->send.conn.llFifo;
- uint64_t sendHead = sendHeadPtr[0];
- const int rank = comm->rank;
- const int nextRank = ring->devUserRanks[1];
- const int root = args->root;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
- typedef LLPrimitives<T, FUNC> LL;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
const ssize_t size = args->N;
- ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nRings*chunkSize;
+ const int rank = ring->devUserRanks[0];
+ const int nextRank = ring->devUserRanks[1];
+ const int root = args->root;
- uint64_t step = ring->send.conn.llStep;
- uint32_t flag = step + 1;
- int boffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step);
+ ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = args->nChannels*chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff;
- union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
@@ -183,46 +83,20 @@ __device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) {
}
ssize_t offset = gridOffset + bid*chunkSize;
- int maxOffset = min(chunkSize, size-offset);
+ int nelem = min(chunkSize, size-offset);
if (rank == root) {
- WAIT_NEXT;
if (thisInput == thisOutput) {
- LL::ReduceCopy(
- thisInput + offset,
- nextOutput + boffset,
- maxOffset, flag, llNthreads);
+ LLprims.send(thisInput+offset, nelem);
} else {
- LL::ReduceCopy(
- thisInput + offset,
- thisOutput + offset,
- nextOutput + boffset,
- maxOffset, flag, llNthreads);
+ LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
}
- POST_SIZE;
- NEXT_STEP_LL;
} else if (nextRank == root) {
- LL::ReduceCopy(
- prevInput + boffset,
- thisOutput + offset,
- maxOffset, flag, llNthreads);
- NEXT_STEP_LL;
- ACK_PREV;
+ LLprims.recv(thisOutput + offset, nelem);
} else {
- WAIT_NEXT;
- LL::ReduceCopy(
- prevInput + boffset,
- thisOutput + offset,
- nextOutput + boffset,
- maxOffset, flag, flag, llNthreads);
- POST_SIZE;
- NEXT_STEP_LL;
- ACK_PREV;
+ LLprims.recvCopySend(thisOutput + offset, nelem);
}
}
-
- // We need everyone to acknowledge data even if they didn't receive anything
- // so that the next collective can start right away.
- ACK_PREV;
-
- FIFO_CLEANING_AND_SAVE_STEP(flag);
}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
index c988913..e4aecbd 100644
--- a/src/collectives/device/common.h
+++ b/src/collectives/device/common.h
@@ -11,13 +11,29 @@
#include "core.h"
#include "nccl.h"
+// Exit If Abort Barrier across CTA: make sure all threads exit consistently
+// Each thread sets a predicate to true if abort == 1
+// all CTA's threads enter the barrier and do a popc on their predicates being True
+// If any of the thread's predicate was True, all the threads call exit()
+static inline __device__ void exitIfAbortBarrier(int abort) {
+ uint32_t popc;
+ asm ("{");
+ asm volatile (" .reg .pred barr_pred;");
+ asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort));
+ asm volatile (" bar.red.popc.u32 %0, 13, barr_pred;" : "=r"(popc));
+ asm ("}");
+ if (popc) { asm volatile ("exit;"); }
+}
+
typedef void(*ncclKern_t)(struct CollectiveArgs* args);
extern __device__ ncclKern_t ncclFuncs[];
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
int* d = (int*)dst;
int* s = (int*)src;
- __syncthreads();
+ // When aggregation is effective, if some threads have aborted inside the LL kernel,
+ // make sure the rest of the threads abort as well
+ exitIfAbortBarrier(0);
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
__syncthreads();
}
@@ -27,12 +43,14 @@ static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* ho
}
/* Functions for aggregation case */
-#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype) \
+#define IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
__device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \
- coll##Kernel<UNROLL, ncclFunc<ctype>, ctype>(args); \
+ coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(args); \
}
+
+#if NCCL_OP == 0
/* Kernels with the first operation inlined */
-#define IMPL_COLL4K(coll, op, ncclFunc, dtype, ctype, fIndex) \
+#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \
__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
int tid = threadIdx.x; \
@@ -40,25 +58,25 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
__shared__ struct ncclColl localColl; \
\
struct ncclComm* comm = firstColl.args.comm; \
- struct ncclRing* ring = comm->rings+bid; \
+ struct ncclChannel* channel = comm->channels+bid; \
struct ncclColl* c; \
if (bid == 0) { \
/* To optimize for latency, (only) the first operation is passed as argument.*/ \
c = &firstColl; \
} else { \
c = &localColl; \
- load_coll(c, ring->devCollectives+ring->collFifoHead, tid); \
+ load_coll(c, channel->devCollectives+channel->collFifoHead, tid); \
} \
while (1) { \
- if (tid < c->nThreads) { \
+ if (tid < c->args.nThreads) { \
if (c->funcIndex == fIndex) { \
- coll##Kernel<UNROLL, ncclFunc<ctype>, ctype>(&c->args); \
+ coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(&c->args); \
} else { \
ncclFuncs[c->funcIndex](&c->args); \
} \
} \
int nextIndex = c->nextIndex; \
- if (tid == 0) ring->collFifoHead = nextIndex; \
+ if (tid == 0) channel->collFifoHead = nextIndex; \
\
if (c->active == 2) { \
return; \
@@ -66,25 +84,75 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
\
/* Load next collective operation*/ \
c = &localColl; /* for bid 0 */ \
- load_coll(c, ring->devCollectives+nextIndex, tid); \
+ load_coll(c, channel->devCollectives+nextIndex, tid); \
} \
}
+#else
+#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
+#endif
+
+// Only generate inline kernels for LL
+#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \
+ IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
+ IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \
+ IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
- IMPL_COLL4(coll##LL, op, ncclFunc, dtype, ctype) \
- IMPL_COLL4K(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1)) \
- IMPL_COLL4(coll, op, ncclFunc, dtype, ctype) \
- IMPL_COLL4K(coll, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 0)) \
+ IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \
+ IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1)
+#if NCCL_TYPE == 0
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8)
+#elif NCCL_TYPE == 1
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, u8, uint8_t, ncclColl, ncclOp, ncclUint8)
+#elif NCCL_TYPE == 2
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, i32, int32_t, ncclColl, ncclOp, ncclInt32)
+#elif NCCL_TYPE == 3
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, u32, uint32_t, ncclColl, ncclOp, ncclUint32)
+#elif NCCL_TYPE == 4
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, i64, int64_t, ncclColl, ncclOp, ncclInt64)
+#elif NCCL_TYPE == 5
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64)
+#elif NCCL_TYPE == 6
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16)
+#elif NCCL_TYPE == 7
+#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
+ IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32)
+#elif NCCL_TYPE == 8
#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) \
- IMPL_COLL3(coll, op, ncclFunc, u8, uint8_t, ncclColl, ncclOp, ncclUint8) \
- IMPL_COLL3(coll, op, ncclFunc, i32, int32_t, ncclColl, ncclOp, ncclInt32) \
- IMPL_COLL3(coll, op, ncclFunc, u32, uint32_t, ncclColl, ncclOp, ncclUint32) \
- IMPL_COLL3(coll, op, ncclFunc, i64, int64_t, ncclColl, ncclOp, ncclInt64) \
- IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) \
- IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) \
- IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) \
IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64)
+#endif
+
+// Reduction define all functions
+#if NCCL_OP == 0
+#define IMPL_COLL_R(collf, colln) \
+ IMPL_COLL2(collf, sum, FuncSum, colln, ncclSum);
+#elif NCCL_OP == 1
+#define IMPL_COLL_R(collf, colln) \
+ IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd);
+#elif NCCL_OP == 2
+#define IMPL_COLL_R(collf, colln) \
+ IMPL_COLL2(collf, min, FuncMin, colln, ncclMin);
+#elif NCCL_OP == 3
+#define IMPL_COLL_R(collf, colln) \
+ IMPL_COLL2(collf, max, FuncMax, colln, ncclMax);
+#endif
+
+// Copy primitives only define one
+#if NCCL_OP == 0 && NCCL_TYPE == 0
+#define IMPL_COLL_C(collf, colln) \
+ IMPL_COLL3(collf, copy, FuncSum, i8, int8_t, colln, ncclSum, ncclInt8);
+#else
+#define IMPL_COLL_C(collf, colln)
+#endif
+
+#define COLL_UNROLL 4
#endif
diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h
index 0eaa061..e1fb096 100644
--- a/src/collectives/device/common_kernel.h
+++ b/src/collectives/device/common_kernel.h
@@ -192,14 +192,6 @@ struct MULTI<FUNC, int64_t> {
}
};
-#define ALIGNUP(x, a) ((((x)-1) & ~((a)-1)) + (a))
-
-template<typename T>
-__device__ inline volatile T* AlignUp(volatile T * ptr, size_t align) {
- size_t ptrval = reinterpret_cast<size_t>(ptr);
- return reinterpret_cast<volatile T*>(ALIGNUP(ptrval, align));
-}
-
template<typename T> inline __device__
T vFetch(const volatile T* ptr) {
return *ptr;
@@ -236,25 +228,6 @@ void vStore<half>(volatile half* ptr, const half val) {
}
#endif
-template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS>
-__device__ inline void ReduceCopy(
- const int tid, const int nthreads,
- const volatile T * __restrict__ const src0,
- const volatile T * __restrict__ const src1,
- volatile T * __restrict__ const dest0,
- volatile T * __restrict__ const dest1, const int N) {
- for (int idx = tid; idx < N; idx += nthreads) {
- T val = vFetch(src0+idx);
- if (TWO_INPUTS) {
- val = FUNC()(val, vFetch(src1+idx));
- }
- vStore(dest0+idx, val);
- if (TWO_OUTPUTS) {
- vStore(dest1+idx, val);
- }
- }
-}
-
typedef ulong2 Pack128;
template<class FUNC, typename T>
@@ -265,72 +238,111 @@ struct MULTI128 {
}
};
-inline __device__ void Fetch128(Pack128& v, Pack128* p) {
+inline __device__ void Fetch128(Pack128& v, const Pack128* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}
inline __device__ void Store128(Pack128* p, Pack128& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
+template<class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
+__device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthreads,
+ int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
+ const int offset, const int N) {
+ for (int idx = offset+tid; idx < offset+N; idx += nthreads) {
+ T val = vFetch(srcs[0]+idx);
+ #pragma unroll
+ for (int i=1; i<MINSRCS; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
+ #pragma unroll 1
+ for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
+
+ #pragma unroll
+ for (int i=0; i<MINDSTS; i++) vStore(dsts[i]+idx, val);
+ #pragma unroll 1
+ for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) vStore(dsts[i]+idx, val);
+ }
+}
+
#define WARP_SIZE 32
-template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS, int UNROLL>
-__device__ inline void ReduceCopy128b( const int w, const int nw, const int t,
- Pack128 * src0, Pack128 * src1, Pack128 * dest0, Pack128 * dest1,
- const int N) {
- Pack128 t0[UNROLL];
- Pack128 t1[UNROLL];
- const Pack128* src0_end = src0 + N;
+
+template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
+__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t,
+ int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS],
+ const int elemOffset, const int Npack) {
const int inc = nw * UNROLL * WARP_SIZE;
- const int offset = w * UNROLL * WARP_SIZE + t;
- src0 += offset; if (TWO_INPUTS) src1 += offset;
- dest0 += offset; if (TWO_OUTPUTS) dest1 += offset;
-
- while (src0 < src0_end) {
-#pragma unroll
- for (int u = 0; u < UNROLL; ++u) {
- Fetch128(t0[u], src0+u*WARP_SIZE);
- if (TWO_INPUTS) Fetch128(t1[u], src1+u*WARP_SIZE);
+ int offset = w * UNROLL * WARP_SIZE + t;
+
+ const Pack128* srcs[MAXSRCS];
+ for (int i=0; i<MAXSRCS; i++) srcs[i] = ((const Pack128*)(s[i]+elemOffset))+offset;
+ Pack128* dsts[MAXDSTS];
+ for (int i=0; i<MAXDSTS; i++) dsts[i] = ((Pack128*)(d[i]+elemOffset))+offset;
+
+ while (offset < Npack) {
+ Pack128 vals[UNROLL];
+ // Load and reduce
+ for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
+
+ for (int i=1; i<MINSRCS; i++) {
+ Pack128 vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
-#pragma unroll
- for (int u = 0; u < UNROLL; ++u) {
- if (TWO_INPUTS) MULTI128<FUNC, T>()(t0[u], t1[u]);
- Store128(dest0+u*WARP_SIZE, t0[u]);
- if (TWO_OUTPUTS) Store128(dest1+u*WARP_SIZE, t0[u]);
+ #pragma unroll 1
+ for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
+ Pack128 vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
- src0 += inc; if (TWO_INPUTS) src1 += inc;
- dest0 += inc; if (TWO_OUTPUTS) dest1 += inc;
+
+ // Store
+ for (int i = 0; i < MINDSTS; i++) {
+ for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ #pragma unroll 1
+ for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) {
+ for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
+ for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
+ offset += inc;
}
}
-template<int UNROLL, class FUNC, typename T, bool HAS_DEST1, bool HAS_SRC1>
-__device__ inline void ReduceOrCopy(const int tid, const int nthreads,
- volatile T * __restrict__ dest0, volatile T * __restrict__ dest1,
- const volatile T * __restrict__ src0, const volatile T * __restrict__ src1,
+template <typename T>
+__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); }
+
+// Try to limit consecutive load/stores to 8.
+// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise
+#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS)))
+
+template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
+__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
+ int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
int N) {
int Nrem = N;
if (Nrem <= 0) return;
- int Npreamble = (Nrem<alignof(Pack128)) ? Nrem : AlignUp(dest0, alignof(Pack128)) - dest0;
+ int alignDiff = 0;
+ int align = ptrAlign128(srcs[0]);
+ #pragma unroll
+ for (int i=1; i<MINSRCS; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
+ for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
+ #pragma unroll
+ for (int i=0; i<MINDSTS; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
+ for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
- // stage 0: check if we'll be able to use the fast, 128-bit aligned path.
- // If not, we'll just use the slow preamble path for the whole operation
- bool alignable = (((AlignUp(src0, alignof(Pack128)) == src0 + Npreamble)) &&
- (!HAS_DEST1 || (AlignUp(dest1, alignof(Pack128)) == dest1 + Npreamble)) &&
- (!HAS_SRC1 || (AlignUp(src1, alignof(Pack128)) == src1 + Npreamble)));
-
- if (!alignable) {
- Npreamble = Nrem;
- }
+ int Npreamble = alignDiff ? Nrem :
+ N < alignof(Pack128) ? N :
+ (alignof(Pack128) - align) % alignof(Pack128);
// stage 1: preamble: handle any elements up to the point of everything coming
// into alignment
- ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(tid, nthreads, src0, src1, dest0, dest1, Npreamble);
-
- Nrem -= Npreamble;
- if (Nrem == 0) return;
-
- dest0 += Npreamble; if (HAS_DEST1) { dest1 += Npreamble; }
- src0 += Npreamble; if (HAS_SRC1) { src1 += Npreamble; }
+ if (Npreamble) {
+ ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, 0, Npreamble);
+ Nrem -= Npreamble;
+ if (Nrem == 0) return;
+ }
+ int offset = Npreamble;
// stage 2: fast path: use 128b loads/stores to do the bulk of the work,
// assuming the pointers we have are all 128-bit alignable.
@@ -338,35 +350,33 @@ __device__ inline void ReduceOrCopy(const int tid, const int nthreads,
int nw = nthreads / WARP_SIZE; // Number of warps
int t = tid % WARP_SIZE; // Thread (inside the warp)
- const int PackFactor = sizeof(Pack128) / sizeof(T);
+ const int packFactor = sizeof(Pack128) / sizeof(T);
// stage 2a: main loop
- int Nalign2a = (Nrem / (PackFactor * UNROLL * nthreads))
- * (UNROLL * nthreads); // round down
+ int Npack2a = (Nrem / (packFactor * AUTOUNROLL * WARP_SIZE))
+ * (AUTOUNROLL * WARP_SIZE); // round down
+ int Nelem2a = Npack2a * packFactor;
- ReduceCopy128b<FUNC, T, HAS_SRC1, HAS_DEST1, UNROLL>(w, nw, t, (Pack128*)src0, (Pack128*)src1, (Pack128*)dest0, (Pack128*)dest1, Nalign2a);
+ ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2a);
- int Ndone2a = Nalign2a * PackFactor;
- Nrem -= Ndone2a;
+ Nrem -= Nelem2a;
if (Nrem == 0) return;
- dest0 += Ndone2a; if (HAS_DEST1) { dest1 += Ndone2a; }
- src0 += Ndone2a; if (HAS_SRC1) { src1 += Ndone2a; }
+ offset += Nelem2a;
// stage 2b: slightly less optimized for section when we don't have full
- // UNROLLs
+ // unrolling
- int Nalign2b = Nrem / PackFactor;
+ int Npack2b = Nrem / packFactor;
+ int Nelem2b = Npack2b * packFactor;
- ReduceCopy128b<FUNC, T, HAS_SRC1, HAS_DEST1, 1>(w, nw, t, (Pack128*)src0, (Pack128*)src1, (Pack128*)dest0, (Pack128*)dest1, Nalign2b);
+ ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2b);
- int Ndone2b = Nalign2b * PackFactor;
- Nrem -= Ndone2b;
+ Nrem -= Nelem2b;
if (Nrem == 0) return;
- dest0 += Ndone2b; if (HAS_DEST1) { dest1 += Ndone2b; }
- src0 += Ndone2b; if (HAS_SRC1) { src1 += Ndone2b; }
+ offset += Nelem2b;
// stage 2c: tail
- ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(tid, nthreads, src0, src1, dest0, dest1, Nrem);
+ ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, offset, Nrem);
}
#endif // COMMON_KERNEL_H_
diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu
index 1fb8108..ea06b68 100644
--- a/src/collectives/device/functions.cu
+++ b/src/collectives/device/functions.cu
@@ -8,9 +8,13 @@
#include "collectives.h"
#include "common.h"
-#define NCCL_FUNC4(coll, op, dtype) \
+#define NCCL_FUNC5(coll, op, dtype) \
NCCL_COLL_NAME(coll, op, dtype), \
- NCCL_COLL_NAME(coll##LL, op, dtype) \
+ NCCL_COLL_NAME(coll##LL, op, dtype)
+
+#define NCCL_FUNC4(coll, op, dtype) \
+ NCCL_FUNC5(coll##Ring, op, dtype), \
+ NCCL_FUNC5(coll##Tree, op, dtype)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(coll, op) \
@@ -55,7 +59,7 @@
NCCL_FUNCS2A(ncclAllReduce) }
// Must be consistent with the ncclFuncSet enum
-__device__ ncclKern_t ncclFuncs[ncclCollCount*ncclNumOps*ncclNumTypes*2] = {
+__device__ ncclKern_t ncclFuncs[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = {
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
diff --git a/src/collectives/device/gen_rules.sh b/src/collectives/device/gen_rules.sh
new file mode 100755
index 0000000..3942c8c
--- /dev/null
+++ b/src/collectives/device/gen_rules.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+#
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# See LICENSE.txt for license information
+#
+
+dir=$1
+
+targets="GENOBJS := \\\\\n"
+
+for base in all_reduce all_gather broadcast reduce reduce_scatter; do
+ opn=0
+ for op in sum prod min max; do
+ dtn=0
+ for dt in i8 u8 i32 u32 i64 u64 f16 f32 f64; do
+ echo "${dir}/${base}_${op}_${dt}.o : ${base}.cu ${dir}/${base}.dep"
+ echo " @printf \"Compiling %-35s > %s\\\\n\" ${base}.cu ${dir}/${base}_${op}_${dt}.o"
+ echo " mkdir -p ${dir}"
+ echo " \${NVCC} -DNCCL_OP=${opn} -DNCCL_TYPE=${dtn} \${NVCUFLAGS} -dc ${base}.cu -o ${dir}/${base}_${op}_${dt}.o"
+ echo ""
+ targets="$targets\t${dir}/${base}_${op}_${dt}.o \\\\\n"
+ dtn=$(($dtn + 1))
+ done
+ opn=$(($opn + 1))
+ done
+done
+echo -e "$targets"
diff --git a/src/collectives/device/ll_kernel.h b/src/collectives/device/ll_kernel.h
deleted file mode 100644
index 5ec3c9a..0000000
--- a/src/collectives/device/ll_kernel.h
+++ /dev/null
@@ -1,154 +0,0 @@
-/*************************************************************************
- * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved.
- *
- * See LICENSE.txt for license information
- ************************************************************************/
-
-#ifndef NCCL_LL_KERNEL_H_
-#define NCCL_LL_KERNEL_H_
-
-static __device__ uint64_t readLL(union ncclLLFifoLine* src, uint32_t flag) {
- uint32_t data1, flag1, data2, flag2;
- do {
- asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
- } while ((flag1 != flag) || (flag2 != flag));
- uint64_t val64 = data1 + (((uint64_t)data2) << 32);
- return val64;
-}
-
-static __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
- asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
-}
-
-// Using memcpy handles misaligned pointers.
-static __device__ uint64_t readAL(uint64_t* src) {
- uint64_t val;
- memcpy((char*)&val, (char*)src, sizeof(uint64_t));
- return val;
-}
-static __device__ void storeAL(uint64_t* dst, uint64_t val) {
- memcpy((char*)dst, (char*)&val, sizeof(uint64_t));
-}
-
-template <typename T, class FUNC>
-class LLPrimitives {
- private:
- template <int HAS_SRC1, int HAS_SRC2, int HAS_DST1, int HAS_DST2>
- static __device__ void ReduceCopyGeneric(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- if (size <= 0) return;
- size_t size64 = size * sizeof(T) / sizeof(uint64_t);
- uint64_t* src1A = (uint64_t*)src1;
- uint64_t* dst1A = (uint64_t*)dst1;
- int offset = threadIdx.x;
- // Do multiples of 64 bits
-#pragma unroll 1
- for (; offset < size64; offset += nthreads) {
- uint64_t val;
- if (HAS_SRC1) {
- val = readAL(src1A+offset);
- if (HAS_SRC2) val = MULTI<FUNC, T>()(readLL(src2+offset, iflag), val);
- } else if (HAS_SRC2) {
- val = readLL(src2+offset, iflag);
- }
- if (HAS_DST1) storeAL(dst1A+offset, val);
- if (HAS_DST2) storeLL(dst2+offset, val, oflag);
- }
- // Finish last word
- int sizeDone = size64*(sizeof(uint64_t)/sizeof(T));
- int sizeRem = size - sizeDone;
- if (threadIdx.x == 0 && sizeRem) {
- const T* src1B = src1 + sizeDone;
- T* dst1B = dst1 + sizeDone;
-
- uint64_t lastVal;
- T* vals = (T*)&lastVal;
-
- if (HAS_SRC2) {
- uint64_t lastVal2 = readLL(src2+size64, iflag);
- T* src2B = (T*)&lastVal2;
- for (int offset = 0; offset < sizeRem; offset++) {
- vals[offset] = HAS_SRC1 ? FUNC()(src2B[offset], src1B[offset]) : src2B[offset];
- }
- } else if (HAS_SRC1) {
- for (int offset = 0; offset < sizeRem; offset++) {
- vals[offset] = src1B[offset];
- }
- }
- if (HAS_DST2) storeLL(dst2+size64, lastVal, oflag);
- if (HAS_DST1) {
- for (int offset = 0; offset < sizeRem; offset++) {
- dst1B[offset] = vals[offset];
- }
- }
- }
- }
- public:
- static __device__ void ReduceCopy(const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 0, 0, 1>(src, NULL, NULL, dst, size, 0, oflag, nthreads);
- }
-
- static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) {
- return ReduceCopyGeneric<0, 1, 1, 0>(NULL, src, dst, NULL, size, iflag, 0, nthreads);
- }
-
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 0, 1>(src1, src2, NULL, dst, size, iflag, oflag, nthreads);
- }
-
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 1, 0>(src1, src2, dst, NULL, size, iflag, 0, nthreads);
- }
-
- static __device__ void ReduceCopy(const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 0, 1, 1>(src, NULL, dst1, dst2, size, 0, oflag, nthreads);
- }
-
- static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<0, 1, 1, 1>(NULL, src, dst1, dst2, size, iflag, oflag, nthreads);
- }
-
- static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) {
- return ReduceCopyGeneric<1, 1, 1, 1>(src1, src2, dst1, dst2, size, iflag, oflag, nthreads);
- }
-};
-
-// Common macros
-
-#define STEP_TO_SLOT(step) \
- (step % NCCL_LL_CHUNKS)
-
-#define WAIT_NEXT \
- if (tid == 0) { \
- while (sendHead + NCCL_LL_CHUNKS <= step) { \
- sendHead = sendHeadPtr[0]; \
- } \
- } \
- asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads));
-
-#define POST_SIZE \
- if (tid == 0 && sizesFifo) sizesFifo[step % NCCL_LL_CHUNKS] = (maxOffset <= 0) ? -1 : (maxOffset*2*(int)sizeof(T));
-
-#define ACK_PREV \
- asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads)); \
- if (tid == 0) recvHeadPtr[0] = step;
-
-#define FIFO_CLEANING_AND_SAVE_STEP(flag) do { \
- if (step > ring->send.conn.llLastCleaning + NCCL_LL_CLEAN_FREQ) { \
- /* Reset all flags */ \
- static_assert((NCCL_LL_BUFF_SIZE % NCCL_LL_MAX_NTHREADS) == 0, "NCCL_LL_BUFF_SIZE must be a multiple of THREADS"); \
- static_assert(NCCL_LL_BUFF_SIZE/(sizeof(union ncclLLFifoLine)*NCCL_LL_MAX_NTHREADS) > 0, "NCCL_LL_BUFF_SIZE is less than 16 bytes*THREADS"); \
- const union ncclLLFifoLine resetLine = { 0, flag, 0, flag }; \
- for (int i=0; i<NCCL_LL_BUFF_SIZE/(sizeof(union ncclLLFifoLine)*llNthreads); i++) { \
- prevInput[tid+i*llNthreads].i4 = resetLine.i4; \
- } \
- __threadfence_system(); \
- /* Restart from the same slot, only make sure sender waits for data to be reset */ \
- step += NCCL_LL_CHUNKS; \
- ACK_PREV; \
- while (sendHeadPtr[0] < step); \
- if (tid == 0) ring->send.conn.llLastCleaning = step; \
- } \
- ring->send.conn.llStep = step; \
-} while (0);
-
-#endif
diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h
index e2baa4b..c5aaf54 100644
--- a/src/collectives/device/primitives.h
+++ b/src/collectives/device/primitives.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -9,218 +9,579 @@
#include <type_traits>
#include "reduce_kernel.h" // for reduction funcs
+#include "common.h"
+
+#define SPINS_BEFORE_CHECK_ABORT 1000000
+
+// Unroll unconditionally the first send/recv since nsend/nrecv should be at
+// least 1 if SEND/RECV is set.
+#define FOR_SEND(func, ...) do { \
+ if (SEND) { \
+ /* Send to far first, then close */ \
+ for (int i=1; i<NSEND && i<nsend; i++) func(i, ##__VA_ARGS__); \
+ func(0, ##__VA_ARGS__); \
+ } \
+} while (0)
+
+#define FOR_RECV(func, ...) do { \
+ if (RECV) { \
+ /* Recv from close first, then far */ \
+ func(0, ##__VA_ARGS__); \
+ for (int i=1; i<NRECV && i<nrecv; i++) func(i, ##__VA_ARGS__); \
+ } \
+} while (0)
+// Implementation of primitive types
+template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, class FUNC>
+class ncclPrimitives {
+ private:
+ const int tid;
+ const int nthreads;
+ int nrecv = 0;
+ int nsend = 0;
+ const int stepSize;
+ struct ncclConnInfo* recvConn[NRECV];
+ struct ncclConnInfo* sendConn[NSEND];
+ volatile uint64_t* waitPtr;
+ uint64_t recvStep[NRECV];
+ uint64_t sendStep[NSEND];
+ uint64_t sendConnHead[NSEND];
+ const T* recvDirectBuff[NRECV];
+ T* sendDirectBuff[NSEND];
+ const T* recvBuff[NRECV];
+ T* sendBuff[NSEND];
+ struct ncclComm* comm;
+
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
+ inline __device__ const T* recvPtr(int i) { return ((const T*)recvBuff[i])+recvOffset(i); }
+ inline __device__ T* sendPtr(int i) { return ((T*)sendBuff[i])+sendOffset(i); }
+
+ inline __device__ void barrier() {
+ asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
+ }
-/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
- *
- * In order to reduce the reptetion of template arguments, the operations
- * are bundled as static methods of the Primitives class.
- *
- * Each primitive operation copies/reduces a contiguous buffer and syncs
- * an optional set of flags against a sub-step counter. The sync value is
- * based on the step parameter. Sync flags must be of type WaitFlag or
- * PostFlag. The primitive routines wait for all WaitFlag args to attain
- * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
- * corresponding substep by previous step) before executing the transfer.
- * After each substep is transfered, all PostFlag arguments get updated to
- * the value SUBSTEPS*step+substep+1.
- */
-
-
-class WaitFlag {
- volatile uint64_t * const flag;
- const int shift;
- public:
- __device__ __forceinline__
- WaitFlag(volatile uint64_t * const flag, const int shift) : flag(flag), shift(shift) { }
- __device__ __forceinline__
- void wait(uint64_t val) { while ((*flag + shift) < val) /*SPIN*/; }
-};
+ uint32_t mismatch = 0;
+ const uint64_t opCount;
+ inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) {
+ if (mismatch) {
+ // In non-LL, we use _threadfence_system before incrementing opCount, yet we are still waiting for credits here, so there must be a size mismatch
+ *(comm->fatalDevError) = ncclDevAssertedMismatch;
+ } else if (remoteOpCount && *remoteOpCount > opCount) {
+ mismatch += 1;
+ }
+ }
+
+ uint32_t spins = 0;
+ uint32_t abort = 0;
+
+ inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) {
+ spins++;
+ if (spins == SPINS_BEFORE_CHECK_ABORT) {
+ abort = *(comm->abortFlag);
+ checkMismatch(remoteOpCount);
+ spins = 0;
+ }
+ return abort;
+ }
+
+ inline __device__ void waitRecv(int i) {
+ spins = 0;
+ mismatch = 0;
+ recvStep[i] += SLICESTEPS;
+ if (tid == i) {
+ while (*(waitPtr) < recvStep[i]) {
+ if (checkAbort(recvConn[i]->opCountRem)) break;
+ }
+ }
+ }
+
+ inline __device__ void waitSend(int i) {
+ spins = 0;
+ mismatch = 0;
+ sendStep[i] += SLICESTEPS;
+ if (tid == WARP_SIZE+i) {
+ while (sendConnHead[i] + NCCL_STEPS < sendStep[i]) {
+ sendConnHead[i] = *waitPtr;
+ if (checkAbort(sendConn[i]->opCountRem)) break;
+ }
+ }
+ }
+
+ inline __device__ void postRecv(int i) {
+ *(recvConn[i]->head) = recvStep[i] += SLICESTEPS;
+ }
+
+ inline __device__ void postSend(int i) {
+ *(sendConn[i]->tail) = sendStep[i] += SLICESTEPS;
+ }
+
+ inline __device__ void postSendSize(int i, int size) {
+ if (sendConn[i]->fifo) sendConn[i]->fifo[sendStep[i]%NCCL_STEPS] = size;
+ }
+
+ template <int DIRECTRECV>
+ inline __device__ const T* directRecvPtr(int i, int directOffset) {
+ return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i);
+ }
+
+ template <int DIRECTSEND>
+ inline __device__ T* directSendPtr(int i, int directOffset) {
+ return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
+ }
+
+ template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
+ inline __device__ void
+ GenericOp(const T* srcPtr, T* dstPtr, int nelem, int directOffset) {
+ int offset = 0;
+ int sliceSize = stepSize * SLICESTEPS;
+
+ const T* srcs[RECV*NRECV+SRC];
+ srcs[0] = SRC ? srcPtr : directRecvPtr<DIRECTRECV>(0, directOffset);
+ if (RECV) {
+ if (SRC) srcs[1] = recvPtr(0);
+ for (int i=1; i<NRECV && i<nrecv; i++) srcs[SRC+i] = recvPtr(i);
+ }
+
+ T* dsts[SEND*NSEND+DST];
+ dsts[0] = DST ? dstPtr : directSendPtr<DIRECTSEND>(0, directOffset);
+ if (SEND) {
+ if (DST) dsts[1] = directSendPtr<DIRECTSEND>(0, directOffset);
+ for (int i=1; i<NSEND && i<nsend; i++) dsts[DST+i] = directSendPtr<DIRECTSEND>(i, directOffset);
+ }
+
+ #pragma unroll 1
+ for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
+ int realSize = max(0, min(sliceSize, nelem-offset));
+ if (tid < nthreads) {
+ FOR_SEND(waitSend);
+ FOR_RECV(waitRecv);
+ if (realSize > 0) {
+ barrier();
+ if (DIRECTRECV && recvDirectBuff[0]) {
+ // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
+ if (SEND) {
+ ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize);
+ }
+ } else {
+ ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
+ }
+ }
+ exitIfAbortBarrier(abort);
+ } else {
+ exitIfAbortBarrier(abort);
+ FOR_SEND(postSendSize, realSize*sizeof(T));
+ if (SEND) __threadfence_system();
+ FOR_SEND(postSend);
+ FOR_RECV(postRecv);
+ }
+ for (int i=0; i<RECV*NRECV+SRC; i++) srcs[i] += sliceSize;
+ for (int i=0; i<SEND*NSEND+DST; i++) dsts[i] += sliceSize;
+ offset += sliceSize;
+ }
+ }
+
+ __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
+ recvConn[i] = conn;
+ recvBuff[i] = (const T*)recvConn[i]->buff;
+ recvStep[i] = recvConn[i]->step;
+ recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
+ // Return credits in case we rounded up.
+ if (tid == nthreads) *recvConn[i]->head = recvStep[i];
+ if (tid == i) {
+ waitPtr = recvConn[i]->tail;
+ *(recvConn[i]->opCountLoc) = opCount;
+ }
+ recvDirectBuff[i] = NULL;
+ if (directBuff && recvConn[i]->direct) {
+ recvDirectBuff[i] = directBuff;
+ if (tid == 0) *recvConn[i]->ptrExchange = directBuff;
+ }
+ nrecv++;
+ }
+
+ __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i, T* directBuff) {
+ sendConn[i] = conn;
+ sendBuff[i] = (T*)sendConn[i]->buff;
+ sendStep[i] = sendConn[i]->step;
+ sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
+ if (tid == WARP_SIZE+i) {
+ waitPtr = sendConn[i]->head;
+ sendConnHead[i] = *waitPtr;
+ *(sendConn[i]->opCountLoc) = opCount;
+ }
+ sendDirectBuff[i] = NULL;
+ if (directBuff && sendConn[i]->direct) {
+ void* volatile* ptr = sendConn[i]->ptrExchange;
+ while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
+ __syncthreads();
+ if (tid == 0) *ptr = NULL;
+ }
+ nsend++;
+ }
+
+ __device__ __forceinline__ void saveRecvConn(int i) {
+ if (tid == i) {
+ recvConn[i]->step = recvStep[i];
+ __threadfence_system();
+ *(recvConn[i]->opCountLoc) += 1;
+ }
+ }
+
+ __device__ __forceinline__ void saveSendConn(int i) {
+ if (tid == WARP_SIZE+i) {
+ sendConn[i]->step = sendStep[i];
+ __threadfence_system();
+ *(sendConn[i]->opCountLoc) += 1;
+ }
+ }
-class PostFlag {
- volatile uint64_t * const flag;
- const int shift;
- volatile int * const fifo;
- const int fifo_size;
public:
__device__ __forceinline__
- PostFlag(volatile uint64_t* const flag, const int shift, volatile int* const fifo, const int fifo_size) : flag(flag), shift(shift), fifo(fifo), fifo_size(fifo_size) { }
- __device__ __forceinline__
- void post(uint64_t val) { *flag = (val - shift); }
- __device__ __forceinline__
- void postSize(uint64_t step, int size) { if (fifo != NULL) fifo[step%fifo_size] = size; };
-};
+ ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclComm* comm, const uint64_t opCount)
+ : comm(comm), tid(tid), nthreads(nthreads), stepSize(stepSize), opCount(opCount) {
+ // Make sure step is updated before we read it
+ __syncthreads();
+ for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff);
+ for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i, directBuff);
+ }
-// Helper to check if any argument is of type T.
-// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
-template<typename T> __device__ __forceinline__
-bool AnyAre() { return false; }
+ __device__ __forceinline__ void
+ send(const T* src, int nelem) {
+ GenericOp<0, 0, 0, 1, 1, 0>(src, NULL, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ directSend(const T* src, int directOffset, int nelem) {
+ GenericOp<0, 1, 0, 1, 1, 0>(src, NULL, nelem, directOffset);
+ }
-template<typename T, typename FIRST_T, typename... TAIL_Ts>
-__device__ __forceinline__
-bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
- return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
-}
+ __device__ __forceinline__ void
+ recv(T* dst, int nelem) {
+ GenericOp<0, 0, 1, 0, 0, 1>(NULL, dst, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ directRecv(T* dst, int directOffset, int nelem) {
+ GenericOp<1, 0, 1, 0, 0, 1>(NULL, dst, nelem, directOffset);
+ }
+ __device__ __forceinline__ void
+ copySend(const T* src, T* dst, int nelem) {
+ GenericOp<0, 0, 0, 1, 1, 1>(src, dst, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ directCopySend(const T* src, T* dst, int directOffset, int nelem) {
+ GenericOp<0, 1, 0, 1, 1, 1>(src, dst, nelem, directOffset);
+ }
-// Wait on all WaitFlags, ignore PostFlags
-__device__ __forceinline__
-void WaitOnFlags(uint64_t val) { }
+ __device__ __forceinline__ void
+ recvCopySend(T* dst, int nelem) {
+ GenericOp<0, 0, 1, 1, 0, 1>(NULL, dst, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ directRecvCopySend(T* dst, int directOffset, int nelem) {
+ GenericOp<1, 1, 1, 1, 0, 1>(NULL, dst, nelem, directOffset);
+ }
-template <typename... TAIL_Ts> __device__ __forceinline__
-void WaitOnFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
- flag.wait(val);
- WaitOnFlags(val, tail...);
-}
+ __device__ __forceinline__ void
+ recvReduceCopy(const T* src, T* dst, int nelem) {
+ GenericOp<0, 0, 1, 0, 1, 1>(src, dst, nelem, 0);
+ }
-template <typename... TAIL_Ts> __device__ __forceinline__
-void WaitOnFlags(uint64_t val, PostFlag, TAIL_Ts... tail) {
- WaitOnFlags(val, tail...);
-}
+ __device__ __forceinline__ void
+ recvReduceSend(const T* src, int nelem) {
+ GenericOp<0, 0, 1, 1, 1, 0>(src, NULL, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ recvReduceCopySend(const T* src, T* dst, int nelem) {
+ GenericOp<0, 0, 1, 1, 1, 1>(src, dst, nelem, 0);
+ }
+ __device__ __forceinline__ void
+ directRecvReduceCopySend(const T* src, T* dst, int directOffset, int nelem) {
+ // Direct is only for the send part
+ GenericOp<0, 1, 1, 1, 1, 1>(src, dst, nelem, directOffset);
+ }
-// Post all PostFlags, ignore WaitFlags
-__device__ __forceinline__
-void PostToFlags(uint64_t val) { }
+ __device__ __forceinline__ ~ncclPrimitives() {
+ // Save steps for next collective. Have thread 0 do it to be compatible
+ // with the way LL works.
+ for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i);
+ for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i);
+ }
+};
-template <typename... TAIL_Ts> __device__ __forceinline__
-void PostToFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
- PostToFlags(val, tail...);
-}
+template <typename T, class FUNC, int NRECV, int NSEND>
+class ncclLLPrimitives {
+ private:
+ const int tid;
+ const int nthreads;
+ int nrecv = 0;
+ int nsend = 0;
+ struct ncclConnInfo* recvConn[NRECV];
+ struct ncclConnInfo* sendConn[NSEND];
+ volatile uint64_t* waitPtr;
+ volatile uint64_t* postPtr;
+ volatile int* fifoPtr;
+ uint64_t recvStep[NRECV];
+ uint64_t sendStep[NSEND];
+ uint64_t sendConnHead;
+ union ncclLLFifoLine* recvBuff[NRECV];
+ union ncclLLFifoLine* sendBuff[NSEND];
+ struct ncclComm* comm;
+
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
+ inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
+ inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
+ inline __device__ uint32_t recvFlag(int i) { return recvStep[i]+1; }
+ inline __device__ uint32_t sendFlag(int i) { return sendStep[i]+1; }
+
+ // Exit If Abort Barrier : make sure all threads exit consistently
+ // Each thread sets a predicate to true if val == 1
+ // all CTA's threads enter the barrier and do a popc on their predicates being True
+ // If any of the thread's predicate was True, all the threads call exit()
+ inline __device__ void exitIfAbortLocalBarrier() {
+ uint32_t popc;
+ asm ("{");
+ asm volatile (" .reg .pred barr_pred;");
+ asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort));
+ asm volatile (" bar.red.popc.u32 %0, 14, %1, barr_pred;" : "=r"(popc) : "r"(nthreads));
+ asm ("}");
+ if (popc) {
+ // Make sure threads not participating in the operation get the abort and all threads exit
+ exitIfAbortBarrier(1);
+ }
+ }
+
+ inline __device__ void barrier() {
+ asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
+ }
+
+ uint32_t mismatch = 0;
+ const uint64_t opCount;
+
+ inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) {
+ if (mismatch > 20) {
+ // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
+ // Note that we are not using _threadfence_system in LL so the error cannot be asserted
+ *(comm->fatalDevError) = ncclDevSuspectedMismatch;
+ } else if (remoteOpCount && *remoteOpCount > opCount) {
+ mismatch += 1;
+ }
+ }
-template <typename... TAIL_Ts> __device__ __forceinline__
-void PostToFlags(uint64_t val, PostFlag flag, TAIL_Ts... tail) {
- flag.post(val);
- PostToFlags(val, tail...);
-}
+ uint32_t spins = 0;
+ uint32_t abort = 0;
+ inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) {
+ spins++;
+ if (spins == SPINS_BEFORE_CHECK_ABORT) {
+ abort = *(comm->abortFlag);
+ checkMismatch(remoteOpCount);
+ spins = 0;
+ }
+ return abort;
+ }
-// Post sizes for PostFlags, ignore WaitFlags
-__device__ __forceinline__
-void PostSizeToFlags(uint64_t step, int size) { }
+ inline __device__ void waitSend(int i, int nbytes) {
+ spins = 0;
+ mismatch = 0;
+ if (tid == WARP_SIZE+i) {
+ while (sendConnHead + NCCL_STEPS < sendStep[i] + 1) {
+ sendConnHead = *waitPtr;
+ if (checkAbort(sendConn[i]->opCountRem)) break;
+ }
+ if (fifoPtr) fifoPtr[sendStep[i]%NCCL_STEPS] = nbytes;
+ }
+ }
-template <typename... TAIL_Ts> __device__ __forceinline__
-void PostSizeToFlags(uint64_t step, int size, WaitFlag flag, TAIL_Ts... tail) {
- PostSizeToFlags(step, size, tail...);
-}
+ inline __device__ void postRecv(int i) {
+ recvStep[i]++;
+ if (tid == i) *postPtr = recvStep[i];
+ }
-template <typename... TAIL_Ts> __device__ __forceinline__
-void PostSizeToFlags(uint64_t step, int size, PostFlag flag, TAIL_Ts... tail) {
- flag.postSize(step, size);
- PostSizeToFlags(step, size, tail...);
-}
+ inline __device__ void postSend(int i) {
+ sendStep[i]++;
+ }
+ __device__ uint64_t readLL(int i, int offset) {
+ union ncclLLFifoLine* src = recvPtr(i) + offset;
+ uint32_t flag = recvFlag(i);
+ uint32_t data1, flag1, data2, flag2;
+ spins = 0;
+ mismatch = 0;
+ do {
+ asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
+ if (checkAbort(recvConn[i]->opCountRem)) break;
+ } while ((flag1 != flag) || (flag2 != flag));
+ uint64_t val64 = data1 + (((uint64_t)data2) << 32);
+ return val64;
+ }
-// Create pointer arithmetic syntax that doesn't break for std::nullptr_t
-template <typename Tptr> __device__ __forceinline__
-Tptr ptradd(Tptr ptr, int i) {
- return ptr + i;
-}
+ __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
+ asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
+ }
-__device__ __forceinline__
-std::nullptr_t ptradd(std::nullptr_t ptr, int i) {
- return nullptr;
-}
+ // Using memcpy handles misaligned pointers.
+ __device__ uint64_t readAL(uint64_t* src) {
+ uint64_t val;
+ memcpy((char*)&val, (char*)src, sizeof(uint64_t));
+ return val;
+ }
+ __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
+ memcpy((char*)dst, (char*)&val, nbytes);
+ }
-// Implementation of primitive types
-template <int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
-class Primitives {
- private:
- template <typename SRC2_T, // either T* or std::nullptr_t
- typename DST2_T, // either T* or std::nullptr_t
- typename... SYNC_Ts> // either WaitFunc or PostFunc
- static __device__ __forceinline__ void
- GenericOp(const int tid, const int nthreads,
- const T* src1,
- const SRC2_T src2,
- T* dst1,
- DST2_T dst2,
- int len, int maxoffset, uint64_t step, SYNC_Ts... flags) {
-
- enum { noSrc2 = std::is_same<SRC2_T, std::nullptr_t>::value };
- enum { noDst2 = std::is_same<DST2_T, std::nullptr_t>::value };
- static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
- "src2 must be of type T* or std::nullptr_t");
- static_assert(noDst2 || std::is_same<DST2_T, T*>::value,
- "dst2 must be of type T* or std::nullptr_t");
-
- using OpType = typename std::conditional<noSrc2, FuncSum<T>, REDOP>::type;
-
- int sliceSize = len / SUBSTEPS;
- int sliceOffset = 0;
-
-#pragma unroll 1
- for (int sub=0; sub<SUBSTEPS; ++sub) {
- int realSize = max(0, min(sliceSize, maxoffset-sliceOffset));
- if (tid < nthreads) {
- if (AnyAre<WaitFlag>(flags...)) {
- if (tid == 0) {
- WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
- }
- asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
+ template <int RECV, int SEND, int SRC, int DST>
+ __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
+ uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
+ FOR_SEND(waitSend, nbytes*2);
+ barrier();
+ uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
+ uint64_t* srcPack = (uint64_t*)srcPtr;
+ uint64_t* dstPack = (uint64_t*)dstPtr;
+ // Do multiples of 64 bits
+ #pragma unroll 2
+ for (int offset=tid; offset<npack; offset+=nthreads) {
+ // Recv : local, then intra-node, then inter-node
+ uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
+ if (RECV) {
+ if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
+ for (int i=1; i<NRECV && i<nrecv; i++) {
+ val = MULTI<FUNC, T>()(readLL(i, offset), val);
}
- ReduceOrCopy
- <
- UNROLL,
- OpType,
- T,
- !std::is_same<DST2_T, std::nullptr_t>::value, // HAS_DEST1
- !std::is_same<SRC2_T, std::nullptr_t>::value // HAS_SRC1
- >
- (
- tid, nthreads,
- ptradd(dst1, sliceOffset),
- ptradd(dst2, sliceOffset),
- ptradd(src1, sliceOffset),
- ptradd(src2, sliceOffset),
- realSize
- );
- if (AnyAre<PostFlag>(flags...)) {
- __syncthreads();
+ }
+
+ // Send : inter-node, then intra-node, then local
+ if (SEND) {
+ for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
+ storeLL(sendPtr(0)+offset, val, sendFlag(0));
+ }
+ if (DST) {
+ if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
+ // Last incomplete word
+ storeAL(dstPack+offset, val, nbytes & 0x7);
+ } else {
+ storeAL(dstPack+offset, val, sizeof(uint64_t));
}
- } else {
- if (AnyAre<PostFlag>(flags...)) {
- __syncthreads();
- PostSizeToFlags(SUBSTEPS*step+sub, realSize*sizeof(T), flags...);
- __threadfence_system();
- PostToFlags(SUBSTEPS*step + sub + 1, flags...);
+ }
+ }
+ exitIfAbortLocalBarrier();
+ FOR_RECV(postRecv);
+ FOR_SEND(postSend);
+ }
+
+ __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
+ recvConn[i] = conn;
+ recvBuff[i] = recvConn[i]->llBuff;
+ recvStep[i] = recvConn[i]->step;
+ if (tid == i) {
+ postPtr = recvConn[i]->head;
+ *(recvConn[i]->opCountLoc) = opCount;
+ }
+ nrecv++;
+ }
+
+ __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
+ sendConn[i] = conn;
+ sendBuff[i] = sendConn[i]->llBuff;
+ sendStep[i] = sendConn[i]->step;
+ if (tid == WARP_SIZE+i) {
+ waitPtr = sendConn[i]->head;
+ fifoPtr = sendConn[i]->fifo;
+ sendConnHead = *waitPtr;
+ *(sendConn[i]->opCountLoc) = opCount;
+ }
+ nsend++;
+ }
+
+ __device__ __forceinline__ void saveRecvConn(int i) {
+ if (tid == i) {
+ recvConn[i]->step = recvStep[i];
+ *(recvConn[i]->opCountLoc) += 1;
+ __threadfence_block();
+ }
+ }
+
+ __device__ __forceinline__ void saveSendConn(int i) {
+ if (tid == WARP_SIZE+i) {
+ sendConn[i]->step = sendStep[i];
+ *(sendConn[i]->opCountLoc) += 1;
+ __threadfence_block();
+ }
+ }
+
+ __device__ __forceinline__ void llSendCleaning(int i) {
+ if (sendStep[i] > sendConn[i]->llLastCleaning + NCCL_LL_CLEAN_FREQ) {
+ /* Reset all flags */
+ static_assert((NCCL_LL_BUFF_SIZE % NCCL_LL_MAX_NTHREADS) == 0, "NCCL_LL_BUFF_SIZE must be a multiple of THREADS");
+ static_assert(NCCL_LL_BUFF_SIZE/(sizeof(union ncclLLFifoLine)*NCCL_LL_MAX_NTHREADS) > 0, "NCCL_LL_BUFF_SIZE is less than 16 bytes*THREADS");
+ for (int s=0; s<NCCL_STEPS; s++) {
+ waitSend(i, 0);
+ for (int o=tid; o<NCCL_LL_SLICE_LINES; o+=nthreads) {
+ const union ncclLLFifoLine resetLine = { 0, sendFlag(i), 0, sendFlag(i) };
+ sendPtr(i)[o].i4 = resetLine.i4;
}
}
- sliceOffset += sliceSize;
+ if (tid == 0) sendConn[i]->llLastCleaning = sendStep[i];
+ }
+ }
+
+ __device__ __forceinline__ void llRecvCleaning(int i) {
+ if (recvStep[i] > recvConn[i]->llLastCleaning + NCCL_LL_CLEAN_FREQ) {
+ recvStep[i] += NCCL_STEPS;
+ if (tid == 0) recvConn[i]->llLastCleaning = recvStep[i];
}
}
public:
- template <typename... SYNC_Ts>
- static __device__ __forceinline__ void
- Copy(const int tid, const int nthreads, const T* src, T* dst,
- int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
- GenericOp(tid, nthreads, src, nullptr, dst, nullptr, len, maxOffset, step, flags...);
+ __device__ __forceinline__
+ ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclComm* comm, const uint64_t opCount)
+ : comm(comm), tid(tid), nthreads(nthreads), opCount(opCount) {
+ // Make sure step is updated before we read it.
+ barrier();
+
+ for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
+ for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
}
- template <typename... SYNC_Ts>
- static __device__ __forceinline__ void
- DoubleCopy(const int tid, const int nthreads, const T* src, T* dst1, T* dst2,
- int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
- GenericOp(tid, nthreads, src, nullptr, dst1, dst2, len, maxOffset, step, flags...);
+ __device__ void send(const T* src, int nelem) {
+ return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
}
- template <typename... SYNC_Ts>
- static __device__ __forceinline__ void
- Reduce(const int tid, const int nthreads, const T* src1, const T* src2, T* dst,
- int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
- GenericOp(tid, nthreads, src1, src2, dst, nullptr, len, maxOffset, step, flags...);
+ __device__ void recv(T* dst, int nelem) {
+ return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
}
- template <typename... SYNC_Ts>
- static __device__ __forceinline__ void
- ReduceCopy(const int tid, const int nthreads, const T* src1, const T* src2, T* dst1, T* dst2,
- int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
- GenericOp(tid, nthreads, src1, src2, dst1, dst2, len, maxOffset, step, flags...);
+ __device__ void recvReduceSend(const T* src, int nelem) {
+ return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
+ }
+
+ __device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
+ return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
}
-};
-#endif // end include guard
+ __device__ void copySend(const T* src, T* dst, int nelem) {
+ return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ void recvCopySend(T* dst, int nelem) {
+ return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
+ }
+
+ __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
+ return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ __forceinline__ ~ncclLLPrimitives() {
+ for (int i=0; i<NSEND && i<nsend; i++) llSendCleaning(i);
+ for (int i=0; i<NRECV && i<nrecv; i++) llRecvCleaning(i);
+ // Save steps for the next operation
+ for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i);
+ for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i);
+ }
+};
+#endif
diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu
index bd1d23c..1ef66d4 100644
--- a/src/collectives/device/reduce.cu
+++ b/src/collectives/device/reduce.cu
@@ -4,18 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "common.h"
#include "reduce.h"
+#include "common.h"
#include "collectives.h"
-#define UNROLL 4
-
-#if NCCL_OP == 0
-IMPL_COLL2(ncclReduce, sum, FuncSum, ncclCollReduce, ncclSum);
-#elif NCCL_OP == 1
-IMPL_COLL2(ncclReduce, prod, FuncProd, ncclCollReduce, ncclProd);
-#elif NCCL_OP == 2
-IMPL_COLL2(ncclReduce, min, FuncMin, ncclCollReduce, ncclMin);
-#elif NCCL_OP == 3
-IMPL_COLL2(ncclReduce, max, FuncMax, ncclCollReduce, ncclMax);
-#endif
+IMPL_COLL_R(ncclReduce, ncclCollReduce);
diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h
index f5694b1..302d053 100644
--- a/src/collectives/device/reduce.h
+++ b/src/collectives/device/reduce.h
@@ -8,143 +8,71 @@
#include "primitives.h"
#include "collectives.h"
-// Increase Step and boffset for buffer sync
-#define NEXT_STEP \
- step++; \
- boffset += sliceSize; \
- if (boffset == buffSize) boffset = 0;
-
template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceKernel(struct CollectiveArgs* args) {
+__device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = blockDim.x - 1;
const int bid = args->bid;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
-
- WaitFlag waitDoneFromNext(ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0);
- PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0);
- PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCE_BUFCHUNKS*REDUCE_SUBSTEPS);
-
- typedef Primitives<UNROLL, REDUCE_SUBSTEPS, T, FUNC> Prims;
-
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
const ssize_t size = args->N;
const int nranks = comm->nRanks;
- const int buffSize = ring->buffSize / sizeof(T);
- const int sliceSize = buffSize / REDUCE_BUFCHUNKS;
- const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * REDUCE_CHUNKSTEPS;
+ const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
const int rank = ring->devUserRanks[0];
const int prevRank = ring->devUserRanks[nranks-1];
const int root = args->root;
- if (tid == 0) {
- // Update in case we skipped some collectives
- *ring->recv.conn.opCount = args->opCount;
-
- if (rank != root) {
- // Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
- }
- }
- __syncthreads();
-
- uint64_t step = 0ULL;
- int boffset = 0;
-
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- T * __restrict__ prevInput = (T*)ring->recv.conn.buff;
- T * __restrict__ nextOutput = (T*)ring->send.conn.buff;
+
+ ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings));
- ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t offset = gridOffset + bid*chunkSize;
- int maxOffset = min(chunkSize, size-offset);
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t offset = gridOffset + bid*realChunkSize;
+ int nelem = min(realChunkSize, size-offset);
if (prevRank == root) {
- Prims::Copy(tid, nthreads,
- thisInput + offset,
- nextOutput + boffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
+ prims.send(thisInput+offset, nelem);
} else if (rank == root) {
- Prims::Reduce(tid, nthreads,
- prevInput + boffset,
- thisInput + offset,
- thisOutput + offset,
- sliceSize, maxOffset,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
} else {
- Prims::Reduce(tid, nthreads,
- prevInput + boffset,
- thisInput + offset,
- nextOutput + boffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
- }
- NEXT_STEP; // Increases step, boffset
- }
-
- if (tid == 0) {
- if (rank != root) {
- // Wait for next to have consumed data before resetting the flag
- waitDoneFromNext.wait(REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1));
- *ring->send.conn.head = 0ULL;
+ prims.recvReduceSend(thisInput+offset, nelem);
}
- *ring->recv.conn.tail = 0ULL;
- __threadfence_system();
- *ring->recv.conn.opCount = args->opCount+1;
}
}
-#include "ll_kernel.h"
-
-#define NEXT_STEP_LL \
- boffset += NCCL_LL_SLICE_LINES; \
- if (boffset == NCCL_LL_BUFF_LINES) boffset = 0; \
- flag++; \
- step++;
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclReduceTreeKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceLLKernel(struct CollectiveArgs* args) {
+__device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int bid = args->bid;
- const int llNthreads = args->nThreads;
+ const int nthreads = args->nThreads;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead;
- volatile uint64_t * sendHeadPtr = ring->send.conn.llHead;
- volatile int * sizesFifo = ring->send.conn.llFifo;
- uint64_t sendHead = sendHeadPtr[0];
- const int nranks = comm->nRanks;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
const int rank = comm->rank;
+ const int nranks = comm->nRanks;
const int prevRank = ring->devUserRanks[nranks-1];
const int root = args->root;
- typedef LLPrimitives<T, FUNC> LL;
-
- const ssize_t size = args->N;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nRings*chunkSize;
-
- uint64_t step = ring->send.conn.llStep;
- uint32_t flag = step + 1;
- int boffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step);
+ const ssize_t loopSize = args->nChannels*chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff;
- union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
@@ -152,39 +80,16 @@ __device__ void ncclReduceLLKernel(struct CollectiveArgs* args) {
}
ssize_t offset = gridOffset + bid*chunkSize;
- int maxOffset = min(chunkSize, size-offset);
+ int nelem = min(chunkSize, size-offset);
if (prevRank == root) {
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- nextOutput + boffset,
- maxOffset, flag, llNthreads);
- POST_SIZE;
- NEXT_STEP_LL;
+ LLprims.send(thisInput+offset, nelem);
} else if (rank == root) {
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + boffset,
- thisOutput + offset,
- maxOffset, flag, llNthreads);
- NEXT_STEP_LL;
- ACK_PREV;
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
} else {
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + boffset,
- nextOutput + boffset,
- maxOffset, flag, flag, llNthreads);
- POST_SIZE;
- NEXT_STEP_LL;
- ACK_PREV;
+ LLprims.recvReduceSend(thisInput+offset, nelem);
}
}
-
- // We need everyone to acknowledge data even if they didn't receive anything
- // so that the next collective can start right away.
- ACK_PREV;
-
- FIFO_CLEANING_AND_SAVE_STEP(flag);
}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h
index 0cb8f13..0e90793 100644
--- a/src/collectives/device/reduce_kernel.h
+++ b/src/collectives/device/reduce_kernel.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -46,30 +46,28 @@ struct FuncMin {
}
};
+#define MASK0 0x00ff00ff
+#define MASK1 0xff00ff00
+static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
+ /* This can be used both for signed and unsigned 8-bit addition */
+ const uint32_t x0 = x & MASK0;
+ const uint32_t x1 = x & MASK1;
+ const uint32_t y0 = y & MASK0;
+ const uint32_t y1 = y & MASK1;
+ const uint32_t r0 = (x0+y0);
+ const uint32_t r1 = (x1+y1);
+ return (r0 & MASK0) | (r1 & MASK1);
+}
+
template<>
struct FuncSum<int8_t> {
- union converter { uint32_t storage; char4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vadd.s32.s32.s32 %0, %1.b0, %2.b0; \n\t"
- "vadd.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vadd.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vadd.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
- converter cx, cy, cr;
- cx.storage = x;
- cy.storage = y;
- cr.a.x = cx.a.x + cy.a.x;
- cr.a.y = cx.a.y + cy.a.y;
- cr.a.z = cx.a.z + cy.a.z;
- cr.a.w = cx.a.w + cy.a.w;
- return cr.storage;
+ return addChar4(x, y);
#endif
}
__device__ int8_t operator()(const int8_t x, const int8_t y) const {
@@ -78,28 +76,13 @@ struct FuncSum<int8_t> {
};
template<>
struct FuncSum<uint8_t> {
- union converter { uint32_t storage; uchar4 a; };
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
int32_t rv, z=0;
asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vadd.u32.u32.u32 %0, %1.b0, %2.b0; \n\t"
- "vadd.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vadd.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vadd.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
- converter cx, cy, cr;
- cx.storage = x;
- cy.storage = y;
- cr.a.x = cx.a.x + cy.a.x;
- cr.a.y = cx.a.y + cy.a.y;
- cr.a.z = cx.a.z + cy.a.z;
- cr.a.w = cx.a.w + cy.a.w;
- return cr.storage;
+ return addChar4(x, y);
#endif
}
__device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
@@ -109,22 +92,6 @@ struct FuncSum<uint8_t> {
static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) {
/* This can be used both for signed and unsigned 8-bit multiplication */
-#if (__CUDA_ARCH__ >= 300)
- uint32_t rv;
- asm("{ .reg .u32 t0, t1, t2, t3;\n\t"
- " vmad.u32.u32.u32 t3, %1.b3, %2.b3, 0;\n\t"
- " vmad.u32.u32.u32 t2, %1.b2, %2.b2, 0;\n\t"
- " shl.b32 t3, t3, 16;\n\t"
- " shl.b32 t2, t2, 16;\n\t"
- " vmad.u32.u32.u32 t1, %1.b1, %2.b1, t3;\n\t"
- " shl.b32 t1, t1, 8;\n\t"
- " vmad.u32.u32.u32 t0, %1.b0, %2.b0, t2;\n\t"
- " and.b32 t1, t1, 0xff00ff00;\n\t"
- " and.b32 t0, t0, 0x00ff00ff;\n\t"
- " or.b32 %0, t0, t1;\n\t"
- "}" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
-#else
union converter { uint32_t storage; char4 a; };
converter cx, cy, cr;
cx.storage = x;
@@ -134,7 +101,6 @@ static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) {
cr.a.z = cx.a.z * cy.a.z;
cr.a.w = cx.a.w * cy.a.w;
return cr.storage;
-#endif
}
template<>
@@ -164,13 +130,6 @@ struct FuncMax<int8_t> {
int32_t rv, z=0;
asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vmax.s32.s32.s32 %0, %1.b0, %2.b0; \n\t"
- "vmax.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vmax.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vmax.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
converter cx, cy, cr;
cx.storage = x;
@@ -194,13 +153,6 @@ struct FuncMax<uint8_t> {
int32_t rv, z=0;
asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vmax.u32.u32.u32 %0, %1.b0, %2.b0; \n\t"
- "vmax.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vmax.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vmax.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
converter cx, cy, cr;
cx.storage = x;
@@ -225,13 +177,6 @@ struct FuncMin<int8_t> {
int32_t rv, z=0;
asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vmin.s32.s32.s32 %0, %1.b0, %2.b0; \n\t"
- "vmin.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vmin.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vmin.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
converter cx, cy, cr;
cx.storage = x;
@@ -255,13 +200,6 @@ struct FuncMin<uint8_t> {
int32_t rv, z=0;
asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
return rv;
-#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700)
- int32_t rv;
- asm("vmin.u32.u32.u32 %0, %1.b0, %2.b0; \n\t"
- "vmin.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t"
- "vmin.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t"
- "vmin.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y));
- return rv;
#else
converter cx, cy, cr;
cx.storage = x;
diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu
index b16053c..10857ed 100644
--- a/src/collectives/device/reduce_scatter.cu
+++ b/src/collectives/device/reduce_scatter.cu
@@ -4,18 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "common.h"
#include "reduce_scatter.h"
+#include "common.h"
#include "collectives.h"
-#define UNROLL 4
-
-#if NCCL_OP == 0
-IMPL_COLL2(ncclReduceScatter, sum, FuncSum, ncclCollReduceScatter, ncclSum);
-#elif NCCL_OP == 1
-IMPL_COLL2(ncclReduceScatter, prod, FuncProd, ncclCollReduceScatter, ncclProd);
-#elif NCCL_OP == 2
-IMPL_COLL2(ncclReduceScatter, min, FuncMin, ncclCollReduceScatter, ncclMin);
-#elif NCCL_OP == 3
-IMPL_COLL2(ncclReduceScatter, max, FuncMax, ncclCollReduceScatter, ncclMax);
-#endif
+IMPL_COLL_R(ncclReduceScatter, ncclCollReduceScatter);
diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h
index cad011b..c70c845 100644
--- a/src/collectives/device/reduce_scatter.h
+++ b/src/collectives/device/reduce_scatter.h
@@ -8,156 +8,82 @@
#include "primitives.h"
#include "collectives.h"
-// Increase Step and poffset/noffset for buffer sync
-#define NEXT_STEP \
- step++; \
- poffset = noffset; \
- noffset += sliceSize; \
- if (noffset == buffSize) noffset = 0;
-
template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) {
+__device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = blockDim.x - 1;
const int bid = args->bid;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
-
- WaitFlag waitDoneFromNext(ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS);
- WaitFlag waitReadyFromPrev(ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS);
- PostFlag postDoneToPrev(ring->recv.conn.head, REDUCESCATTER_SUBSTEPS, NULL, 0);
- PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS);
-
- typedef Primitives<UNROLL, REDUCESCATTER_SUBSTEPS, T, FUNC> Prims;
-
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
const ssize_t size = args->N;
const int nranks = comm->nRanks;
- const int buffSize = ring->buffSize / sizeof(T);
- const int sliceSize = buffSize / REDUCESCATTER_BUFCHUNKS;
- const ssize_t loopSize = args->nRings*(ssize_t)sliceSize;
-
- if (tid == 0) {
- // Update in case we skipped some collectives
- *ring->recv.conn.opCount = args->opCount;
- // Wait for next to be ready
- WaitFlag waitOpCountNext(ring->send.conn.opCount, 0);
- waitOpCountNext.wait(args->opCount);
- }
- __syncthreads();
-
- uint64_t step = 0ULL;
- int poffset, noffset = 0;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- T * __restrict__ prevInput = (T*)ring->recv.conn.buff;
- T * __restrict__ nextOutput = (T*)ring->send.conn.buff;
+
+ ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings));
- ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*realChunkSize;
/////////////// begin ReduceScatter steps ///////////////
ssize_t offset;
- int maxOffset = min(chunkSize, size-chunkOffset);
+ int nelem = min(realChunkSize, size-chunkOffset);
int rankDest;
// step 0: push data to next GPU
rankDest = ring->devUserRanks[nranks-1];
offset = chunkOffset + rankDest * size;
- Prims::Copy(tid, nthreads,
- thisInput + offset,
- nextOutput + noffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext,
- postReadyToNext);
-
- NEXT_STEP; // Increases step, poffset, noffset
+ prims.send(thisInput+offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
rankDest = ring->devUserRanks[nranks-j];
offset = chunkOffset + rankDest * size;
- Prims::Reduce(tid, nthreads,
- prevInput + poffset,
- thisInput + offset,
- nextOutput + noffset,
- sliceSize, maxOffset,
- step,
- waitDoneFromNext, waitReadyFromPrev,
- postReadyToNext, postDoneToPrev);
-
- NEXT_STEP;
+ prims.recvReduceSend(thisInput+offset, nelem);
}
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data and push to the next GPU
+ // step k-1: reduce this buffer and data, which will produce the final result
rankDest = ring->devUserRanks[0];
offset = chunkOffset + rankDest * size;
- Prims::Reduce(tid, nthreads,
- prevInput + poffset,
- thisInput + offset,
- thisOutput + chunkOffset,
- sliceSize, maxOffset,
- step,
- waitReadyFromPrev,
- postDoneToPrev);
- }
-
- if (tid == 0) {
- waitDoneFromNext.wait(REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS));
- *ring->send.conn.head = 0ULL;
- *ring->recv.conn.tail = 0ULL;
- __threadfence_system();
- *ring->recv.conn.opCount = args->opCount+1;
+ prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
}
}
-#include "ll_kernel.h"
-
-#define NEXT_STEP_LL \
- poffset = noffset; \
- pflag = nflag; \
- noffset += NCCL_LL_SLICE_LINES; \
- if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \
- nflag++; \
- step++;
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclReduceScatterTreeKernel(struct CollectiveArgs* args) { }
template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
+__device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int bid = args->bid;
- const int llNthreads = args->nThreads;
+ const int nthreads = args->nThreads;
struct ncclComm* comm = args->comm;
- struct ncclRing* ring = comm->rings+blockIdx.x;
- volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead;
- volatile uint64_t * sendHeadPtr = ring->send.conn.llHead;
- volatile int * sizesFifo = ring->send.conn.llFifo;
- uint64_t sendHead = sendHeadPtr[0];
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
- typedef LLPrimitives<T, FUNC> LL;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
const ssize_t size = args->N;
//const int rank = comm->rank;
const int nranks = comm->nRanks;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = args->nRings*chunkSize;
-
- uint64_t step = ring->send.conn.llStep;
- uint32_t pflag, nflag = step + 1;
- int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step);
+ const ssize_t loopSize = args->nChannels*chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff;
- union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
if (size-gridOffset < loopSize) {
@@ -167,37 +93,21 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
/////////////// begin ReduceScatter steps ///////////////
ssize_t offset;
- int maxOffset = min(chunkSize, size-chunkOffset);
+ int nelem = min(chunkSize, size-chunkOffset);
int rankDest;
// step 0: push data to next GPU
rankDest = ring->devUserRanks[nranks-1];
offset = chunkOffset + rankDest * size;
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- nextOutput + noffset,
- maxOffset, nflag, llNthreads);
- POST_SIZE;
-
- NEXT_STEP_LL;
+ LLprims.send(thisInput+offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
rankDest = ring->devUserRanks[nranks-j];
offset = chunkOffset + rankDest * size;
- WAIT_NEXT;
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + poffset,
- nextOutput + noffset,
- maxOffset, pflag, nflag, llNthreads);
- POST_SIZE;
- ACK_PREV;
-
- NEXT_STEP_LL;
+ LLprims.recvReduceSend(thisInput+offset, nelem);
}
// step k-1: reduce this buffer and data, which will produce the final
@@ -205,13 +115,9 @@ __device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) {
rankDest = ring->devUserRanks[0];
offset = chunkOffset + rankDest * size;
- LL::ReduceCopy(
- thisInput + offset,
- prevInput + poffset,
- thisOutput + chunkOffset,
- maxOffset, pflag, llNthreads);
- ACK_PREV;
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
}
-
- FIFO_CLEANING_AND_SAVE_STEP(nflag);
}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/reduce.cu b/src/collectives/reduce.cu
index d8fde80..302d4bc 100644
--- a/src/collectives/reduce.cu
+++ b/src/collectives/reduce.cu
@@ -4,30 +4,15 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "common_coll.h"
#include "enqueue.h"
#include "collectives.h"
-ncclResult_t ncclReduceFunc(const void* sendbuff, void* recvbuff, const size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- size_t nbytes = count*ncclTypeSize(datatype);
- INFO(NCCL_COLL,"Reduce: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
- if (comm->nRanks == 1) {
- if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
- } else {
- NCCLCHECK(transportSaveProxies(REDUCE_SUBSTEPS, REDUCE_BUFCHUNKS, 1, 1, nbytes, proxyPatternTo(root), comm));
- NCCLCHECK(saveKernel(ncclCollReduce, sendbuff, recvbuff, count, datatype, op, root, comm, stream, nbytes, 1));
- }
-
- return ncclSuccess;
-}
-
NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclReduceFunc, "Reduce", sendbuff, recvbuff, count, datatype,
- op, root, comm, stream);
+ struct ncclInfo info = { ncclCollReduce, "Reduce",
+ sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
+ REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
+ return ncclEnqueueCheck(&info);
}
diff --git a/src/collectives/reduce_scatter.cu b/src/collectives/reduce_scatter.cu
index 1447d4a..4ee77ef 100644
--- a/src/collectives/reduce_scatter.cu
+++ b/src/collectives/reduce_scatter.cu
@@ -4,29 +4,15 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "common_coll.h"
#include "enqueue.h"
#include "collectives.h"
-ncclResult_t ncclReduceScatterFunc(const void* sendbuff, void* recvbuff, size_t count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- size_t nbytes = count*ncclTypeSize(datatype);
- INFO(NCCL_COLL,"ReduceScatter: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, datatype, op, root, comm, comm->nRanks, stream);
- if (comm->nRanks == 1) {
- if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
- } else {
- NCCLCHECK(transportSaveProxies(REDUCESCATTER_SUBSTEPS, REDUCESCATTER_BUFCHUNKS, comm->nRanks-1, comm->nRanks, nbytes*comm->nRanks, proxyPatternRing, comm));
- NCCLCHECK(saveKernel(ncclCollReduceScatter, sendbuff, recvbuff, count, datatype, op, root, comm, stream, nbytes*comm->nRanks, 1));
- }
- return ncclSuccess;
-}
-
NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
- return ncclEnqueueCheck(ncclReduceScatterFunc, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype,
- op, 0, comm, stream);
+ struct ncclInfo info = { ncclCollReduceScatter, "ReduceScatter",
+ sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
+ REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
+ return ncclEnqueueCheck(&info);
}