Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2016-09-22 21:57:56 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2016-09-22 21:57:56 +0300
commitcabd6848e4c07e73f6db2cf74e3db0c1b7191fa9 (patch)
treee06cb9e849628b14970ecd6291e9614e6d0c1711
parente3dbc6110ebefdf5792de0c60fda1d81822d1454 (diff)
Heavy code refactoring to remove a lot of code in collectives (~1000 lines).
Have all collectives use the same args, the same ring, and the same primitives for synchronization between threads with the same pattern.
-rw-r--r--Makefile22
-rw-r--r--src/all_gather.cu600
-rw-r--r--src/all_reduce.cu602
-rw-r--r--src/broadcast.cu477
-rw-r--r--src/common_kernel.h75
-rw-r--r--src/copy_kernel.h2
-rw-r--r--src/core.cu469
-rw-r--r--src/core.h103
-rw-r--r--src/enqueue.h124
-rw-r--r--src/libwrap.cu28
-rw-r--r--src/libwrap.h12
-rw-r--r--src/primitives.h206
-rw-r--r--src/reduce.cu459
-rw-r--r--src/reduce_kernel.h37
-rw-r--r--src/reduce_scatter.cu576
15 files changed, 1480 insertions, 2312 deletions
diff --git a/Makefile b/Makefile
index 7822d29..35e5eef 100644
--- a/Makefile
+++ b/Makefile
@@ -7,7 +7,9 @@
CUDA_HOME ?= /usr/local/cuda
PREFIX ?= /usr/local
VERBOSE ?= 0
+KEEP ?= 0
DEBUG ?= 0
+PROFAPI ?= 0
BUILDDIR ?= build
CUDA_LIB ?= $(CUDA_HOME)/lib64
@@ -19,7 +21,7 @@ NVCC_GENCODE ?= -gencode=arch=compute_35,code=sm_35 \
-gencode=arch=compute_52,code=sm_52 \
-gencode=arch=compute_52,code=compute_52
-CXXFLAGS := -I$(CUDA_INC) -fPIC -fvisibility=hidden
+CXXFLAGS := -I$(CUDA_INC) -fPIC -fvisibility=hidden
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -lineinfo -std=c++11 -maxrregcount 96
# Use addprefix so that we can specify more than one path
LDFLAGS := $(addprefix -L,${CUDA_LIB}) -lcudart -lrt
@@ -39,10 +41,17 @@ else
.SILENT:
endif
+ifneq ($(KEEP), 0)
+NVCUFLAGS += -keep
+endif
+
+ifneq ($(PROFAPI), 0)
+CXXFLAGS += -DPROFAPI
+endif
NCCL_MAJOR := 1
-NCCL_MINOR := 2
-NCCL_PATCH := 3
+NCCL_MINOR := 3
+NCCL_PATCH := 0
CXXFLAGS += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
@@ -50,7 +59,7 @@ CUDA_MAJOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 1)
CUDA_MINOR = $(shell echo $(CUDA_VERSION) | cut -d "." -f 2)
CXXFLAGS += -DCUDA_MAJOR=$(CUDA_MAJOR) -DCUDA_MINOR=$(CUDA_MINOR)
-.PHONY : lib clean debclean test mpitest install
+.PHONY : lib clean test mpitest install deb debian debclean
.DEFAULT : lib
INCEXPORTS := nccl.h
@@ -103,6 +112,7 @@ install : lib
cp -P -v $(BUILDDIR)/lib/* $(PREFIX)/lib/
cp -v $(BUILDDIR)/include/* $(PREFIX)/include/
+
#### TESTS ####
TEST_ONLY ?= 0
@@ -132,7 +142,7 @@ MPITESTBINS:= $(patsubst %, $(MPITSTDIR)/%, $(MPITESTS))
test : $(TESTBINS)
-$(TSTDIR)/% : test/single/%.cu $(TSTDEP)
+$(TSTDIR)/% : test/single/%.cu test/include/*.h $(TSTDEP)
@printf "Building %-25s > %-24s\n" $< $@
mkdir -p $(TSTDIR)
$(NVCC) $(TSTINC) $(NVCUFLAGS) --compiler-options "$(CXXFLAGS)" -o $@ $< $(TSTLIB) -lcuda -lcurand -lnvToolsExt
@@ -144,7 +154,7 @@ $(TSTDIR)/% : test/single/%.cu $(TSTDEP)
mpitest : $(MPITESTBINS)
-$(MPITSTDIR)/% : test/mpi/%.cu $(TSTDEP)
+$(MPITSTDIR)/% : test/mpi/%.cu $(TSTDEP)
@printf "Building %-25s > %-24s\n" $< $@
mkdir -p $(MPITSTDIR)
$(NVCC) $(MPIFLAGS) $(TSTINC) $(NVCUFLAGS) --compiler-options "$(CXXFLAGS)" -o $@ $< $(TSTLIB) -lcurand
diff --git a/src/all_gather.cu b/src/all_gather.cu
index 2a6e5da..2dd6246 100644
--- a/src/all_gather.cu
+++ b/src/all_gather.cu
@@ -1,479 +1,203 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
-#include <algorithm>
-#include <cassert>
-
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
+#include "primitives.h"
+
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
+
+// Increase Step and poffset/noffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ poffset = noffset; \
+ noffset += sliceSize; \
+ if (noffset == buffSize) noffset = 0;
+
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
+
+template<int THREADS, int UNROLL, class FUNC, typename T>
+__launch_bounds__(THREADS+WARP_SIZE, 1)
+__global__ void AllGatherKernel(const KernelArgs<T> args) {
+ const int tid = threadIdx.x;
+ __shared__ T* sharedNextOutput;
+ __shared__ DevRing<T> ring;
+ bool pushrecv = args.pushrecv;
+
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
- * split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
- * single-data-element operations inside an UNROLL. As the name suggests, the
- * UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 2
-
-// If this is called with STEP, it means that we just finished processing the
-// data for step STEP on this GPU, which is the data required on the next GPU
-// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
-// is available. This is called by one particular consumer warp and so we select
-// the first thread in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
- do { \
- __threadfence_system(); \
- args.NextNewDataAvailableFlag[0] = \
- NUM_SUBCHUNKS*((chunk) * (args.NumGPUs - 1) + (step)) + subchunk+1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*((chunk) * (args.NumGPUs - 1) + (step)) \
- + subchunk + 1 - NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- __threadfence_system(); \
- args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + (subchunk) + 1; \
- } while (0)
-
-#define WAIT_FOR_PREV_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int*)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1-NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-__device__ inline void getSliceSizeAndChunkSize(int *sliceSize, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *sliceSize = bigSliceN;
- } else {
- *sliceSize = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- }
-}
-
-template<typename T>
-struct AllGatherKernelArgs {
- // general parameters
- int ThisId;
- int NumGPUs;
- int N;
- int * UserFromRing;
-
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
-
- int BufferSliceStride;
- int BufferMisalignedN;
-
- T ** ThisPtrToNextOutput;
- T ** PrevPtrToThisOutput;
-
- // local and remote input, output, and buffer
- const T * __restrict__ ThisInput;
- volatile T * __restrict__ ThisOutput;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
-
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__device__ inline int GetBlock(const int index, const int step,
- const int * const userFromRing, const int numGPUs) {
- return userFromRing[(numGPUs + index - step) % numGPUs];
-}
-
-__shared__ volatile void * nextOutput;
-
-template<int THREADS, int UNROLL, bool PUSHRECV, typename T>
-__global__ void AllGatherKernel(const AllGatherKernelArgs<T> args) {
- if (args.N == 0) return;
- int tid = threadIdx.x;
-
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr;
- });
-
- *((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput;
-
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr;
- });
-
- if(PUSHRECV)
- nextOutput = *((volatile void * volatile *)args.ThisPtrToNextOutput);
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
+ if (pushrecv) {
+ *ring.sendPtrToPrev = (T*)args.ThisOutput;
+ Wait([=] {
+ return *ring.recvPtrFromNext != nullptr;
+ });
+ sharedNextOutput = *ring.recvPtrFromNext;
+ *ring.recvPtrFromNext = nullptr;
+ }
}
__syncthreads();
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
-
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
-
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
-
- // step 0: copy the resident block from the ThisInput to ThisOutput and also
- // to NextOutput
- int step = 0;
- int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- int outputOffset = chunkOffset + block * args.N;
- int inputOffset = chunkOffset;
- int bufferOffset;
- int sliceSize;
-
- if (!PUSHRECV) {
- bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- block * args.BufferMisalignedN;
- }
-
- // Copy from ThisInput
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- if (!PUSHRECV)
- WAIT_FOR_PREV_CHUNK(chunk, s);
-
- if (PUSHRECV) {
- DoubleCopy<UNROLL, THREADS>(
- args.ThisOutput + outputOffset,
- (volatile T *)nextOutput + outputOffset,
- args.ThisInput + inputOffset,
- sliceSize);
- } else {
- DoubleCopy<UNROLL, THREADS>(
- args.ThisOutput + outputOffset,
- args.NextBuffer + bufferOffset,
- args.ThisInput + inputOffset,
- sliceSize);
- }
- __syncthreads();
-
- outputOffset += sliceSize;
- inputOffset += sliceSize;
- if (!PUSHRECV)
- bufferOffset += sliceSize;
- }
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
+
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T> Prims;
+
+ const int size = args.N;
+ const int nranks = args.nRanks;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int poffset, noffset = 0;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+
+ for (int chunkOffset = 0; chunkOffset < size; chunkOffset += sliceSize) {
+ /////////////// begin AllGather steps ///////////////
+ int offset;
+ int maxOffset = size-chunkOffset;
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring.userRank[0];
+ offset = chunkOffset + rankDest * size;
+
+ if (thisInput == thisOutput) {
+ Prims::Copy(
+ thisInput + offset,
+ pushrecv ? sharedNextOutput + offset : nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
} else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
+ Prims::DoubleCopy(
+ thisInput + chunkOffset,
+ thisOutput + offset,
+ pushrecv ? sharedNextOutput + offset : nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
}
- // steps j with 0 < j < k - 1:
- // copy a block that was pushed to this GPU to the next GPU
- for (step = 1; step < args.NumGPUs - 1; ++step) {
- block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- outputOffset = chunkOffset + block * args.N;
- if (!PUSHRECV) {
- bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- block * args.BufferMisalignedN;
- }
+ NEXT_STEP; // Increases step, poffset, noffset
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- if (PUSHRECV) {
- Copy<UNROLL, THREADS>(
- (volatile T *)nextOutput + outputOffset,
- args.ThisOutput + outputOffset,
- sliceSize);
- } else {
- DoubleCopy<UNROLL, THREADS>(
- args.NextBuffer + bufferOffset,
- args.ThisOutput + outputOffset,
- args.ThisBuffer + bufferOffset,
- sliceSize);
- }
- __syncthreads();
-
- outputOffset += sliceSize;
- if (!PUSHRECV)
- bufferOffset += sliceSize;
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
- }
+ // k-2 steps: copy to next GPU
+ if (pushrecv) {
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring.userRank[nranks-j];
+ offset = chunkOffset + rankDest * size;
- if (!PUSHRECV) {
- step = args.NumGPUs - 1;
- block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- outputOffset = chunkOffset + block * args.N;
- bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- block * args.BufferMisalignedN;
+ Prims::Copy(
+ thisOutput + offset,
+ sharedNextOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
- // Make final copy from buffer to dest.
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- Copy<UNROLL, THREADS>(
- args.ThisOutput + outputOffset,
- args.ThisBuffer + bufferOffset,
- sliceSize);
-
- __syncthreads();
-
- outputOffset += sliceSize;
- bufferOffset += sliceSize;
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_CHUNK_DONE(chunk, s);
- }
+ NEXT_STEP;
}
+ } else {
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring.userRank[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ Prims::DoubleCopy(
+ prevInput + poffset,
+ thisOutput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
+ }
+
+ // Make final copy from buffer to dest.
+ rankDest = ring.userRank[1];
+ offset = chunkOffset + rankDest * size;
+
+ // Here we need to copy from buffer to this output.
+ Prims::Copy(
+ prevInput + poffset,
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
}
// wait for the last data to be pushed to us
- if (tid < THREADS) {
- if (PUSHRECV)
- WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
- else
- WAIT_FOR_PREV_CHUNK(args.NumChunks, NUM_SUBCHUNKS-1);
-
- if (tid == 0) {
- args.ThisNewDataAvailableFlag[0] = 0;
- args.ThisChunkDoneFlag[0] = 0;
- *args.ThisPtrToNextOutput = nullptr;
- }
- }
-}
+ if (tid == 0) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
-template<typename T>
-ncclResult_t ncclAllGatherWithType(const void* sendbuff, void* recvbuff,
- int count, ncclComm* comm, int numUnroll, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
- int index = comm->ncclId;
-
- int blockSizeInBytes = count * sizeof(T);
- int misalignedBytes = blockSizeInBytes % alignof(uint64_t);
-
- assert((int)((misalignedBytes / sizeof(T)) * sizeof(T)) == misalignedBytes);
-
- int misalignedN = misalignedBytes / sizeof(T);
- assert(misalignedN < (int)(sizeof(uint64_t) / sizeof(T)));
-
- int paddingN = (misalignedN > 0) ? sizeof(uint64_t) / sizeof(T) : 0;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- int bufferN = comm->buffSize / sizeof(T);
- // we only need buffer for k slices and k paddings
- int bufferNPerSlice = (bufferN - comm->nDev * NUM_SUBCHUNKS * paddingN)
- / (comm->nDev * NUM_SUBCHUNKS);
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- AllGatherKernelArgs<T> args;
-
- args.ThisId = index;
- args.NumGPUs = comm->nDev;
- args.N = count;
-
- /* Block j is coming from sendbuff[j], which lives on device with logical
- * index comm->ringFromUser[j]. But the block ordering does not necessarily
- * follow the ring ordering. Hence the order in which a particular GPU
- * processes the different blocks (the correspondence between the step in
- * the reduction algorithm and the block on which a GPU operates in that
- * particular step) is not the same as the ring order.
- *
- * Say we have 4 GPUs and comm->userFromRing = { 1, 2, 0, 3 }. Then there are 3
- * step in the all-gather algorithm and block 0 comes from device 2, block 1
- * from 0, block 2 from device 1, and block 3 comes from device 3. In the
- * first step of the algorithm, each GPU must copy its own block from its
- * sendbuff to the appropriate location in its recvbuff. The blocks that a
- * GPU has to process in the next steps is determined by the previous step
- * because each GPU only hands off data to the next GPU in the ring.
- *
- * In the above example, we get the following table of which block is
- * processed by each GPU in a given step. The columns correspond to the
- * different GPUs while the rows are the steps in the algorithm.
- *
- * GPU 0 1 2 3
- * step
- * 0 1 2 0 3
- * 1 3 1 2 0
- * 2 0 3 1 2
- *
- * We note the the rows in the above table are just comm->userFromRing in the
- * first step and the list is cyclicly permuted to the right for each next
- * step. The columns, which are what the individual GPUs need to know, are
- * comm->userFromRing traversed backwards and starting at index k for GPU k.
- * These columns are what we put into args.BlockVsStep to tell the GPU which
- * block it needs to be processing at a particular step. */
- args.UserFromRing = comm->devUserFromRing;
-
- args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
- args.SliceSize = std::min(maxSliceSize, args.SliceSize);
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // don't reduce this if we cut the slice size in half below, because if that
- // happens, the last chunk will be larger than the other chunks, and we will
- // need the extra buffer space
- args.BufferSliceStride = args.SliceSize + paddingN;
-
- args.BufferMisalignedN = misalignedN;
-
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
- }
+ // Wait for last update from prev then reset the flag
+ waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
+ *ring.recvFlagFromPrev = 0;
- args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
+ incrementOpCounter(&args);
+ }
+}
- args.ThisInput = (const T*)sendbuff;
- args.ThisOutput = (volatile T*)recvbuff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
+#define THREADS 384
+#define UNROLL 8
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
+template<class FUNC, typename T>
+ncclResult_t RingAllGather(const void* sendbuff, void* recvbuff,
+ const int count, ncclComm* comm, cudaStream_t stream) {
+ if (count == 0)
+ return ncclSuccess;
- if (comm->nDev == 1) {
+ if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
- if( comm->useRemoteRecv ) {
- AllGatherKernel<NUM_THREADS, UNROLL_COUNT, true, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- AllGatherKernel<NUM_THREADS, UNROLL_COUNT, false, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
+ KernelArgs<T> args;
+ ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
+ LAUNCH_KERNEL(AllGatherKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
+
return ncclSuccess;
}
-class AllGatherFunctor {
-public:
- ncclResult_t operator()(const void* sendbuff, void* recvbuff,
- int count, ncclDataType_t datatype, ncclRedOp_t /*dummy operation*/,
- int /*dummy root*/, ncclComm* comm, cudaStream_t stream) {
- int numUnroll = 16; // this is optimal on dt07 with 4 GPUs
-
- switch (datatype) {
- case ncclChar:
- return ncclAllGatherWithType<char>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
- case ncclInt:
- return ncclAllGatherWithType<int>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
-#if CUDART_VERSION >= 7050
- case ncclHalf:
- return ncclAllGatherWithType<half>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
-#endif
- case ncclFloat:
- return ncclAllGatherWithType<float>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
- case ncclDouble:
- return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
- case ncclInt64:
- return ncclAllGatherWithType<long long>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
- case ncclUint64:
- return ncclAllGatherWithType<unsigned long long>(sendbuff, recvbuff, count, comm,
- numUnroll, stream);
- }
- return ncclInvalidType;
+template<typename T, template<typename> class RedOp>
+class AllGather {
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
+ return RingAllGather<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, int count, ncclDataType_t datatype,
+ void* recvbuff, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, int count, ncclDataType_t datatype,
void* recvbuff, ncclComm_t comm, cudaStream_t stream) {
- return enqueue(AllGatherFunctor(), sendbuff, recvbuff, count, datatype,
- ncclSum, 0, comm, stream);
+ return enqueue<AllGather, FuncNull>(sendbuff, recvbuff, count, datatype, 0, comm, stream);
}
+
diff --git a/src/all_reduce.cu b/src/all_reduce.cu
index 01dc09f..86a297f 100644
--- a/src/all_reduce.cu
+++ b/src/all_reduce.cu
@@ -1,491 +1,233 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
-#include "reduce_kernel.h"
+#include "primitives.h"
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is an independent, complete reduction. Each
- * GPU has a buffer that can fit an entire CHUNK, so that all SUBCHUNKS can be
- * processed without checking that the buffer on the receiving GPU is empty. A
- * SUBCHUNK is split into NUM_GPUS SLICES and each GPU works on a different
- * SLICE at the same time. Before moving on the the next SLICE in the reduction
- * algorithm, the GPU has to check whether it has received the data from the
- * previous GPU it needs for this SLICE. To hide the latency of this
- * communication, each GPU processes all the SLICES of all the SUBCHUNKS in
- * sequence before moving on to the next SLICE. Each SLICE is split into a
- * certain number of UNROLLS (determined by the buffer size) and each thread
- * performs UNROLL_COUNT single-data-element operations inside an UNROLL. As the
- * name suggests, the UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 2
-
-
-// If this is called with STEP, it means that we just finished processing the
-// data for step STEP on this GPU, which is the data required on the next GPU
-// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
-// is available. This is called by one particular consumer warp and so we select
-// the first thread in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
- do { \
- __threadfence_system(); \
- args.NextNewDataAvailableFlag[0] = \
- NUM_SUBCHUNKS*((chunk) * (2 * args.NumGPUs - 2) + (step) + 1)+subchunk; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- 2*((chunk) * (2 * args.NumGPUs - 2) + (step))+subchunk; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-#define WAIT_FOR_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-
-__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *size = bigSliceN;
- *offset = slice * bigSliceN;
- } else {
- *size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- *offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
- }
-}
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
-template<typename T>
-struct AllReduceKernelArgs {
- // general parameters
- int ThisId;
- int NumGPUs;
- int N;
-
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
-
- T ** ThisPtrToNextOutput;
- T ** PrevPtrToThisOutput;
-
- // local and remote input, output, and buffer
- const T * __restrict__ ThisInput;
- volatile T * __restrict__ ThisOutput;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
-
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__shared__ volatile void * nextOutput;
+// Increase Step and poffset/noffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ poffset = noffset; \
+ noffset += sliceSize; \
+ if (noffset == buffSize) noffset = 0;
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
-template<int THREADS, int UNROLL, class FUNC, bool PUSHRECV, typename T>
+template<int THREADS, int UNROLL, class FUNC, typename T>
__launch_bounds__(THREADS+WARP_SIZE, 1)
-__global__ void AllReduceKernel(const AllReduceKernelArgs<T> args) {
- if (args.N == 0) return;
+__global__ void AllReduceKernel(const KernelArgs<T> args) {
const int tid = threadIdx.x;
+ __shared__ T* sharedNextOutput;
+ __shared__ DevRing<T> ring;
+ bool pushrecv = args.pushrecv;
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
- if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr;
- });
-
- *((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput;
-
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr;
- });
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
- if (PUSHRECV)
- nextOutput =
- *((volatile void * volatile *)args.ThisPtrToNextOutput);
+ if (tid == 0) {
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
+ if (pushrecv) {
+ *ring.sendPtrToPrev = (T*)args.ThisOutput;
+ Wait([=] {
+ return *ring.recvPtrFromNext != nullptr;
+ });
+ sharedNextOutput = *ring.recvPtrFromNext;
+ *ring.recvPtrFromNext = nullptr;
+ }
}
__syncthreads();
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = args.NumGPUs * NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
+ const int size = args.N;
+ const int nranks = args.nRanks;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int poffset, noffset = 0;
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+ for (int chunkOffset = 0; chunkOffset < size; chunkOffset += nranks*sliceSize) {
/////////////// begin AllReduce steps ///////////////
-
- // step 0: push data to next GPU
- int step = 0;
- int slice = args.ThisId;
int offset;
- int sliceSize;
-
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- if (s > 0) { slice += args.NumGPUs; }
- getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- if (!PUSHRECV && chunk > 0) {
- WAIT_FOR_CHUNK(chunk, s);
- }
-
- Copy<UNROLL, THREADS>(
- args.NextBuffer + offset,
- args.ThisInput + chunkOffset + offset,
- sliceSize);
-
- __syncthreads();
- }
- } else { // is consumer thread
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
+ int maxOffset;
+ int slice;
- // steps j with 1 <= j < k - 1, where k = number of GPUs:
- // reduce and copy to next GPU
- for (step = 1; step < args.NumGPUs - 1; ++step) {
- if (tid < THREADS) {
- slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- if (s > 0) { slice += args.NumGPUs; }
- getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- Reduce<UNROLL, THREADS, FUNC>(
- args.NextBuffer + offset,
- args.ThisBuffer + offset,
- args.ThisInput + chunkOffset + offset,
- sliceSize);
-
- __syncthreads();
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
+ // step 0: push data to next GPU
+ slice = ring.userRank[nranks-1];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ Prims::Copy(
+ thisInput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP; // Increases step, poffset, noffset
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ slice = ring.userRank[nranks-j];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ Prims::Reduce(
+ prevInput + poffset,
+ thisInput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
// step k - 1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
- step = args.NumGPUs - 1;
-
- if (tid < THREADS) {
- slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- if (s > 0) { slice += args.NumGPUs; }
- getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- if (PUSHRECV) {
- ReduceAndCopy<UNROLL, THREADS, FUNC>(
- (volatile T *)nextOutput + chunkOffset + offset,
- args.ThisOutput + chunkOffset + offset,
- args.ThisBuffer + offset,
- args.ThisInput + chunkOffset + offset,
- sliceSize);
- } else {
- ReduceAndCopy<UNROLL, THREADS, FUNC>(
- args.NextBuffer + offset,
- args.ThisOutput + chunkOffset + offset,
- args.ThisBuffer + offset,
- args.ThisInput + chunkOffset + offset,
- sliceSize);
- }
-
- __syncthreads();
+ slice = ring.userRank[0];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ Prims::ReduceCopy(
+ prevInput + poffset,
+ thisInput + offset,
+ pushrecv ? (sharedNextOutput + offset) : (nextOutput + noffset),
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
+
+ if (pushrecv) {
+ // k-2 steps: copy result to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ slice = ring.userRank[nranks - j];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ Prims::Copy(
+ thisOutput + offset,
+ sharedNextOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
} else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
+ // k-2 steps: copy result to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ slice = ring.userRank[nranks - j];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ Prims::DoubleCopy(
+ prevInput + poffset,
+ thisOutput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
- }
- // steps j with k <= j < 2*k-2: copy result to next GPU
- for (step = args.NumGPUs; step < 2 * args.NumGPUs - 2; ++step) {
- if (tid < THREADS) {
- slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- if (s > 0) { slice += args.NumGPUs; }
- getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- if( PUSHRECV ) {
- Copy<UNROLL, THREADS>(
- (volatile T *)nextOutput + chunkOffset + offset,
- args.ThisOutput + chunkOffset + offset,
- sliceSize);
- } else {
- DoubleCopy<UNROLL, THREADS>(
- args.NextBuffer + offset,
- args.ThisOutput + chunkOffset + offset,
- args.ThisBuffer + offset,
- sliceSize);
- }
-
- __syncthreads();
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
- }
-
- if (!PUSHRECV) {
// Make final copy from buffer to dest.
- if (tid < THREADS) {
- slice = (args.NumGPUs + slice - 1) % args.NumGPUs;
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- if (s > 0) { slice += args.NumGPUs; }
- getSliceSizeAndOffset(&sliceSize, &offset, slice, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- WAIT_FOR_NEW_DATA(chunk, s, step);
-
- // Here we need to copy from buffer to this output.
- Copy<UNROLL, THREADS>(
- args.ThisOutput + chunkOffset + offset,
- args.ThisBuffer + offset,
- sliceSize);
-
- __syncthreads();
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- if(chunk+1 < args.NumChunks) {
- SIGNAL_CHUNK_DONE(chunk, s);
- }
- }
- }
+ slice = ring.userRank[1];
+ offset = chunkOffset + slice * sliceSize;
+ maxOffset = size-offset;
+
+ // Here we need to copy from buffer to this output.
+ Prims::Copy(
+ prevInput + poffset,
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
}
// wait for the last data to be pushed to us
- if (tid < THREADS) {
- if(PUSHRECV) {
- WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
- }
+ if (tid == 0) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
- if (tid == 0) {
- args.ThisNewDataAvailableFlag[0] = 0;
- if(!PUSHRECV) {
- args.ThisChunkDoneFlag[0] = 0;
- }
- *args.ThisPtrToNextOutput = nullptr;
- }
+ // Wait for last update from prev then reset the flag
+ waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
+ *ring.recvFlagFromPrev = 0;
+
+ incrementOpCounter(&args);
}
}
+#define THREADS 512
+#define UNROLL 8
+
template<class FUNC, typename T>
-ncclResult_t ncclAllReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
+ncclResult_t RingAllReduce(const void* sendbuff, void* recvbuff,
const int count, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
- int index = comm->ncclId;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int bufferN = comm->buffSize / sizeof(T);
- int bufferNPerSlice = bufferN / (NUM_SUBCHUNKS * comm->nDev);
- int sliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
-
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- AllReduceKernelArgs<T> args;
-
- args.ThisId = index;
- args.NumGPUs = comm->nDev;
- args.N = count;
-
- args.SliceSize = sliceSize;
- int subchunkSize = comm->nDev * args.SliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * subchunkSize;
-
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- int subchunkSize = comm->nDev * args.SliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * subchunkSize;
-
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
- }
-
- args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
- args.ThisInput = (const T*)sendbuff;
- args.ThisOutput = (volatile T*)recvbuff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
-
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
-
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
-
- if (comm->nDev == 1) {
+ if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
- if( comm->useRemoteRecv ) {
- AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, true, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- AllReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, false, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
+ KernelArgs<T> args;
+ ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
+ LAUNCH_KERNEL(AllReduceKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
- return ncclSuccess;
-}
-
-template<typename T>
-ncclResult_t ncclAllReduceWithType(const void* sendbuff,
- void* recvbuff, int count, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
- switch (op) {
- case ncclSum:
- return ncclAllReduceWithTypeAndFunc<FuncSum<T>, T>(
- sendbuff, recvbuff, count, comm, stream);
- case ncclProd:
- return ncclAllReduceWithTypeAndFunc<FuncProd<T>, T>(
- sendbuff, recvbuff, count, comm, stream);
- case ncclMax:
- return ncclAllReduceWithTypeAndFunc<FuncMax<T>, T>(
- sendbuff, recvbuff, count, comm, stream);
- case ncclMin:
- return ncclAllReduceWithTypeAndFunc<FuncMin<T>, T>(
- sendbuff, recvbuff, count, comm, stream);
- }
- return ncclInvalidOperation;
+ return ncclSuccess;
}
-class AllReduceFunctor {
-public:
- ncclResult_t operator()(const void* sendbuff, void* recvbuff,
- int count, ncclDataType_t datatype, ncclRedOp_t op, int /*root*/,
- ncclComm* comm, cudaStream_t stream) {
-
- switch (datatype) {
- case ncclChar:
- return ncclAllReduceWithType<char>(sendbuff, recvbuff, count, op,
- comm, stream);
- case ncclInt:
- return ncclAllReduceWithType<int>(sendbuff, recvbuff, count, op,
- comm, stream);
-#if CUDART_VERSION >= 7050
- case ncclHalf:
- return ncclAllReduceWithType<half>(sendbuff, recvbuff, count, op,
- comm, stream);
-#endif
- case ncclFloat:
- return ncclAllReduceWithType<float>(sendbuff, recvbuff, count, op,
- comm, stream);
- case ncclDouble:
- return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
- comm, stream);
- case ncclInt64:
- return ncclAllReduceWithType<long long>(sendbuff, recvbuff, count, op,
- comm, stream);
- case ncclUint64:
- return ncclAllReduceWithType<unsigned long long int>(sendbuff, recvbuff, count, op,
- comm, stream);
- }
-
- return ncclInvalidType;
+template<typename T, template <typename> class RedOp>
+class AllReduce {
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
+ return RingAllReduce<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, int count,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
- return enqueue(AllReduceFunctor(), sendbuff, recvbuff, count, datatype, op, 0,
- comm, stream);
+ return enqueue<AllReduce>(sendbuff, recvbuff, count, datatype, op, 0, comm, stream);
}
diff --git a/src/broadcast.cu b/src/broadcast.cu
index a57f4da..5843922 100644
--- a/src/broadcast.cu
+++ b/src/broadcast.cu
@@ -1,392 +1,165 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
-#include <algorithm>
-
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
+#include "primitives.h"
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
- * split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
- * single-data-element operations inside an UNROLL. As the name suggests, the
- * UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 4
-
-// if this is called with CHUNK, it means that we just finished pushing the data
-// of chunk CHUNK to the next GPU, so it can proceed with CHUNK
-// We add 1 to chunk so that the initial flag of 0 doesn't allow the non-root
-// GPUs to proceed before the flag is incremented from the upstream GPU. This
-// is called by one particular consumer warp and so we select the first thread
-// in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk) \
- do { \
- __threadfence_system(); \
- args.NextNewDataAvailableFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-#define WAIT_FOR_NEW_DATA(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// If this is called with CHUNK, it means that this GPU has just finished
-// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- bool newDataAvailable = \
- ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- bool chunkDone = \
- ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk)+subchunk + 1 - NUM_SUBCHUNKS; \
- return newDataAvailable && chunkDone; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *size = bigSliceN;
- *offset = slice * bigSliceN;
- } else {
- *size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- *offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
- }
-
-// if (threadIdx.x == 0)
-// printf("[size=%d] [offset=%d] slice=%d numSlices=%d "
-// "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
-// "lastSliceN=%d\n", *size, *offset, slice, numSlices, numBigSlices,
-// numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-}
-
-template<typename T>
-struct BroadcastKernelArgs {
- // general parameters
- int ThisId;
- int N;
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
- int BufferSliceStride;
+// Increase Step and boffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ boffset += sliceSize; \
+ if (boffset == buffSize) boffset = 0;
- T ** ThisPtrToNextData;
- T ** PrevPtrToThisData;
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
- // local and remote data
- T * __restrict__ ThisData;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
+template<int THREADS, int UNROLL, class FUNC, typename T>
+__launch_bounds__(THREADS+WARP_SIZE, 1)
+__global__ void BroadcastKernel(const KernelArgs<T> args) {
+ const int tid = threadIdx.x;
+ __shared__ T* sharedNextOutput;
+ __shared__ DevRing<T> ring;
+ bool pushrecv = args.pushrecv;
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__shared__ volatile void * nextData;
-enum BcastRole {ROOT=0, MIDDLE=1, END=2};
-
-template<int THREADS, int UNROLL, bool PUSHRECV, int ROLE, typename T>
-__global__ void BroadcastKernel(const BroadcastKernelArgs<T> args) {
- if (args.N == 0) return;
- int tid = threadIdx.x;
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisData) == nullptr; // Wait for previous processor to be done
- });
-
- *((T * volatile *)args.PrevPtrToThisData) = (T*)args.ThisData; // Tell Previous I'm starting
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextData) != nullptr; // Wait till I've been told next started
- });
-
- if (PUSHRECV)
- nextData = *((volatile void * volatile *)args.ThisPtrToNextData); // Grab next's pointer if needed.
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
+ if (pushrecv) {
+ *ring.sendPtrToPrev = (T*)args.ThisOutput;
+ Wait([=] {
+ return *ring.recvPtrFromNext != nullptr;
+ });
+ sharedNextOutput = *ring.recvPtrFromNext;
+ *ring.recvPtrFromNext = nullptr;
+ }
}
__syncthreads();
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
-
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
-
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
-
- int offset;
- int sliceSize;
-
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndOffset(&sliceSize, &offset, s, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- if (PUSHRECV) {
- if (ROLE != ROOT)
- WAIT_FOR_NEW_DATA(chunk, s);
-
- if (ROLE != END)
- Copy<UNROLL, THREADS>(
- (volatile T *)nextData + chunkOffset + offset,
- args.ThisData + chunkOffset + offset,
- sliceSize);
- } else { // PUSH2BUFF
- if (ROLE == ROOT) {
- WAIT_FOR_CHUNK(chunk, s);
-
- Copy<UNROLL, THREADS>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- sliceSize);
- } else if (ROLE == MIDDLE) {
- WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, s);
-
- DoubleCopy<UNROLL, THREADS>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- } else { // ROLE == END
- WAIT_FOR_NEW_DATA(chunk, s);
-
- Copy<UNROLL, THREADS>(
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- }
- }
- __syncthreads();
- }
- } else { // Consumer thread
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- if (ROLE != END)
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s);
-
- // signal chunk done if we don't push into the receive buffer and this
- // is no the last chunk and this is not root
- if ((!PUSHRECV) && (ROLE != ROOT) && (chunk + 1 < args.NumChunks)) {
- SIGNAL_CHUNK_DONE(chunk, s);
- }
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, (1-NUM_BUFCHUNKS)*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, 0);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, 0);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
+
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T> Prims;
+
+ const int size = args.N;
+ const int rank = ring.userRank[0];
+ const int nextRank = ring.userRank[1];
+ const int root = args.root;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int boffset = 0;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+
+ for (int offset = 0; offset < size; offset += sliceSize) {
+ int maxOffset = size-offset;
+ if (rank == root) {
+ Prims::Copy(
+ thisInput + offset,
+ pushrecv ? sharedNextOutput + offset : nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext,
+ postReadyToNext);
+ } else if (nextRank == root) {
+ if (pushrecv) maxOffset = 0; // Only wait for signals
+ Prims::Copy(
+ prevInput + boffset,
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitReadyFromPrev,
+ postDoneToPrev);
+ } else {
+ if (pushrecv) {
+ Prims::Copy(
+ thisOutput + offset,
+ sharedNextOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+ } else {
+ Prims::DoubleCopy(
+ prevInput + boffset,
+ thisOutput + offset,
+ nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
}
}
+ NEXT_STEP; // Increases step, boffset
}
- // reset flags
+ // wait for the last data to be pushed to us
if (tid == 0) {
- args.ThisNewDataAvailableFlag[0] = 0;
- args.ThisChunkDoneFlag[0] = 0;
- *args.ThisPtrToNextData = nullptr;
- }
-}
-
-template<typename T>
-ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
- ncclComm* comm, int numUnroll, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
- int index = comm->ncclId;
- int rootId = comm->ringFromUser[root];
-
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int bufferN = comm->buffSize / sizeof(T);
- // we only need buffer for k slices and k paddings
- int bufferNPerSlice = bufferN / NUM_SUBCHUNKS;
- int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
-
- BroadcastKernelArgs<T> args;
-
- args.ThisId = index;
- args.N = count;
+ if (nextRank != root) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
+ }
- args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
+ if (rank != root) {
+ // reset the flag
+ *ring.recvFlagFromPrev = 0;
+ }
- // if we don't directly push into the remote receive buffer, make sure slice
- // fits into the temporary buffer
- if (!comm->useRemoteRecv) {
- // Larger transfers help QPI more than tag updates hurt P2P.
- args.SliceSize *= 8;
+ incrementOpCounter(&args);
}
+}
- args.SliceSize = std::min(maxSliceSize, args.SliceSize);
- args.BufferSliceStride = args.SliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
+#define THREADS 256
+#define UNROLL 8
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
+template<class FUNC, typename T>
+ncclResult_t RingBroadcast(void* buff, const int count, const int root,
+ ncclComm* comm, cudaStream_t stream) {
+ if (count == 0)
+ return ncclSuccess;
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
+ if (comm->nRanks != 1) {
+ KernelArgs<T> args;
+ ArgsSetup(&args, buff, buff, root, count, comm);
+ LAUNCH_KERNEL(BroadcastKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
-// printf("sliceSize = %i, chunkSize = %i, numChunks = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks);
-
- args.ThisPtrToNextData = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisData = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
-
- args.ThisData = (T*)buff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
-
- // we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
- // to signal the next GPU that new data is available and the following
- // NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
-
- if (comm->nDev != 1) {
- if (comm->useRemoteRecv) {
- if (index == (rootId + comm->nDev - 1) % comm->nDev) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else if (index == rootId) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
- } else {
- if (index == (rootId + comm->nDev - 1) % comm->nDev) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else if (index == rootId) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
- }
- }
return ncclSuccess;
}
-class BroadcastFunctor {
-public:
- ncclResult_t operator()(const void* /*dummy sendbuff*/,
- void* buff, int count, ncclDataType_t datatype, ncclRedOp_t /*dummy operation*/,
- int root, ncclComm* comm, cudaStream_t stream) {
- int numUnroll = 4;
-
- switch (datatype) {
- case ncclChar:
- return ncclBcastWithType<char>(buff, count, root, comm, numUnroll, stream);
- case ncclInt:
- return ncclBcastWithType<int>(buff, count, root, comm, numUnroll, stream);
-#ifdef CUDA_HAS_HALF
- case ncclHalf:
- return ncclBcastWithType<half>(buff, count, root, comm, numUnroll, stream);
-#endif
- case ncclFloat:
- return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
- case ncclDouble:
- return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
- case ncclInt64:
- return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
- case ncclUint64:
- return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
- }
- return ncclInvalidType;
+template<typename T, template<typename> class RedOp>
+class Broadcast {
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int root, ncclComm* comm, cudaStream_t stream) {
+ return RingBroadcast<RedOp<T>, T>(recvbuff, count, root, comm, stream);
}
};
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype, int root,
+ ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
- return enqueue(BroadcastFunctor(), nullptr, buff, count, datatype, ncclSum,
- root, comm, stream);
+ return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream);
}
diff --git a/src/common_kernel.h b/src/common_kernel.h
index c213575..95c9eb4 100644
--- a/src/common_kernel.h
+++ b/src/common_kernel.h
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
@@ -245,7 +245,7 @@ __device__ inline void ReduceOrCopy(const int tid,
volatile T * __restrict__ dest0, volatile T * __restrict__ dest1,
const volatile T * __restrict__ src0, const volatile T * __restrict__ src1,
int N) {
- if (N==0) {
+ if (N<=0) {
return;
}
@@ -455,5 +455,76 @@ __device__ inline void CalcLastChunk(int * const bigSliceN,
*numSmallSlices + 1;
}
+// Kernel launch
+template<typename T>
+struct KernelArgs {
+ // general parameters
+ int nRanks;
+ int root;
+ int buffSize;
+ int N;
+ int opIndex;
+ volatile int * __restrict__ opCounter;
+ bool pushrecv;
+
+ // some pre-computed sizes
+ int SliceSize;
+ int SliceOffset;
+ int ChunkSize;
+ int NumChunks;
+
+ // local and remote input, output, and buffer
+ const T * __restrict__ ThisInput;
+ T * __restrict__ ThisOutput;
+
+ DevRing<char>* ring;
+};
+
+template<typename T>
+void ArgsSetup(KernelArgs<T> *args, const void* sendbuff, void* recvbuff,
+ const int root, const int count, ncclComm *comm) {
+ args->nRanks = comm->nRanks;
+ args->root = root;
+ args->buffSize = comm->buffSize;
+ args->N = count;
+ args->opIndex = comm->opSched;
+ args->opCounter = comm->opCounter;
+ args->ThisInput = (const T*)sendbuff;
+ args->ThisOutput = (T*)recvbuff;
+ args->ring = comm->devRing;
+ args->pushrecv = comm->globalMemSpace;
+}
+
+#define LAUNCH_KERNEL(K, THREADS, UNROLL, FUNC, T, \
+ args, stream) do { \
+ dim3 grid(1, 1, 1); \
+ dim3 block(THREADS+1, 1, 1); \
+ void* argptrs[] = {&args}; \
+ CUDACHECK(cudaLaunchKernel( \
+ (void*)K<THREADS, UNROLL, FUNC, T>, \
+ grid, block, argptrs, 0, stream)); \
+} while (0)
+
+template <typename T>
+__device__ inline void incrementOpCounter(const KernelArgs<T> *args) {
+ // increment comm's operation counts
+ __threadfence_system(); // Technically need to ensure that cleared flags
+ // are visible before incrementing op counter.
+ *args->opCounter = args->opIndex+1;
+}
+
+template <int THREADS, typename T> __device__ __forceinline__
+void LoadRing(const DevRing<char>* src, DevRing<T>* dst) {
+ enum { NUM_WORDS = sizeof(DevRing<char>) / sizeof(long long) };
+ static_assert(sizeof(DevRing<char>) % sizeof(long long) == 0, "Bad alignment");
+ static_assert(THREADS >= NUM_WORDS, "Not enough threads to load DevRing");
+ static_assert(sizeof(DevRing<char>) == sizeof(DevRing<T>), "DevRing size mismatch");
+ long long* lldst = reinterpret_cast<long long*>(dst);
+ const long long* llsrc = reinterpret_cast<const long long*>(src);
+ if (threadIdx.x < NUM_WORDS) {
+ lldst[threadIdx.x] = llsrc[threadIdx.x];
+ }
+}
+
#endif // COMMON_KERNEL_H_
diff --git a/src/copy_kernel.h b/src/copy_kernel.h
index 0ef39c2..8464699 100644
--- a/src/copy_kernel.h
+++ b/src/copy_kernel.h
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
diff --git a/src/core.cu b/src/core.cu
index 2eca735..be4be06 100644
--- a/src/core.cu
+++ b/src/core.cu
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
#include <stdio.h>
@@ -20,7 +20,7 @@
DebugLevel ncclDebugLevel;
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclGetUniqueId, ncclUniqueId* out);
ncclResult_t ncclGetUniqueId(ncclUniqueId* out) {
pid_t pid = getpid();
static int count = 0;
@@ -83,7 +83,7 @@ typedef struct {
int rank;
int ndev;
int cudaDev;
- int ncclId;
+ int sortId;
pid_t pid;
ncclMem* hostptr;
ncclMem* devptr;
@@ -94,15 +94,13 @@ typedef struct {
static int compRanks(const void* a, const void* b) {
const RankEntry* A = (const RankEntry*)a;
const RankEntry* B = (const RankEntry*)b;
- if (A->ncclId < B->ncclId) return -1;
- if (A->ncclId > B->ncclId) return 1;
+ if (A->sortId < B->sortId) return -1;
+ if (A->sortId > B->sortId) return 1;
return 0;
}
static void orderRanks(RankEntry* ranks, int count) {
qsort(ranks, count, sizeof(RankEntry), compRanks);
- for(int i=0; i<count; ++i)
- ranks[i].ncclId = i;
}
@@ -110,7 +108,7 @@ typedef struct {
union {
struct {
volatile int bar;
- int ringDirectFail;
+ int globalMemSpaceBroke;
};
char pad[16];
};
@@ -156,7 +154,7 @@ static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId,
return ncclSuccess;
}
-static void syncRingDirect(RankGather* gather, int* ringDirectOk) {
+static void syncRingDirect(RankGather* gather, int* globalMemSpaceOk) {
int bar_tmp = gather->bar - 1;
int ndev = gather->ranks[0].ndev;
bool swapped;
@@ -169,7 +167,7 @@ static void syncRingDirect(RankGather* gather, int* ringDirectOk) {
sched_yield();
__sync_synchronize();
- *ringDirectOk = gather->ringDirectFail ? 0 : 1;
+ *globalMemSpaceOk = gather->globalMemSpaceBroke ? 0 : 1;
}
static ncclResult_t closeGather(RankGather* gather, int ndev) {
@@ -264,13 +262,13 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm)
return ncclUnhandledCudaError;
}
// Order by nvml index
- if (wrapNvmlDeviceGetIndex(nvmlHandle, (unsigned*)&info->ncclId) != ncclSuccess) {
+ if (wrapNvmlDeviceGetIndex(nvmlHandle, (unsigned*)&info->sortId) != ncclSuccess) {
WARN("rank %d failed to get nvml device index for device %d", rank, comm->cudaDev);
return ncclUnhandledCudaError;
}
info->rank = rank;
- info->ndev = comm->nDev;
+ info->ndev = comm->nRanks;
info->cudaDev = comm->cudaDev;
info->pid = getpid();
info->buffSize = comm->buffSize;
@@ -285,109 +283,104 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm)
}
-static const int CLEANUP_NONE = 0;
-static const int CLEANUP_CUIPC = 1;
-static const int CLEANUP_UNMAP = 2;
-
static ncclResult_t commClearMaps(ncclComm_t comm) {
ncclResult_t res, retval = ncclSuccess;
cudaError_t cures;
- for(int d=0; d<comm->nDev; ++d) {
- switch(comm->ptrs[d].remoteCleanup) {
- case CLEANUP_NONE:
- break;
- case CLEANUP_CUIPC:
- cures = cudaIpcCloseMemHandle((void*)comm->ptrs[d].cleanupHandle);
- if (cures != cudaSuccess) {
- WARN("rank %d failed to close IPC handle to rank %d",
- comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
- retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
- }
- break;
- case CLEANUP_UNMAP:
- cures = cudaHostUnregister(comm->ptrs[d].cleanupHandle);
- if (cures != cudaSuccess) {
- WARN("rank %d failed to unregister handle to rank %d",
- comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
+ for(int d=0; d<comm->nRanks; ++d) {
+ if (comm->ptrs[d].hostCleanup != NULL) {
+ cures = cudaHostUnregister(comm->ptrs[d].hostCleanup);
+ if (cures != cudaSuccess) {
+ WARN("rank %d failed to unregister handle to device %d",
+ comm->rank, d);
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
- }
- res = shmUnmap(comm->ptrs[d].cleanupHandle, offsetof(ncclMem, buff) + comm->buffSize);
- if (res != ncclSuccess) {
- WARN("rank %d failed to unmap handle to rank %d",
- comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
+ }
+ res = shmUnmap(comm->ptrs[d].hostCleanup, offsetof(ncclMem, buff) + comm->buffSize);
+ if (res != ncclSuccess) {
+ WARN("rank %d failed to unmap handle to device %d",
+ comm->rank, d);
retval = (retval == ncclSuccess) ? res : retval;
- }
- break;
- default:
- WARN("Unknown cleanup type %d", comm->ptrs[d].remoteCleanup);
+ }
+ comm->ptrs[d].hostCleanup = NULL;
+ }
+
+ if (comm->ptrs[d].devCleanup != NULL) {
+ cures = cudaIpcCloseMemHandle((void*)comm->ptrs[d].devCleanup);
+ if (cures != cudaSuccess) {
+ WARN("rank %d failed to close IPC handle to device %d: %s",
+ comm->rank, d, cudaGetErrorString(cures));
+ retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
+ }
}
- comm->ptrs[d].remoteCleanup = CLEANUP_NONE;
- comm->ptrs[d].cleanupHandle = NULL;
}
if (comm->userFromRing != NULL)
- memset(comm->userFromRing, 0, sizeof(int)*comm->nDev);
- if (comm->ringFromUser != NULL)
- memset(comm->ringFromUser, 0, sizeof(int)*comm->nDev);
+ memset(comm->userFromRing, 0, sizeof(int)*comm->nRanks);
+ if (comm->ncclFromRing != NULL)
+ memset(comm->ncclFromRing, 0, sizeof(int)*comm->nRanks);
if (comm->devUserFromRing != NULL) {
- cudaError_t err = cudaMemset(comm->devUserFromRing, 0, sizeof(int)*comm->nDev);
- if (err != cudaSuccess) {
- WARN("Faild to clear dev map: %s", cudaGetErrorString(err));
+ cures = cudaMemset(comm->devUserFromRing, 0, sizeof(int)*comm->nRanks);
+ if (cures != cudaSuccess) {
+ WARN("Faild to clear dev map: %s", cudaGetErrorString(cures));
+ retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
+ }
+ }
+
+ if (comm->devRing != NULL) {
+ cures = cudaMemset(comm->devRing, 0, sizeof(DevRing<char>));
+ if (cures != cudaSuccess) {
+ WARN("Failed to clear devRing: %s", cudaGetErrorString(cures));
retval = (retval == ncclSuccess) ? ncclUnhandledCudaError : retval;
}
}
return retval;
}
-static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int rank, RankEntry* ranks, int* ringDirectFailed) {
- int ndev = comm->nDev;
+static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int rank, RankEntry* ranks, int* globalMemSpaceBroke) {
+ int ndev = comm->nRanks;
+ comm->rank = rank;
+
+ if (ndev > MAXRANKS) {
+ WARN("%d ranks exceeds MAXRANKS of %d", ndev, MAXRANKS);
+ return ncclUnsupportedDeviceCount;
+ }
+
+ // Check for inconsistencies between ranks
+ // If two ranks use the same rank, then one slot of
+ // ranks[] will be left unset with zero ndev/buffSize.
for(int i=0; i<ndev; ++i) {
- // Check for inconsistencies between ranks
- // If two ranks use the same rank, then one slot of
- // ranks[] will be left unset with zero ndev/buffSize.
if (ranks[i].buffSize != comm->buffSize
- || ranks[i].ndev != comm->nDev) {
+ || ranks[i].ndev != comm->nRanks) {
commClearMaps(comm);
return ncclRankMismatch;
}
-
- // Create rank<->nccl maps
- int iRank = ranks[i].rank;
- comm->userFromRing[i] = iRank;
- comm->ringFromUser[iRank] = i;
}
- if (cudaMemcpy(comm->devUserFromRing, comm->userFromRing, ndev*sizeof(int),
- cudaMemcpyHostToDevice) != cudaSuccess) {
- WARN("rank %d failed to copy maps to device", rank);
- commClearMaps(comm);
- return ncclUnhandledCudaError;
- }
-
- int myId = -1;
+ // Find self among ranks of gather
+ int myNcclId = -1;
for (int i=0; i<ndev; ++i) {
if(ranks[i].rank == rank) {
- myId = i;
+ myNcclId = i;
break;
}
}
-
- if (myId == -1) {
+ if (myNcclId == -1) {
WARN("rank %d not found in communicator", rank);
return ncclInvalidRank;
}
- comm->ncclId = myId;
- int myDev = ranks[myId].cudaDev;
- pid_t myPid = ranks[myId].pid;
- comm->useRemoteRecv = 1; // Assume we directly write to result ptrs.
+ for(int ringPos=0; ringPos<ndev; ++ringPos) {
+ int ncclPos = (ringPos+myNcclId) % ndev; // ring order relative to self
+ int userRank = ranks[ncclPos].rank;
+ comm->userFromRing[ringPos] = userRank;
+ comm->ncclFromRing[ringPos] = ncclPos;
+ }
+
+ int myDev = ranks[myNcclId].cudaDev;
+ pid_t myPid = ranks[myNcclId].pid;
- // The order that we link with peers must ensure that
- // P2P slots are used for high-priority links first.
- for (int j=0; j<ndev; ++j) {
- int i = (myId - 1 + ndev + j) % ndev;
+ for (int i=0; i<ndev; ++i) {
int iRank = ranks[i].rank;
int iDev = ranks[i].cudaDev;
pid_t iPid = ranks[i].pid;
@@ -399,84 +392,127 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran
canpeer = 0;
}
+ cudaError_t err;
+ ncclMem* remoteHostBuff;
+
+ comm->ptrs[i].type = NodeRef::HOST; // Assume host buffer
+ comm->ptrs[i].devCleanup = NULL;
+ comm->ptrs[i].hostCleanup = NULL;
+
if (iPid == myPid) {
- if (myDev == iDev) {
+ remoteHostBuff = ranks[i].hostptr;
+
+ if (myDev == iDev) { // shared device
INFO("rank access %d -> %d via common device", rank, iRank);
- comm->ptrs[i].local = ranks[myId].devptr;
+ comm->ptrs[i].type = NodeRef::DEVICE;
+ comm->ptrs[i].local = ranks[myNcclId].devptr;
comm->ptrs[i].remote = ranks[i].devptr;
- comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
- } else {
- int peer_enabled = canpeer;
- if (canpeer) {
- cudaError_t p2pErr = cudaDeviceEnablePeerAccess(iDev, 0);
- if (p2pErr == cudaErrorPeerAccessAlreadyEnabled) {
- cudaGetLastError();
- } else if (p2pErr != cudaSuccess) {
- INFO("peer access failed between rank %d (dev %d) and rank %d (dev %d)\n",
- rank, myDev, iRank, iDev);
- peer_enabled = 0;
- }
- }
-
- if (peer_enabled) {
- INFO("rank access %d -> %d via P2P device mem", rank, iRank);
- comm->ptrs[i].local = ranks[myId].devptr;
- comm->ptrs[i].remote = ranks[i].devptr;
- comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
- } else { // go through hostmem
- INFO("rank access %d -> %d via zero-copy host mem", rank, iRank);
- if (j <= 2)
- *ringDirectFailed = 1;
- if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myId].hostptr, 0) != cudaSuccess) {
- WARN("rank %d failed to map zero copy buffer to device", rank);
- commClearMaps(comm);
- return ncclUnhandledCudaError;
- }
- if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, ranks[i].hostptr, 0) != cudaSuccess) {
- WARN("rank %d failed to map %d's zero copy buffer to device", rank, iRank);
- commClearMaps(comm);
- return ncclUnhandledCudaError;
- }
- comm->ptrs[i].remoteCleanup = CLEANUP_NONE;
- }
- }
- } else { // multi-process!
- *ringDirectFailed = 1;
- if (canpeer || myDev == iDev) {
- INFO("rank access %d -> %d via Ipc P2P device mem", rank, iRank);
- comm->ptrs[i].local = ranks[myId].devptr;
- if (cudaIpcOpenMemHandle((void**)(&comm->ptrs[i].remote),
- ranks[i].devipc, cudaIpcMemLazyEnablePeerAccess) != cudaSuccess) {
- WARN("rank %d failed to open Ipc handle to rank %d", rank, iRank);
+ } else if (canpeer) {
+ INFO("rank access %d -> %d via P2P device mem", rank, iRank);
+ err = cudaDeviceEnablePeerAccess(iDev, 0);
+ if (err == cudaErrorPeerAccessAlreadyEnabled) {
+ cudaGetLastError();
+ } else if (err != cudaSuccess) {
+ WARN("rank %d failed to peer with device %d: %s",
+ rank, iDev, cudaGetErrorString(err));
commClearMaps(comm);
return ncclUnhandledCudaError;
}
- comm->ptrs[i].remoteCleanup = CLEANUP_CUIPC;
- comm->ptrs[i].cleanupHandle = comm->ptrs[i].remote;
- } else { // go through hostmem
- INFO("rank access %d -> %d via zero copy host shm", rank, iRank);
- if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myId].hostptr, 0) != cudaSuccess) {
- WARN("rank %d failed to obtain dev ptr to sysmem buffer", rank);
- commClearMaps(comm);
- return ncclUnhandledCudaError;
- }
- char rankname[1024];
- sprintf(rankname, "%s-%d", commId->internal, ranks[i].rank);
- if (openHostMemShm(rankname, (ncclMem**)&comm->ptrs[i].cleanupHandle, ranks[i].buffSize)
- != ncclSuccess) {
- WARN("rank %d failed to open sysmem buffer of rank %d", rank, iRank);
- commClearMaps(comm);
- return ncclUnhandledCudaError;
- }
- if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, comm->ptrs[i].cleanupHandle, 0) != cudaSuccess) {
- WARN("rank %d failed to obtain dev ptr for rank %d", rank, iRank);
+ comm->ptrs[i].type = NodeRef::DEVICE;
+ comm->ptrs[i].local = ranks[myNcclId].devptr;
+ comm->ptrs[i].remote = ranks[i].devptr;
+ }
+ } else { // Separate processes
+ *globalMemSpaceBroke = 1;
+ char rankname[1024];
+ sprintf(rankname, "%s-%d", commId->internal, ranks[i].rank);
+ if (openHostMemShm(rankname, &remoteHostBuff, ranks[i].buffSize)
+ != ncclSuccess) {
+ WARN("rank %d failed to open sysmem buffer of rank %d", rank, iRank);
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
+ }
+ comm->ptrs[i].hostCleanup = remoteHostBuff;
+
+ // TODO: Extend to same device (MPS) case.
+ // At present that would go through host mem.
+ if (canpeer) {
+ INFO("rank access %d -> %d via IPC device mem", rank, iRank);
+ comm->ptrs[i].type = NodeRef::DEVICE;
+ comm->ptrs[i].local = ranks[myNcclId].devptr;
+ err = cudaIpcOpenMemHandle((void**)(&comm->ptrs[i].remote),
+ ranks[i].devipc, cudaIpcMemLazyEnablePeerAccess);
+ if (err != cudaSuccess) {
+ WARN("rank %d failed to open Ipc handle to rank %d: %s",
+ rank, iRank, cudaGetErrorString(err));
commClearMaps(comm);
return ncclUnhandledCudaError;
}
- comm->ptrs[i].remoteCleanup = CLEANUP_UNMAP;
+ comm->ptrs[i].devCleanup = comm->ptrs[i].remote;
+ }
+ }
+
+ err = cudaHostGetDevicePointer(&comm->ptrs[i].opCounter,
+ &(remoteHostBuff->opCounter), 0);
+ if (err != cudaSuccess) {
+ WARN("rank %d failed to obtain %d's zero copy pointer: %s",
+ rank, iRank, cudaGetErrorString(err));
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
+ }
+
+ if (comm->ptrs[i].type == NodeRef::HOST) {
+ *globalMemSpaceBroke = 1;
+ INFO("rank access %d -> %d via zero-copy host mem", rank, iRank);
+ if (cudaHostGetDevicePointer(&comm->ptrs[i].local, ranks[myNcclId].hostptr, 0) != cudaSuccess) {
+ WARN("rank %d failed to map zero copy buffer to device", rank);
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
+ }
+ if (cudaHostGetDevicePointer(&comm->ptrs[i].remote, remoteHostBuff, 0) != cudaSuccess) {
+ WARN("rank %d failed to map %d's zero copy buffer to device", rank, iRank);
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
}
}
}
+
+ // Setup device-side ring view
+ if (cudaMemcpy(comm->devUserFromRing, comm->userFromRing, ndev*sizeof(int),
+ cudaMemcpyHostToDevice) != cudaSuccess) {
+ WARN("rank %d failed to copy maps to device", rank);
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
+ }
+
+ DevRing<char> ringTemp;
+ memcpy(ringTemp.userRank, comm->userFromRing, ndev*sizeof(int));
+
+ int prevIdx = comm->ncclFromRing[comm->nRanks-1];
+ int nextIdx = comm->ncclFromRing[1 % comm->nRanks];
+ NodeRef* prevPtrs = comm->ptrs+prevIdx;
+ NodeRef* nextPtrs = comm->ptrs+nextIdx;
+
+ ringTemp.prevOpCounter = prevPtrs->opCounter;
+ ringTemp.nextOpCounter = nextPtrs->opCounter;
+ ringTemp.sendFlagToNext = nextPtrs->remote->flags;
+ ringTemp.recvFlagFromPrev = prevPtrs->local->flags;
+ ringTemp.sendFlagToPrev = prevPtrs->remote->flags+1;
+ ringTemp.recvFlagFromNext = nextPtrs->local->flags+1;
+
+ ringTemp.recvPtrFromNext = (char**)&nextPtrs->local->recvPtrs;
+ ringTemp.sendPtrToPrev = (char**)&prevPtrs->remote->recvPtrs;
+
+ ringTemp.recvBuffer = prevPtrs->local->buff;
+ ringTemp.sendBuffer = nextPtrs->remote->buff;
+
+ if (cudaMemcpy(comm->devRing, &ringTemp, sizeof(ringTemp),
+ cudaMemcpyHostToDevice) != cudaSuccess) {
+ WARN("rank %d failed to copy ring maps to device", rank);
+ commClearMaps(comm);
+ return ncclUnhandledCudaError;
+ }
+
return ncclSuccess;
}
@@ -495,23 +531,24 @@ static void initDebug() {
ncclDebugLevel = ABORT;
INFO("NCCL debug level set to ABORT");
}
-
}
static void commFree(ncclComm_t comm) {
if (comm == NULL)
return;
- for(int i=0; i<MAXQUEUE; ++i) {
- if (comm->events.isDone[i] != NULL)
- if (cudaEventDestroy(comm->events.isDone[i]) != cudaSuccess)
- INFO("failed to destroy cuda event %d", i);
- }
+ if (comm->doneEvent != NULL)
+ if (cudaEventDestroy(comm->doneEvent) != cudaSuccess)
+ INFO("ncclComm failed to destroy doneEvent");
ncclResult_t res = commClearMaps(comm);
if (res != ncclSuccess)
INFO("failed to cleanup comm maps");
+ if (comm->devRing != NULL)
+ if (cudaFree(comm->devRing) != cudaSuccess)
+ INFO("commFree failed to free devRing");
+
if (comm->userFromRing != NULL)
free(comm->userFromRing);
@@ -519,8 +556,8 @@ static void commFree(ncclComm_t comm) {
if (cudaFree(comm->devUserFromRing) != cudaSuccess)
INFO("commFree failed to free dev maps");
- if (comm->ringFromUser != NULL)
- free(comm->ringFromUser);
+ if (comm->ncclFromRing != NULL)
+ free(comm->ncclFromRing);
if (comm->devMem != NULL && cudaFree(comm->devMem) != cudaSuccess)
INFO("Failed to free devMap");
@@ -550,7 +587,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
return ncclInvalidRank;
}
- size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(ncclNodeRef);
+ size_t commBytes = offsetof(ncclComm, ptrs) + ndev*sizeof(NodeRef);
struct ncclComm* comm = (struct ncclComm*)malloc(commBytes);
if (comm == NULL) {
WARN("comm allocation failed");
@@ -558,21 +595,23 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
}
memset(comm, 0, commBytes);
- comm->nDev = ndev;
+ comm->nRanks = ndev;
cudaGetDevice(&comm->cudaDev);
const char* str = getenv("NCCL_BUFFSIZE");
+ int buffsize;
if (str != NULL) {
errno = 0;
- comm->buffSize = strtol(str, NULL, 10);
- if (errno == ERANGE || comm->buffSize == 0) {
+ buffsize = strtol(str, NULL, 10);
+ if (errno == ERANGE || buffsize == 0) {
INFO("rank %d invalid NCCL_BUFFSIZE: %s, using default %lu",
rank, str, DEFAULT_BUFFER_SIZE_BYTES);
- comm->buffSize = DEFAULT_BUFFER_SIZE_BYTES;
+ buffsize = DEFAULT_BUFFER_SIZE_BYTES;
}
} else {
- comm->buffSize = DEFAULT_BUFFER_SIZE_BYTES;
+ buffsize = DEFAULT_BUFFER_SIZE_BYTES;
}
+ comm->buffSize = buffsize;
INFO("rank %d using buffSize = %lu", rank, comm->buffSize);
@@ -583,7 +622,14 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
commFree(comm);
return res;
}
- if (cudaMalloc(&comm->devUserFromRing, ndev*sizeof(int)) != cudaSuccess) {
+
+ if (cudaMalloc(&comm->devRing, sizeof(DevRing<char>)) != cudaSuccess) {
+ WARN("rank %d failed to allocate device-side ring views", rank);
+ commFree(comm);
+ return ncclCudaMallocFailed;
+ }
+
+ if (cudaMalloc(&comm->devUserFromRing, ndev*sizeof(int)) != cudaSuccess ) {
WARN("rank %d failed to allocated device maps", rank);
commFree(comm);
return ncclCudaMallocFailed;
@@ -596,20 +642,17 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
return ncclSystemError;
}
- comm->ringFromUser = (int*)malloc(ndev*sizeof(int));
- if (comm->ringFromUser == NULL) {
+ comm->ncclFromRing = (int*)malloc(ndev*sizeof(int));
+ if (comm->ncclFromRing == NULL) {
WARN("rank %d failed to allocate host maps", rank);
commFree(comm);
return ncclSystemError;
}
- EventQueue* eq = &comm->events;
- for(int i=0; i<MAXQUEUE; ++i) {
- if (cudaEventCreateWithFlags(eq->isDone+i, cudaEventDisableTiming) != cudaSuccess) {
- WARN("rank %d failed to create nccl event %d", rank, i);
- commFree(comm);
- return ncclUnhandledCudaError;
- }
+ if (cudaEventCreateWithFlags(&comm->doneEvent, cudaEventDisableTiming) != cudaSuccess) {
+ WARN("ncclComm on rank %d failed to create doneEvent", rank);
+ commFree(comm);
+ return ncclUnhandledCudaError;
}
if(commId == NULL) {
@@ -627,10 +670,46 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, const ncclUniqueId*
comm->hostMemState = ShmMapped | ShmLinked;
}
+ if (cudaHostGetDevicePointer(&comm->opCounter, &comm->hostMem->opCounter, 0) != cudaSuccess) {
+ WARN("ncclComm on rank %d failed to map opCounter to device", rank);
+ commFree(comm);
+ return ncclUnhandledCudaError;
+ }
+
*comret = comm;
return ncclSuccess;
}
+static ncclResult_t devCommUpdate(ncclComm_t comm) {
+ // Copy the comm on the device
+ size_t commBytes = offsetof(ncclComm, ptrs) + comm->nRanks*sizeof(NodeRef);
+ if (cudaMemcpy(comm->devComm, comm, commBytes, cudaMemcpyHostToDevice) != cudaSuccess) {
+ WARN("failed to copy device comm");
+ return ncclUnhandledCudaError;
+ }
+ // Fix the host pointer to be accessible from the device
+ void* dptr;
+ if (cudaHostGetDevicePointer(&dptr, comm->hostMem, 0) != cudaSuccess) {
+ WARN("failed to get device pointer for host mem");
+ return ncclUnhandledCudaError;
+ }
+ if (cudaMemcpy(&comm->devComm->hostMem, &dptr, sizeof(dptr), cudaMemcpyHostToDevice) != cudaSuccess) {
+ WARN("failed to update host pointer");
+ return ncclUnhandledCudaError;
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t devCommSetup(ncclComm_t comm) {
+ // Fully duplicate the comm on the device
+ size_t commBytes = offsetof(ncclComm, ptrs) + comm->nRanks*sizeof(NodeRef);
+ if (cudaMalloc(&comm->devComm, commBytes) != cudaSuccess) {
+ WARN("failed to allocated device comm");
+ return ncclCudaMallocFailed;
+ }
+ return devCommUpdate(comm);
+}
+
static ncclResult_t commUnlinkHostMem(ncclComm_t comm, ncclUniqueId commId, int rank) {
char rankname[1024];
sprintf(rankname, "%s-%d", commId.internal, rank);
@@ -643,12 +722,12 @@ static void showVersion() {
static int shown = 0;
if (shown == 0 && ncclDebugLevel >= VERSION) {
printf("NCCL version %d.%d.%d compiled with CUDA %d.%d\n", NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH, CUDA_MAJOR, CUDA_MINOR);
- fflush(stdout); \
+ fflush(stdout);
shown = 1;
}
}
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank);
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
if (myrank == 0) showVersion();
@@ -693,14 +772,14 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
goto cleanup;
}
- res = commBuildMaps(*newcomm, &commId, myrank, gath->ranks, &gath->ringDirectFail);
+ res = commBuildMaps(*newcomm, &commId, myrank, gath->ranks, &gath->globalMemSpaceBroke);
if (res != ncclSuccess) {
WARN("rank %d failed to build comm maps", myrank);
goto cleanup;
}
- syncRingDirect(gath, &((*newcomm)->useRemoteRecv));
- INFO("PushToRecv algos are %s\n", (*newcomm)->useRemoteRecv ? "enabled" : "disabled");
+ syncRingDirect(gath, &((*newcomm)->globalMemSpace));
+ INFO("Global device memory space is %s", (*newcomm)->globalMemSpace ? "enabled" : "disabled");
res = closeGather(gath, ndev); // includes a barrier
gath = NULL;
@@ -709,6 +788,13 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
goto cleanup;
}
+ res = devCommSetup(*newcomm);
+ if (res != ncclSuccess) {
+ WARN("rank %d failed to copy dcomm", myrank);
+ goto cleanup;
+ }
+
+ res = ncclSuccess;
goto final;
cleanup:
@@ -727,7 +813,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
return res;
}
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclCommInitAll, ncclComm_t* comms, int ndev, const int* devlist);
ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
initDebug();
@@ -741,7 +827,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
char busId[13];
nvmlDevice_t nvmlHandle;
int affinity_set = 0;
- int ringDirectFail = 0; // Assume direct access to recv ptr OK
+ int globalMemSpaceBroke = 0; // Assume direct access to recv ptr OK
res = wrapSymbols();
if (res != ncclSuccess) {
@@ -812,16 +898,24 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
for(rank=0; rank<ndev; ++rank) {
comm = comms[rank];
cudaSetDevice(comm->cudaDev);
- res = commBuildMaps(comm, NULL, rank, ranks, &ringDirectFail);
+ res = commBuildMaps(comm, NULL, rank, ranks, &globalMemSpaceBroke);
if (res != ncclSuccess) {
WARN("rank %d failed to build comm maps", rank);
goto cleanup;
}
}
- INFO("PushToRecv algos are %s\n", (ringDirectFail) ? "disabled" : "enabled");
+ INFO("Global device memory space is %s", (globalMemSpaceBroke) ? "disabled" : "enabled");
for(rank=0; rank<ndev; ++rank) {
- comms[rank]->useRemoteRecv = ringDirectFail ? 0 : 1;
+ comms[rank]->globalMemSpace = globalMemSpaceBroke ? 0 : 1;
+ }
+
+ for(rank=0; rank<ndev; ++rank) {
+ res = devCommSetup(comms[rank]);
+ if (res != ncclSuccess) {
+ WARN("rank %d failed to copy dcomm", rank);
+ goto cleanup;
+ }
}
free(ranks);
@@ -845,8 +939,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
return res;
}
-
-extern "C" DSOGLOBAL
+NCCL_API(void, ncclCommDestroy, ncclComm_t comm);
void ncclCommDestroy(ncclComm_t comm) {
if (comm == NULL)
return;
@@ -865,7 +958,7 @@ void ncclCommDestroy(ncclComm_t comm) {
cudaSetDevice(savedDevice);
}
-extern "C" DSOGLOBAL
+NCCL_API(const char*, ncclGetErrorString, ncclResult_t code);
const char* ncclGetErrorString(ncclResult_t code) {
switch (code) {
case ncclSuccess : return "no error";
@@ -887,21 +980,21 @@ const char* ncclGetErrorString(ncclResult_t code) {
return "unknown result code";
}
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
- *count = comm->nDev;
+ *count = comm->nRanks;
return ncclSuccess;
}
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid);
ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
*devid = comm->cudaDev;
return ncclSuccess;
}
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank);
ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
- *rank = comm->userFromRing[comm->ncclId];
+ *rank = comm->rank;
return ncclSuccess;
}
diff --git a/src/core.h b/src/core.h
index 591b934..bbabf49 100644
--- a/src/core.h
+++ b/src/core.h
@@ -1,19 +1,17 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
#ifndef CORE_H_
#define CORE_H_
+
#include "nccl.h"
#include <cstdio>
#include <cuda_runtime.h>
-#define MAXFLAGS 8
-#define MAXQUEUE 4 // Maximum number of queued collectives per communicator.
-#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
// DIE on error
#define CUDACHECK(cmd) do { \
@@ -25,55 +23,78 @@
} \
} while(false)
-#define NCCL_MEM_PAD_ALIGN 4096
-typedef struct {
- cudaEvent_t isDone[MAXQUEUE];
- int back; // Last event used
-} EventQueue;
+#define MAXRANKS 32
+#define DEFAULT_BUFFER_SIZE_BYTES (1UL << 21)
+#define NCCL_MEM_PAD_ALIGN 65536
+
struct ncclMem {
union { // Pad this block so that devBuff is correctly aligned
struct {
- int flags[MAXFLAGS];
- void* recvPtrs[MAXFLAGS];
+ int flags[2];
+ void* recvPtrs;
+ int opCounter; // Used to determine when remote Communicators are ready.
+ // Only used in host memory.
};
char pad[NCCL_MEM_PAD_ALIGN];
};
- // devBuff will likely be bigger ; we only use its offset/address.
- char buff[NCCL_MEM_PAD_ALIGN];
+ // devBuff will be bigger ; we only use its offset/address.
+ char buff[1];
+};
+
+template <typename T>
+struct alignas(long long) DevRing {
+ volatile int* __restrict__ prevOpCounter;
+ volatile int* __restrict__ nextOpCounter;
+ volatile int* __restrict__ sendFlagToNext;
+ volatile int* __restrict__ sendFlagToPrev;
+ volatile int* __restrict__ recvFlagFromNext;
+ volatile int* __restrict__ recvFlagFromPrev;
+
+ T* volatile * __restrict__ recvPtrFromNext;
+ T* volatile * __restrict__ sendPtrToPrev;
+ T* __restrict__ recvBuffer;
+ T* __restrict__ sendBuffer;
+
+ int userRank[MAXRANKS];
};
-struct ncclNodeRef {
- ncclMem* remote;
- ncclMem* local;
- int remoteCleanup;
- void* cleanupHandle;
+struct NodeRef {
+ ncclMem* remote; // TODO: Verify if these
+ ncclMem* local; // are still needed.
+ enum {DEVICE, HOST} type;
+ ncclMem* devCleanup; // Used only when remote comm uses same process & GPU
+ ncclMem* hostCleanup; // Used whenever target is in different process
+ int* opCounter; // TODO: see if this can be removed too.
};
+
struct ncclComm {
- int nDev; // number of devices in communicator
- int cudaDev; // cuda device index
- int ncclId; // nccl logical index
+ int rank; // my rank in the communicator
+ int nRanks; // number of GPUs in communicator
+ int cudaDev; // my cuda device index
// Device and Host allocated chunks. Stored here to correctly free() memory.
ncclMem* devMem;
ncclMem* hostMem;
int hostMemState;
+ int opSched; // Scheduling operation index
+ int* opCounter; // Counter of completed operations
- // Placed between calling and internal device streams.
- EventQueue events;
+ cudaStream_t prevStream; // cache last used stream
+ cudaEvent_t doneEvent; // orders operations in different streams
// Maps an internal nccl index to user-specified rank order. This is necessary
// since we need to know how the user expects data to be ordered across
- // devices.
+ // devices. Ordered from current device.
int* userFromRing;
// copy of the above stored on each device
int* devUserFromRing;
- // Inverse of userFromRing. Maps user specified index to internal nccl index.
- int* ringFromUser;
+ // Ring order
+ int* ncclFromRing; // TODO: REMOVE IF NOT NEEDED BEYOND CORE.CU
// Size of temp buffer in bytes.
size_t buffSize;
@@ -81,13 +102,20 @@ struct ncclComm {
// Whether we have remote access to the recvbuff pointers passed from remote
// GPUs. In single process mode this can be used as long as QPI links are
// not present. In multi-process, we never push to a remote recvbuff.
- int useRemoteRecv;
+ int globalMemSpace;
+
+ // Device copy of the communicator
+ struct ncclComm *devComm; // TODO: Remove this if not useful
+
+ // Device-side ring view
+ DevRing<char>* devRing;
// Device-to-device communication structures to access remote or local device
// memory. Actual allocation larger than 1.
- ncclNodeRef ptrs[1];
+ NodeRef ptrs[1];
};
+
typedef enum {NONE=0, VERSION=1, WARN=2, INFO=3, ABORT=4} DebugLevel;
extern DebugLevel ncclDebugLevel;
@@ -96,6 +124,7 @@ extern DebugLevel ncclDebugLevel;
printf("WARN %s:%d ", __FILE__, __LINE__); \
printf(__VA_ARGS__); \
printf("\n"); \
+ fflush(stdout); \
if (ncclDebugLevel >= ABORT) abort(); \
} \
} while(0)
@@ -103,10 +132,26 @@ extern DebugLevel ncclDebugLevel;
#define INFO(...) do { \
if (ncclDebugLevel >= INFO) { \
printf("INFO "); printf(__VA_ARGS__); printf("\n"); \
+ fflush(stdout); \
} \
} while(0)
-#define DSOGLOBAL __attribute__((visibility("default")))
+#ifdef PROFAPI
+#define NCCL_API(ret, func, args...) \
+ __attribute__ ((visibility("default"))) \
+ __attribute__ ((alias(#func))) \
+ ret p##func (args); \
+ extern "C" \
+ __attribute__ ((visibility("default"))) \
+ __attribute__ ((weak)) \
+ ret func(args)
+#else
+#define NCCL_API(ret, func, args...) \
+ extern "C" \
+ __attribute__ ((visibility("default"))) \
+ ret func(args)
+#endif // end PROFAPI
+
#endif // end include guard
diff --git a/src/enqueue.h b/src/enqueue.h
index afa1cbd..01c44c2 100644
--- a/src/enqueue.h
+++ b/src/enqueue.h
@@ -1,57 +1,111 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
#ifndef enqueue_h_
#define enqueue_h_
#include "core.h"
+#include "reduce_kernel.h"
-int getRingIndex(const ncclComm_t comm, int device);
-void lockEventQueue(EventQueue* eq);
-void releaseEventQueue(EventQueue* eq);
-void CUDART_CB freeEvent(cudaStream_t stream, cudaError_t status, void* eq_void);
-
-/* Syncronize with user stream and launch the collective.
- * All work is performed asynchronously with the host thread.
- * The actual collective should be a functor with the
- * folloaing signature.
- * ncclResult_t collective(void* sendbuff, void* recvbuff,
- * int count, ncclDataType_t type, ncclRedOp_t op,
- * int root, ncclComm_t comm);
- * Unneeded arguments should be ignored. The collective may
- * assume that the appropriate cuda device has been set. */
-template<typename ColFunc>
-ncclResult_t enqueue(ColFunc colfunc,
- const void* sendbuff,
+/* Syncronize previous collective (if in different stream) and enqueue
+ * collective. Work is performed asynchronously with the host thread.
+ * The ColFunc class should be templated on the datatype and reduction
+ * operator (if applicable) and define a static entry() method as
+ * follows.
+ * template <typename T, template <typename> class RedOp>
+ * class CollectiveFunctor {
+ * public:
+ * static ncclResult_t entry(const void* sendbuff, void* recvbuff, int count,
+ * int root, ncclComm* comm, cudaStream_t stream);
+ * };
+ * The entry() method can assume that the appropriate cuda device has been set. */
+template< template<typename, template<typename> class> class ColFunc,
+ typename T,
+ template<typename> class Op >
+ncclResult_t enqueue(const void* sendbuff,
void* recvbuff,
int count,
- ncclDataType_t type,
- ncclRedOp_t op,
int root,
ncclComm_t comm,
cudaStream_t stream)
{
- int curDevice;
- CUDACHECK( cudaGetDevice(&curDevice) );
+ if (stream != comm->prevStream) { // sync required for calls in different streams
+ comm->prevStream = stream;
+ CUDACHECK( cudaStreamWaitEvent(stream, comm->doneEvent, 0) );
+ }
- // No need for a mutex here because we assume that all enqueue operations happen in a fixed
- // order on all devices. Thus, thread race conditions SHOULD be impossible.
- EventQueue* eq = &comm->events;
+ ncclResult_t ret;
+ ret = ColFunc<T, Op>::entry(sendbuff, recvbuff, count, root, comm, stream);
- // Ensure that previous collective is complete
- cudaError_t flag = cudaEventQuery(eq->isDone[eq->back]);
- if( flag == cudaErrorNotReady )
- CUDACHECK( cudaStreamWaitEvent(stream, eq->isDone[eq->back], 0) );
+ // Always have to record done event because we don't know what stream next
+ // collective will be in.
+ CUDACHECK( cudaEventRecord(comm->doneEvent, stream) );
+ comm->opSched += 1;
+ return ret;
+}
- // Launch the collective here
- ncclResult_t ret = colfunc(sendbuff, recvbuff, count, type, op, root, comm, stream);
- eq->back = (eq->back + 1) % MAXQUEUE;
- CUDACHECK( cudaEventRecord(eq->isDone[eq->back], stream) );
- return ret;
+// This version decodes type
+template< template<typename, template<typename> class> class ColFunc,
+ template<typename> class Op >
+ncclResult_t enqueue(const void* sendbuff,
+ void* recvbuff,
+ int count,
+ ncclDataType_t type,
+ int root,
+ ncclComm_t comm,
+ cudaStream_t stream)
+{
+ switch(type) {
+ case ncclChar:
+ return enqueue<ColFunc, char, Op>(sendbuff, recvbuff, count, root, comm, stream);
+ case ncclInt:
+ return enqueue<ColFunc, int, Op>(sendbuff, recvbuff, count, root, comm, stream);
+#ifdef CUDA_HAS_HALF
+ case ncclHalf:
+ return enqueue<ColFunc, half, Op>(sendbuff, recvbuff, count, root, comm, stream);
+#endif
+ case ncclFloat:
+ return enqueue<ColFunc, float, Op>(sendbuff, recvbuff, count, root, comm, stream);
+ case ncclDouble:
+ return enqueue<ColFunc, double, Op>(sendbuff, recvbuff, count, root, comm, stream);
+ case ncclInt64:
+ return enqueue<ColFunc, long long, Op>(sendbuff, recvbuff, count, root, comm, stream);
+ case ncclUint64:
+ return enqueue<ColFunc, unsigned long long, Op>(sendbuff, recvbuff, count, root, comm, stream);
+ default:
+ WARN("Invalid ncclType %d", type);
+ return ncclInvalidType;
+ }
+}
+
+// This version decodes both type and reduction op
+template< template<typename, template<typename> class> class ColFunc>
+ncclResult_t enqueue(const void* sendbuff,
+ void* recvbuff,
+ int count,
+ ncclDataType_t type,
+ ncclRedOp_t op,
+ int root,
+ ncclComm_t comm,
+ cudaStream_t stream)
+{
+ switch(op) {
+ case ncclSum:
+ return enqueue<ColFunc, FuncSum>(sendbuff, recvbuff, count, type, root, comm, stream);
+ case ncclProd:
+ return enqueue<ColFunc, FuncProd>(sendbuff, recvbuff, count, type, root, comm, stream);
+ case ncclMax:
+ return enqueue<ColFunc, FuncMax>(sendbuff, recvbuff, count, type, root, comm, stream);
+ case ncclMin:
+ return enqueue<ColFunc, FuncMin>(sendbuff, recvbuff, count, type, root, comm, stream);
+ default:
+ WARN("Invalid ncclRedOp: %d", op);
+ return ncclInvalidOperation;
+ }
}
#endif // End include guard
diff --git a/src/libwrap.cu b/src/libwrap.cu
index ced3f73..5cfa546 100644
--- a/src/libwrap.cu
+++ b/src/libwrap.cu
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
#include "libwrap.h"
@@ -25,7 +25,6 @@ ncclResult_t wrapSymbols(void) {
return ncclSuccess;
static void* nvmlhandle = NULL;
- static void* cuhandle = NULL;
void* tmp;
void** cast;
@@ -38,20 +37,11 @@ ncclResult_t wrapSymbols(void) {
}
}
- cuhandle = dlopen("libcuda.so", RTLD_NOW);
- if (!cuhandle) {
- cuhandle = dlopen("libcuda.so.1", RTLD_NOW);
- if (!cuhandle) {
- WARN("Failed to open libcuda.so[.1]");
- goto teardown;
- }
- }
-
#define LOAD_SYM(handle, symbol, funcptr) do { \
cast = (void**)&funcptr; \
tmp = dlsym(handle, symbol); \
if (tmp == NULL) { \
- WARN("dlsym failed on %s - %s", symbol, dlerror()); \
+ WARN("dlsym failed on %s - %s", symbol, dlerror());\
goto teardown; \
} \
*cast = tmp; \
@@ -76,7 +66,6 @@ ncclResult_t wrapSymbols(void) {
nvmlInternalDeviceSetCpuAffinity = NULL;
nvmlInternalDeviceClearCpuAffinity = NULL;
- if (cuhandle != NULL) dlclose(cuhandle);
if (nvmlhandle != NULL) dlclose(nvmlhandle);
return ncclSystemError;
}
@@ -84,7 +73,7 @@ ncclResult_t wrapSymbols(void) {
ncclResult_t wrapNvmlInit(void) {
if (nvmlInternalInit == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalInit();
@@ -98,7 +87,7 @@ ncclResult_t wrapNvmlInit(void) {
ncclResult_t wrapNvmlShutdown(void) {
if (nvmlInternalShutdown == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalShutdown();
@@ -112,7 +101,7 @@ ncclResult_t wrapNvmlShutdown(void) {
ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_t* device) {
if (nvmlInternalDeviceGetHandleByPciBusId == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceGetHandleByPciBusId(pciBusId, device);
@@ -126,7 +115,7 @@ ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_
ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index) {
if (nvmlInternalDeviceGetIndex == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceGetIndex(device, index);
@@ -140,7 +129,7 @@ ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index) {
ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device) {
if (nvmlInternalDeviceSetCpuAffinity == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceSetCpuAffinity(device);
@@ -154,7 +143,7 @@ ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device) {
ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device) {
if (nvmlInternalInit == NULL) {
- WARN("lib wrapper not initilaized.");
+ WARN("lib wrapper not initialized.");
return ncclLibWrapperNotSet;
}
RetCode ret = nvmlInternalDeviceClearCpuAffinity(device);
@@ -165,3 +154,4 @@ ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device) {
}
return ncclSuccess;
}
+
diff --git a/src/libwrap.h b/src/libwrap.h
index 787912f..9397392 100644
--- a/src/libwrap.h
+++ b/src/libwrap.h
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
@@ -14,6 +14,15 @@
typedef struct nvmlDevice_st* nvmlDevice_t;
+/**
+ * Generic enable/disable enum.
+ */
+typedef enum nvmlEnableState_enum
+{
+ NVML_FEATURE_DISABLED = 0, //!< Feature disabled
+ NVML_FEATURE_ENABLED = 1 //!< Feature enabled
+} nvmlEnableState_t;
+
ncclResult_t wrapSymbols(void);
ncclResult_t wrapNvmlInit(void);
@@ -22,6 +31,7 @@ ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_
ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index);
ncclResult_t wrapNvmlDeviceSetCpuAffinity(nvmlDevice_t device);
ncclResult_t wrapNvmlDeviceClearCpuAffinity(nvmlDevice_t device);
+ncclResult_t wrapNvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t *device);
#endif // End include guard
diff --git a/src/primitives.h b/src/primitives.h
new file mode 100644
index 0000000..4d7b86b
--- /dev/null
+++ b/src/primitives.h
@@ -0,0 +1,206 @@
+/*************************************************************************
+ * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef PRIMITIVES_H_
+#define PRIMITIVES_H_
+
+#include <type_traits>
+#include "copy_kernel.h" // for FuncPassA
+#include "reduce_kernel.h" // for reduction funcs
+
+
+/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
+ *
+ * In order to reduce the reptetion of template arguments, the operations
+ * are bundled as static methods of the Primitives class.
+ *
+ * Each primitive operation copies/reduces a contiguous buffer and syncs
+ * an optional set of flags against a sub-step counter. The sync value is
+ * based on the step parameter. Sync flags must be of type WaitFlag or
+ * PostFlag. The primitive routines wait for all WaitFlag args to attain
+ * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
+ * corresponding substep by previous step) before executing the transfer.
+ * After each substep is transfered, all PostFlag arguments get updated to
+ * the value SUBSTEPS*step+substep+1.
+ */
+
+
+class WaitFlag {
+ volatile int * const flag;
+ const int shift;
+ public:
+ __device__ __forceinline__
+ WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { }
+ __device__ __forceinline__
+ void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; }
+};
+
+
+class PostFlag {
+ volatile int * const flag;
+ const int shift;
+ public:
+ __device__ __forceinline__
+ PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { }
+ __device__ __forceinline__
+ void post(int val) { *flag = (val + shift); }
+};
+
+
+// Helper to check if any argument is of type T.
+// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
+template<typename T> __device__ __forceinline__
+bool AnyAre() { return false; }
+
+template<typename T, typename FIRST_T, typename... TAIL_Ts>
+__device__ __forceinline__
+bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
+ return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
+}
+
+
+// Wait on all WaitFlags, ignore PostFlags
+__device__ __forceinline__
+void WaitOnFlags(int val) { }
+
+template <typename... TAIL_Ts> __device__ __forceinline__
+void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
+ flag.wait(val);
+ WaitOnFlags(val, tail...);
+}
+
+template <typename... TAIL_Ts> __device__ __forceinline__
+void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) {
+ WaitOnFlags(val, tail...);
+}
+
+
+// Post all PostFlags, ingnore WaitFlags
+__device__ __forceinline__
+void PostToFlags(int val) { }
+
+template <typename... TAIL_Ts> __device__ __forceinline__
+void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
+ PostToFlags(val, tail...);
+}
+
+template <typename... TAIL_Ts> __device__ __forceinline__
+void PostToFlags(int val, PostFlag flag, TAIL_Ts... tail) {
+ flag.post(val);
+ PostToFlags(val, tail...);
+}
+
+
+// Create pointer arithmetic syntax that doesn't break for nullptr_t
+template <typename Tptr> __device__ __forceinline__
+Tptr ptradd(Tptr ptr, int i) {
+ return ptr + i;
+}
+
+__device__ __forceinline__
+nullptr_t ptradd(nullptr_t ptr, int i) {
+ return nullptr;
+}
+
+
+// Implementation of primitive types
+template <int THREADS, int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
+class Primitives {
+ private:
+ template <typename SRC2_T, // either T* or nullptr_t
+ typename DST2_T, // either T* or nullptr_t
+ typename... SYNC_Ts> // either WaitFunc or PostFunc
+ static __device__ __forceinline__ void
+ GenericOp(const T* src1,
+ const SRC2_T src2,
+ T* dst1,
+ DST2_T dst2,
+ int len, int maxoffset, int step, SYNC_Ts... flags) {
+
+ enum { noSrc2 = std::is_same<SRC2_T, nullptr_t>::value };
+ enum { noDst2 = std::is_same<DST2_T, nullptr_t>::value };
+ static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
+ "src2 must be of type T* or nullptr_t");
+ static_assert(noDst2 || std::is_same<DST2_T, T*>::value,
+ "dst2 must be of type T* or nullptr_t");
+
+ using OpType = typename std::conditional<noSrc2, FuncPassA<T>, REDOP>::type;
+
+ if (threadIdx.x < THREADS) {
+ int sliceSize = len / SUBSTEPS;
+ int sliceOffset = 0;
+ #pragma unroll 1
+ for (int sub=0; sub<SUBSTEPS; ++sub) {
+ if (AnyAre<WaitFlag>(flags...)) {
+ if (threadIdx.x == 0) {
+ WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
+ }
+ asm volatile ("bar.sync 1, %0;" :: "r"(THREADS));
+ }
+ ReduceOrCopy
+ <
+ UNROLL,
+ THREADS,
+ OpType,
+ T,
+ !std::is_same<DST2_T, nullptr_t>::value, // HAS_DEST1
+ !std::is_same<SRC2_T, nullptr_t>::value // HAS_SRC1
+ >
+ (
+ threadIdx.x,
+ ptradd(dst1, sliceOffset),
+ ptradd(dst2, sliceOffset),
+ ptradd(src1, sliceOffset),
+ ptradd(src2, sliceOffset),
+ min(sliceSize, maxoffset-sliceOffset)
+ );
+ if (AnyAre<PostFlag>(flags...)) {
+ __syncthreads();
+ }
+ sliceOffset += sliceSize;
+ }
+ } else {
+ for(int sub=0; sub<SUBSTEPS; ++sub) {
+ if (AnyAre<PostFlag>(flags...)) {
+ __syncthreads();
+ __threadfence_system();
+ PostToFlags(SUBSTEPS*step + sub + 1, flags...);
+ }
+ }
+ }
+ }
+
+ public:
+ template <typename... SYNC_Ts>
+ static __device__ __forceinline__ void
+ Copy(const T* src, T* dst,
+ int len, int step, SYNC_Ts... flags) {
+ GenericOp(src, nullptr, dst, nullptr, len, step, flags...);
+ }
+
+ template <typename... SYNC_Ts>
+ static __device__ __forceinline__ void
+ DoubleCopy(const T* src, T* dst1, T* dst2,
+ int len, int step, SYNC_Ts... flags) {
+ GenericOp(src, nullptr, dst1, dst2, len, step, flags...);
+ }
+
+ template <typename... SYNC_Ts>
+ static __device__ __forceinline__ void
+ Reduce(const T* src1, const T* src2, T* dst,
+ int len, int step, SYNC_Ts... flags) {
+ GenericOp(src1, src2, dst, nullptr, len, step, flags...);
+ }
+
+ template <typename... SYNC_Ts>
+ static __device__ __forceinline__ void
+ ReduceCopy(const T* src1, const T* src2, T* dst1, T* dst2,
+ int len, int step, SYNC_Ts... flags) {
+ GenericOp(src1, src2, dst1, dst2, len, step, flags...);
+ }
+};
+
+#endif // end include guard
diff --git a/src/reduce.cu b/src/reduce.cu
index 77190b9..f281ce8 100644
--- a/src/reduce.cu
+++ b/src/reduce.cu
@@ -1,393 +1,150 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
-#include <algorithm>
-
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
-#include "reduce_kernel.h"
-
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
- * split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
- * single-data-element operations inside an UNROLL. As the name suggests, the
- * UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 4
-
-// if this is called with CHUNK, it means that we just finished pushing the data
-// of chunk CHUNK to the next GPU, so it can proceed with CHUNK
-// We add 1 to chunk so that the initial flag of 0 doesn't allow the non-root
-// GPUs to proceed before the flag is incremented from the upstream GPU. This
-// is called by one particular consumer warp and so we select the first thread
-// in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk) \
- do { \
- __threadfence_system(); \
- args.NextNewDataAvailableFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-#define WAIT_FOR_NEW_DATA(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// If this is called with CHUNK, it means that this GPU has just finished
-// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- bool newDataAvailable = \
- ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- bool chunkDone = \
- ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk)+subchunk + 1 - NUM_SUBCHUNKS; \
- return newDataAvailable && chunkDone; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *size = bigSliceN;
- *offset = slice * bigSliceN;
- } else {
- *size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- *offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
- }
-
-// if (threadIdx.x == 0)
-// printf("[size=%d] [offset=%d] slice=%d numSlices=%d "
-// "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
-// "lastSliceN=%d\n", *size, *offset, slice, numSlices, numBigSlices,
-// numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-}
+#include "primitives.h"
-template<typename T>
-struct ReduceKernelArgs {
- // general parameters
- int ThisId;
- int N;
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
- int BufferSliceStride;
+// Increase Step and boffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ boffset += sliceSize; \
+ if (boffset == buffSize) boffset = 0;
- T ** ThisPtrToNextData;
- T ** PrevPtrToThisData;
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
- // local and remote data
- T * __restrict__ Output;
- const T * __restrict__ ThisData;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
+template<int THREADS, int UNROLL, class FUNC, typename T>
+__launch_bounds__(THREADS+WARP_SIZE, 1)
+__global__ void ReduceKernel(const KernelArgs<T> args) {
+ const int tid = threadIdx.x;
+ __shared__ DevRing<T> ring;
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__shared__ volatile void * nextData;
-enum ReduceRole {BEGIN=0, MIDDLE=1, END=2};
-
-template<int THREADS, int UNROLL, class FUNC, int ROLE, typename T>
-__global__ void ReduceKernel(const ReduceKernelArgs<T> args) {
- if (args.N == 0) return;
- int tid = threadIdx.x;
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisData) == nullptr; // Wait for previous processor to be done
- });
-
- *((T * volatile *)args.PrevPtrToThisData) = (T*)args.ThisData; // Tell Previous I'm starting
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextData) != nullptr; // Wait till I've been told next started
- });
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
}
__syncthreads();
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
-
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
-
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
-
- int offset;
- int sliceSize;
-
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndOffset(&sliceSize, &offset, s, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- if (ROLE == BEGIN) {
- WAIT_FOR_CHUNK(chunk, s);
-
- Copy<UNROLL, THREADS>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- sliceSize);
- } else if (ROLE == MIDDLE) {
- WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, s);
-
- Reduce<UNROLL, THREADS, FUNC>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- } else { // ROLE == END
- WAIT_FOR_NEW_DATA(chunk, s);
-
- Reduce<UNROLL, THREADS, FUNC>(
- args.Output + chunkOffset + offset,
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- }
- __syncthreads();
- }
- } else { // Consumer thread
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- if (ROLE != END)
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s);
-
- // signal chunk done if we don't push into the receive buffer and this
- // is no the last chunk and this is not root
- if ((ROLE != BEGIN) && (chunk + 1 < args.NumChunks)) {
- SIGNAL_CHUNK_DONE(chunk, s);
- }
- }
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, (1-NUM_BUFCHUNKS)*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, 0);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, 0);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
+
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
+
+ const int size = args.N;
+ const int nranks = args.nRanks;
+ const int rank = ring.userRank[0];
+ const int prevRank = ring.userRank[nranks-1];
+ const int root = args.root;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int boffset = 0;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+
+ for (int offset = 0; offset < size; offset += sliceSize) {
+ int maxOffset = size-offset;
+ if (prevRank == root) {
+ Prims::Copy(
+ thisInput + offset,
+ nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext,
+ postReadyToNext);
+ } else if (rank == root) {
+ Prims::Reduce(
+ prevInput + boffset,
+ thisInput + offset,
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitReadyFromPrev,
+ postDoneToPrev);
+ } else {
+ Prims::ReduceCopy(
+ thisInput + offset,
+ prevInput + boffset,
+ thisOutput + offset,
+ nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
}
+ NEXT_STEP; // Increases step, boffset
}
- // reset flags
+ // wait for the last data to be pushed to us
if (tid == 0) {
- args.ThisNewDataAvailableFlag[0] = 0;
- args.ThisChunkDoneFlag[0] = 0;
- *args.ThisPtrToNextData = nullptr;
+ if (rank != root) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
+ }
+
+ if (prevRank != root) {
+ // reset the flag
+ *ring.recvFlagFromPrev = 0;
+ }
+
+ incrementOpCounter(&args);
}
}
+#define THREADS 512
+#define UNROLL 8
+
template<class FUNC, typename T>
-ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
- const int count, const int root, ncclComm* comm, cudaStream_t stream) {
+ncclResult_t RingReduce(const void* sendbuff, void* recvbuff, const int count, const int root,
+ ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
- int index = comm->ncclId;
-
- const int numUnroll = 4;
- int rootId = comm->ringFromUser[root];
-
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int bufferN = comm->buffSize / sizeof(T);
- // we only need buffer for k slices and k paddings
- int bufferNPerSlice = bufferN / NUM_SUBCHUNKS;
- int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
-
- ReduceKernelArgs<T> args;
-
- args.ThisId = index;
- args.N = count;
-
- args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
-
- if(!comm->useRemoteRecv) {
- // Proxy for QPI. Reduce never pushes directly to recv.
- // But larger transfers help QPI more than tag updates hurt P2P.
- args.SliceSize *= 8;
- }
-
- // make sure slice fits into the temporary buffer
- args.SliceSize = std::min(maxSliceSize, args.SliceSize);
- args.BufferSliceStride = args.SliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
- }
-
- args.ThisPtrToNextData = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisData = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
-
- args.Output = (T*)recvbuff;
- args.ThisData = (const T*) sendbuff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
-
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
-
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
-
- if (comm->nDev == 1) {
+ if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
- if (index == (rootId + 1) % comm->nDev) {
- ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, BEGIN, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else if (index == rootId) {
- ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, END, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- ReduceKernel<NUM_THREADS, UNROLL_COUNT, FUNC, MIDDLE, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
+ KernelArgs<T> args;
+ ArgsSetup(&args, sendbuff, recvbuff, root, count, comm);
+ LAUNCH_KERNEL(ReduceKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
- return ncclSuccess;
-}
-template <typename T>
-ncclResult_t ncclReduceWithType(const void* sendbuff,
- void* recvbuff, int count, ncclRedOp_t op, int root,
- ncclComm* comm, cudaStream_t stream) {
-
- switch (op) {
- case ncclSum:
- return ncclReduceWithTypeAndFunc<FuncSum<T>, T>(
- sendbuff, recvbuff, count, root, comm, stream);
- case ncclProd:
- return ncclReduceWithTypeAndFunc<FuncProd<T>, T>(
- sendbuff, recvbuff, count, root, comm, stream);
- case ncclMax:
- return ncclReduceWithTypeAndFunc<FuncMax<T>, T>(
- sendbuff, recvbuff, count, root, comm, stream);
- case ncclMin:
- return ncclReduceWithTypeAndFunc<FuncMin<T>, T>(
- sendbuff, recvbuff, count, root, comm, stream);
- }
- return ncclInvalidOperation;
+ return ncclSuccess;
}
-
+template<typename T, template<typename> class RedOp>
class ReduceFunctor {
-public:
- ncclResult_t operator()(const void* sendbuff,
- void* recvbuff, int count, ncclDataType_t datatype, ncclRedOp_t op,
- int root, ncclComm* comm, cudaStream_t stream) {
-
- switch (datatype) {
- case ncclChar:
- return ncclReduceWithType<char>(sendbuff, recvbuff, count, op, root, comm, stream);
- case ncclInt:
- return ncclReduceWithType<int>(sendbuff, recvbuff, count, op, root, comm, stream);
-#ifdef CUDA_HAS_HALF
- case ncclHalf:
- return ncclReduceWithType<half>(sendbuff, recvbuff, count, op, root, comm, stream);
-#endif
- case ncclFloat:
- return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
- case ncclDouble:
- return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
- case ncclInt64:
- return ncclReduceWithType<long long>(sendbuff, recvbuff, count, op, root, comm, stream);
- case ncclUint64:
- return ncclReduceWithType<unsigned long long>(sendbuff, recvbuff, count, op, root, comm, stream);
- }
- return ncclInvalidType;
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int root, ncclComm* comm, cudaStream_t stream) {
+ return RingReduce<RedOp<T>, T>(sendbuff, recvbuff, count, root, comm, stream);
}
};
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, int count,
+ ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, int count,
- ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm,
- cudaStream_t stream) {
- return enqueue(ReduceFunctor(), sendbuff, recvbuff, count, datatype, op,
- root, comm, stream);
+ ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
+ return enqueue<ReduceFunctor>(sendbuff, recvbuff, count, datatype, op, root, comm, stream);
}
diff --git a/src/reduce_kernel.h b/src/reduce_kernel.h
index 2ad6e21..f2cd512 100644
--- a/src/reduce_kernel.h
+++ b/src/reduce_kernel.h
@@ -1,7 +1,7 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
@@ -12,6 +12,13 @@
#include <limits>
template<typename T>
+struct FuncNull {
+ __device__ T operator()(const T x, const T y) const {
+ return 0;
+ }
+};
+
+template<typename T>
struct FuncSum {
__device__ T operator()(const T x, const T y) const {
return x + y;
@@ -192,30 +199,46 @@ struct FuncMin<char> {
template<>
struct FuncSum<half> {
__device__ half2 operator()(const half2 x, const half2 y) const {
+#if __CUDA_ARCH__ >= 530
+ return __hadd2(x, y);
+#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x + fy.x;
fr.y = fx.y + fy.y;
return __float22half2_rn(fr);
+#endif
}
__device__ half operator()(const half x, const half y) const {
+#if __CUDA_ARCH__ >= 530
+ return __hadd(x, y);
+#else
return __float2half( __half2float(x) + __half2float(y) );
+#endif
}
};
template<>
struct FuncProd<half> {
__device__ half2 operator()(const half2 x, const half2 y) const {
+#if __CUDA_ARCH__ >= 530
+ return __hmul2(x, y);
+#else
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
fr.x = fx.x * fy.x;
fr.y = fx.y * fy.y;
return __float22half2_rn(fr);
+#endif
}
__device__ half operator()(const half x, const half y) const {
+#if __CUDA_ARCH__ >= 530
+ return __hmul(x, y);
+#else
return __float2half( __half2float(x) * __half2float(y) );
+#endif
}
};
@@ -225,15 +248,15 @@ struct FuncMax<half> {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
- fr.x = fx.x > fy.x ? fx.x : fy.x;
- fr.y = fx.y > fy.y ? fx.y : fy.y;
+ fr.x = fmaxf(fx.x, fy.x);
+ fr.y = fmaxf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
- fm = fx > fy ? fx : fy;
+ fm = fmaxf(fx, fy);
return __float2half(fm);
}
};
@@ -244,15 +267,15 @@ struct FuncMin<half> {
float2 fx, fy, fr;
fx = __half22float2(x);
fy = __half22float2(y);
- fr.x = fx.x < fy.x ? fx.x : fy.x;
- fr.y = fx.y < fy.y ? fx.y : fy.y;
+ fr.x = fminf(fx.x, fy.x);
+ fr.y = fminf(fx.y, fy.y);
return __float22half2_rn(fr);
}
__device__ half operator()(const half x, const half y) const {
float fx, fy, fm;
fx = __half2float(x);
fy = __half2float(y);
- fm = fx < fy ? fx : fy;
+ fm = fminf(fx, fy);
return __float2half(fm);
}
};
diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu
index a30e7cd..75f203b 100644
--- a/src/reduce_scatter.cu
+++ b/src/reduce_scatter.cu
@@ -1,496 +1,166 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
-#include <cassert>
-
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
-#include "reduce_kernel.h"
-
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is an independent, complete reduction. Each
- * GPU has a buffer that can fit an entire CHUNK, so that all SUBCHUNKS can be
- * processed without checking that the buffer on the receiving GPU is empty. A
- * SUBCHUNK is split into NUM_GPUS SLICES and each GPU works on a different
- * SLICE at the same time. Before moving on the the next SLICE in the reduction
- * algorithm, the GPU has to check whether it has received the data from the
- * previous GPU it needs for this SLICE. To hide the latency of this
- * communication, each GPU processes all the SLICES of all the SUBCHUNKS in
- * sequence before moving on to the next SLICE. Each SLICE is split into a
- * certain number of UNROLLS (determined by the buffer size) and each thread
- * performs UNROLL_COUNT single-data-element operations inside an UNROLL. As the
- * name suggests, the UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 2
-
-/*
- * numGPUs BLOCKs consisting of recvcount words each
- * BLOCK is split up into NumChunks CHUNKs
- * CHUNK is split up into NUM_SUBCHUNKS SUBCHUNKs
- * SUBCHUNK consists of exactly one SLICE
- * SLICE is most efficiently processed in multiples of UNROLL_SIZE
- *
- * The algorithm has numGPUs steps and each step processes a SLICE (i.e.
- * SUBCHUNK) of a different BLOCK. Only data of the BLOCKs not resident on the
- * GPU need to be communicated, hence (numGPUs - 1) BLOCKs. So the buffer needs
- * to have room for (numGPUs - 1) SLICEs.
- */
-
-
-// do not encode the subchunk number into the flag, because there is a separate
-// flag for each subchunk
-
-// If this is called with STEP, it means that we just finished processing the
-// data for step STEP on this GPU, which is the data required on the next GPU
-// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
-// is available. This is called by one particular consumer warp and so we select
-// the first thread in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step) \
- do { \
- args.NextNewDataAvailableFlag[0] = \
- 2*((chunk) * args.NumGPUs + (step)) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_NEW_DATA(chunk, subchunk, step) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- 2*((chunk) * args.NumGPUs + (step)) + subchunk - 1; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// If this is called with CHUNK, it means that this GPU has just finished
-// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- args.PrevChunkDoneFlag[0] = 2*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- 2*(chunk) + subchunk - 1; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-
-__device__ inline void getSliceSizeAndChunkSize(int *sliceSize, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *sliceSize = bigSliceN;
- } else {
- *sliceSize = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- }
-
-/* if (threadIdx.x == 0)
- printf("[sliceSize=%d] slice=%d numSlices=%d "
- "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
- "lastSliceN=%d\n", *sliceSize, slice, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-*/
-}
-
-template<typename T>
-struct ReduceScatterKernelArgs {
- // general parameters
- int ThisId;
- int NumGPUs;
- int N;
- int * UserFromRing;
-
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
+#include "primitives.h"
- int BufferSliceStride;
- int BufferMisalignedN;
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
- T ** ThisPtrToNextOutput;
- T ** PrevPtrToThisOutput;
+// Increase Step and poffset/noffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ poffset = noffset; \
+ noffset += sliceSize; \
+ if (noffset == buffSize) noffset = 0;
- // local and remote input, output, and buffer
- const T * __restrict__ ThisInput;
- volatile T * __restrict__ ThisOutput;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
-
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__device__ inline int GetBlock(const int index, const int step,
- const int * const userFromRing, const int numGPUs) {
- return userFromRing[(numGPUs + index - 1 - step) % numGPUs];
-}
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
template<int THREADS, int UNROLL, class FUNC, typename T>
-__global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) {
- if (args.N == 0) return;
- int tid = threadIdx.x;
+__launch_bounds__(THREADS+WARP_SIZE, 1)
+__global__ void ReduceScatterKernel(const KernelArgs<T> args) {
+ const int tid = threadIdx.x;
+ __shared__ DevRing<T> ring;
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
- if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr; // Wait for previous processor to be done
- });
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
- *((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput; // Tell Previous I'm starting
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr; // Wait till I've been told next started
- });
+ if (tid == 0) {
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
}
__syncthreads();
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
-
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
-
-
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, -NUM_BUFCHUNKS*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, -1*NUM_SUBSTEPS);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, -1*NUM_SUBSTEPS);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
+
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T, FUNC> Prims;
+
+ const int size = args.N;
+ const int nranks = args.nRanks;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int poffset, noffset = 0;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+
+ for (int chunkOffset = 0; chunkOffset < size; chunkOffset += sliceSize) {
+ /////////////// begin ReduceScatter steps ///////////////
+ int offset;
+ int maxOffset = size-chunkOffset;
+ int rankDest;
// step 0: push data to next GPU
- int step = 0;
- int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- int blockOffset = chunkOffset + block * args.N;
- int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- ((block * args.BufferMisalignedN) % alignof(PackType));
- int sliceSize;
-
- if (tid < NUM_THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- WAIT_FOR_CHUNK(chunk, s);
- Copy<UNROLL, THREADS>(
- args.NextBuffer + bufferOffset,
- args.ThisInput + blockOffset,
- sliceSize);
- __syncthreads();
- bufferOffset += sliceSize;
- blockOffset += sliceSize;
- }
- } else { // Is consumer
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
-
- // steps j with 0 < j < k - 1, where k = number of GPUs: reduce and copy to
- // next GPU
- for (step = 1; step < args.NumGPUs - 1; ++step) {
- int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- int blockOffset = chunkOffset + block * args.N;
- int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- ((block * args.BufferMisalignedN) % alignof(PackType));
-
- if (tid < NUM_THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
- WAIT_FOR_NEW_DATA(chunk, s, step);
- Reduce<UNROLL, THREADS, FUNC>(
- args.NextBuffer + bufferOffset,
- args.ThisBuffer + bufferOffset,
- args.ThisInput + blockOffset,
- sliceSize);
- __syncthreads();
- bufferOffset += sliceSize;
- blockOffset += sliceSize;
- }
- } else {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
- }
- }
+ rankDest = ring.userRank[nranks-1];
+ offset = chunkOffset + rankDest * size;
+
+ Prims::Copy(
+ thisInput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP; // Increases step, poffset, noffset
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ rankDest = ring.userRank[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ Prims::Reduce(
+ prevInput + poffset,
+ thisInput + offset,
+ nextOutput + noffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
// step k - 1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
- step = args.NumGPUs - 1;
- block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
- blockOffset = chunkOffset + block * args.N;
- bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
- ((block * args.BufferMisalignedN) % alignof(PackType));
-
- if (tid < NUM_THREADS) {
- int outputOffset = 0;
- for (int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
- numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
- WAIT_FOR_NEW_DATA(chunk, s, step);
- Reduce<UNROLL, THREADS, FUNC>(
- args.ThisOutput + (chunkOffset + outputOffset),
- args.ThisBuffer + bufferOffset,
- args.ThisInput + blockOffset,
- sliceSize);
- __syncthreads();
- outputOffset += sliceSize;
- bufferOffset += sliceSize;
- blockOffset += sliceSize;
- }
- } else {
- for (int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
-
- // signal that chunk is done if this is not the last chunk
- if (chunk + 1 < args.NumChunks) {
- SIGNAL_CHUNK_DONE(chunk, s);
- }
- }
- }
+ rankDest = ring.userRank[0];
+ offset = chunkOffset + rankDest * size;
+
+ Prims::Reduce(
+ prevInput + poffset,
+ thisInput + offset,
+ thisOutput + chunkOffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+
+ NEXT_STEP;
}
// wait for the last data to be pushed to us
- if (tid < NUM_THREADS) {
- WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);
+ if (tid == 0) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
- if (tid == 0) {
- args.ThisNewDataAvailableFlag[tid] = 0;
- args.ThisChunkDoneFlag[tid] = 0;
- *args.ThisPtrToNextOutput = nullptr;
- }
+ // Wait for last update from prev then reset the flag
+ waitReadyFromPrev.wait(NUM_SUBSTEPS*(step+1));
+ *ring.recvFlagFromPrev = 0;
+
+ incrementOpCounter(&args);
}
}
+#define THREADS 512
+#define UNROLL 8
+
template<class FUNC, typename T>
-ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
- void* recvbuff, const int recvcount, ncclComm* comm, cudaStream_t stream) {
- if (recvcount == 0) {
+ncclResult_t RingReduceScatter(const void* sendbuff, void* recvbuff,
+ const int count, ncclComm* comm, cudaStream_t stream) {
+ if (count == 0)
return ncclSuccess;
- }
- int index = comm->ncclId;
-
- int blockSizeInBytes = recvcount * sizeof(T);
- int misalignedBytes = blockSizeInBytes % alignof(uint64_t);
-
- assert((int)((misalignedBytes / sizeof(T)) * sizeof(T)) == misalignedBytes);
-
- int misalignedN = misalignedBytes / sizeof(T);
- assert(misalignedN < (int)(sizeof(uint64_t) / sizeof(T)));
-
- int paddingN = (misalignedN > 0) ? sizeof(uint64_t) / sizeof(T) : 0;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int bufferN = comm->buffSize / sizeof(T);
- // we only need buffer for k slices and k*k paddings (we need k paddings per
- // block and we have k blocks)
- int bufferNPerSlice = (bufferN - NUM_SUBCHUNKS * comm->nDev * paddingN) /
- (NUM_SUBCHUNKS * comm->nDev);
- int sliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
-
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- ReduceScatterKernelArgs<T> args;
-
- args.ThisId = index;
- args.NumGPUs = comm->nDev;
- args.N = recvcount;
- /* Block j must end up in recvbuff[j], which lives on device with logical
- * index comm->ringFromUser[j]. But the block ordering does not necessarily
- * follow the ring ordering. Hence the order in which a particular GPU
- * processes the different blocks (the correspondence between the step in
- * the reduction algorithm and the block on which a GPU operates in that
- * particular step) is not the same as the ring order.
- *
- * Say we have 4 GPUs and comm->userFromRing = { 1, 2, 0, 3 }. Then there are 4
- * step in the reduction algorithm and block 0 needs to end up device 2,
- * block 1 on device 0, block 2 on device 1, and block 3 needs to end up on
- * device 3. In the last step of the algorithm, each GPU must be processing
- * the block that will end up on that GPU. The blocks that a GPU has to
- * process in the previous steps is determined by the next step because each
- * GPU only hands off data to the next GPU in the ring.
- *
- * In the above example, we get the following table of which block is
- * processed by each GPU in a given step. The columns correspond to the
- * different GPUs while the rows are the steps in the algorithm.
- *
- * GPU 0 1 2 3
- * step
- * 0 3 1 2 0
- * 1 0 3 1 2
- * 2 2 0 3 1
- * 3 1 2 0 3
- *
- * We note the the rows in the above table are just comm->userFromRing in the last
- * step and the list is cyclicly permuted to the left for each previous
- * step. The columns, which are what the individual GPUs need to know, are
- * comm->userFromRing traversed backwards and starting at index k-1 for GPU k.
- * These columns are what we put into args.BlockVsStep to tell the GPU which
- * block it needs to be processing at a particular step. */
- args.UserFromRing = comm->devUserFromRing;
-
- args.SliceSize = sliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // don't reduce this if we cut the slice size in half below, because if that
- // happens, the last chunk will be larger than the other chunks, and we will
- // need the extra buffer space
- args.BufferSliceStride = args.SliceSize + paddingN;
-
- args.BufferMisalignedN = misalignedN;
-
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
-
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
- }
-
- args.ThisPtrToNextOutput = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisOutput = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
-
- args.ThisInput = (const T*)sendbuff;
- args.ThisOutput = (volatile T*)recvbuff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
-
- // we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
- // to signal the next GPU that new data is available and the following
- // NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
-
- if (comm->nDev == 1) {
+ if (comm->nRanks == 1) {
if (sendbuff != recvbuff)
- CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, recvcount*sizeof(T), cudaMemcpyDeviceToDevice, stream));
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, count*sizeof(T), cudaMemcpyDeviceToDevice, stream));
} else {
- ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
+ KernelArgs<T> args;
+ ArgsSetup(&args, sendbuff, recvbuff, 0, count, comm);
+ LAUNCH_KERNEL(ReduceScatterKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
- return ncclSuccess;
-}
-template<typename T>
-ncclResult_t ncclReduceScatterWithType(const void* sendbuff, void* recvbuff,
- int recvcount, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
- switch (op) {
- case ncclSum:
- return ncclReduceScatterWithTypeAndFunc<FuncSum<T>, T>(
- sendbuff, recvbuff, recvcount, comm, stream);
- case ncclProd:
- return ncclReduceScatterWithTypeAndFunc<FuncProd<T>, T>(
- sendbuff, recvbuff, recvcount, comm, stream);
- case ncclMax:
- return ncclReduceScatterWithTypeAndFunc<FuncMax<T>, T>(
- sendbuff, recvbuff, recvcount, comm, stream);
- case ncclMin:
- return ncclReduceScatterWithTypeAndFunc<FuncMin<T>, T>(
- sendbuff, recvbuff, recvcount, comm, stream);
- }
- return ncclInvalidOperation;
+ return ncclSuccess;
}
-class ReduceScatterFunctor {
-public:
- ncclResult_t operator()(const void* sendbuff, void* recvbuff,
- int recvcount, ncclDataType_t datatype, ncclRedOp_t op, int /*root*/,
- ncclComm* comm, cudaStream_t stream) {
-
- switch (datatype) {
- case ncclChar:
- return ncclReduceScatterWithType<char>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
- case ncclInt:
- return ncclReduceScatterWithType<int>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
-#ifdef CUDA_HAS_HALF
- case ncclHalf:
- return ncclReduceScatterWithType<half>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
-#endif
- case ncclFloat:
- return ncclReduceScatterWithType<float>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
- case ncclDouble:
- return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
- case ncclInt64:
- return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
- case ncclUint64:
- return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
- op, comm, stream);
- }
- return ncclInvalidType;
+template<typename T, template <typename> class RedOp>
+class ReduceScatter {
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int /*root*/, ncclComm* comm, cudaStream_t stream) {
+ return RingReduceScatter<RedOp<T>, T>(sendbuff, recvbuff, count, comm, stream);
}
};
-extern "C" DSOGLOBAL
-ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
- int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm,
- cudaStream_t stream) {
- return enqueue(ReduceScatterFunctor(), sendbuff, recvbuff, recvcount,
- datatype, op, 0, comm, stream);
+NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, int recvcount,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
+ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
+ return enqueue<ReduceScatter>(sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream);
}
+