Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2016-09-22 21:57:56 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2016-09-22 21:57:56 +0300
commitcabd6848e4c07e73f6db2cf74e3db0c1b7191fa9 (patch)
treee06cb9e849628b14970ecd6291e9614e6d0c1711 /src/broadcast.cu
parente3dbc6110ebefdf5792de0c60fda1d81822d1454 (diff)
Heavy code refactoring to remove a lot of code in collectives (~1000 lines).
Have all collectives use the same args, the same ring, and the same primitives for synchronization between threads with the same pattern.
Diffstat (limited to 'src/broadcast.cu')
-rw-r--r--src/broadcast.cu477
1 files changed, 125 insertions, 352 deletions
diff --git a/src/broadcast.cu b/src/broadcast.cu
index a57f4da..5843922 100644
--- a/src/broadcast.cu
+++ b/src/broadcast.cu
@@ -1,392 +1,165 @@
/*************************************************************************
* Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
- * See LICENCE.txt for license information
+ * See LICENSE.txt for license information
************************************************************************/
-#include <algorithm>
-
#include "core.h"
-#include "common_kernel.h"
-#include "copy_kernel.h"
#include "enqueue.h"
+#include "primitives.h"
-/* HIERARCHY
- *
- * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
- * SUBCHUNKS, where each SUBCHUNK is processed independently. A SUBCHUNK is
- * split into numUnroll UNROLLS and each thread performs UNROLL_COUNT
- * single-data-element operations inside an UNROLL. As the name suggests, the
- * UNROLL_COUNT operations within an UNROLL are unrolled.
-*/
-
-// Number of threads used to perform copies, etc. Must be multiple of 32.
-// An additional thread is used to handle threadfences, so the CUDA blocks
-// have dimension NUM_THREADS+1.
-#define NUM_THREADS 256
-
-// Each thread unrolls the innermost loop of the copy or reduction operations
-// to this many single-data-element instructions
-#define UNROLL_COUNT 8
-
-#define UNROLL_SIZE (UNROLL_COUNT * NUM_THREADS)
-
-// To hide the latency associated with the synchronization between different
-// subchunks, we interleave the independent subchunks so that more data can be
-// transferred while the sync is in progress. This is the number of subchunks
-// that are active at the same time
-#define NUM_SUBCHUNKS 4
-
-// if this is called with CHUNK, it means that we just finished pushing the data
-// of chunk CHUNK to the next GPU, so it can proceed with CHUNK
-// We add 1 to chunk so that the initial flag of 0 doesn't allow the non-root
-// GPUs to proceed before the flag is incremented from the upstream GPU. This
-// is called by one particular consumer warp and so we select the first thread
-// in the warp to set the flag.
-#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk) \
- do { \
- __threadfence_system(); \
- args.NextNewDataAvailableFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-#define WAIT_FOR_NEW_DATA(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// If this is called with CHUNK, it means that this GPU has just finished
-// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
-#define SIGNAL_CHUNK_DONE(chunk, subchunk) \
- do { \
- args.PrevChunkDoneFlag[0] = NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- return ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1 - NUM_SUBCHUNKS; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-// This is called by all producer threads, but only thread 0 spins on the flag,
-// all threads synchronize after thread 0 is done spinning.
-#define WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, subchunk) \
- do { \
- if (tid == 0) { \
- Wait([=] { \
- bool newDataAvailable = \
- ((volatile int *)args.ThisNewDataAvailableFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk) + subchunk + 1; \
- bool chunkDone = \
- ((volatile int *)args.ThisChunkDoneFlag)[0] >= \
- NUM_SUBCHUNKS*(chunk)+subchunk + 1 - NUM_SUBCHUNKS; \
- return newDataAvailable && chunkDone; \
- }); \
- } \
- BAR(sync, 1, NUM_THREADS); \
- } while (0)
-
-__device__ inline void getSliceSizeAndOffset(int *size, int *offset, int slice,
- int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
- int smallSliceN, int lastSliceN) {
- if (slice < numBigSlices) {
- *size = bigSliceN;
- *offset = slice * bigSliceN;
- } else {
- *size = (slice < numBigSlices + numSmallSlices) ? smallSliceN
- : ((slice == numSlices - 1) ? lastSliceN : 0);
- *offset = numBigSlices * bigSliceN + (slice - numBigSlices) * smallSliceN;
- }
-
-// if (threadIdx.x == 0)
-// printf("[size=%d] [offset=%d] slice=%d numSlices=%d "
-// "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
-// "lastSliceN=%d\n", *size, *offset, slice, numSlices, numBigSlices,
-// numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-}
-
-template<typename T>
-struct BroadcastKernelArgs {
- // general parameters
- int ThisId;
- int N;
+#define NUM_SUBSTEPS 2
+#define NUM_BUFCHUNKS 2
- // some pre-computed sizes
- int SliceSize;
- int ChunkSize;
- int NumChunks;
- int BufferSliceStride;
+// Increase Step and boffset for buffer sync
+#define NEXT_STEP \
+ step++; \
+ boffset += sliceSize; \
+ if (boffset == buffSize) boffset = 0;
- T ** ThisPtrToNextData;
- T ** PrevPtrToThisData;
+#define ALIGN_SIZE(size, align) \
+ size = ((size + (align) - 1) / (align)) * (align);
- // local and remote data
- T * __restrict__ ThisData;
- volatile T * __restrict__ ThisBuffer;
- volatile T * __restrict__ NextBuffer;
+template<int THREADS, int UNROLL, class FUNC, typename T>
+__launch_bounds__(THREADS+WARP_SIZE, 1)
+__global__ void BroadcastKernel(const KernelArgs<T> args) {
+ const int tid = threadIdx.x;
+ __shared__ T* sharedNextOutput;
+ __shared__ DevRing<T> ring;
+ bool pushrecv = args.pushrecv;
- // local and remote flags
- volatile int * __restrict__ ThisNewDataAvailableFlag;
- volatile int * __restrict__ NextNewDataAvailableFlag;
- volatile int * __restrict__ ThisChunkDoneFlag;
- volatile int * __restrict__ PrevChunkDoneFlag;
-};
-
-__shared__ volatile void * nextData;
-enum BcastRole {ROOT=0, MIDDLE=1, END=2};
-
-template<int THREADS, int UNROLL, bool PUSHRECV, int ROLE, typename T>
-__global__ void BroadcastKernel(const BroadcastKernelArgs<T> args) {
- if (args.N == 0) return;
- int tid = threadIdx.x;
+ LoadRing<THREADS>(args.ring, &ring);
+ __syncthreads();
- // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
- // the previous GPU is done with a previous collective operation.
if (tid == 0) {
- Wait([=] {
- return *((T * volatile *)args.PrevPtrToThisData) == nullptr; // Wait for previous processor to be done
- });
-
- *((T * volatile *)args.PrevPtrToThisData) = (T*)args.ThisData; // Tell Previous I'm starting
- Wait([=] {
- return *((T * volatile *)args.ThisPtrToNextData) != nullptr; // Wait till I've been told next started
- });
-
- if (PUSHRECV)
- nextData = *((volatile void * volatile *)args.ThisPtrToNextData); // Grab next's pointer if needed.
+ WaitFlag prevCommOp(ring.prevOpCounter, 0);
+ WaitFlag nextCommOp(ring.nextOpCounter, 0);
+ prevCommOp.wait(args.opIndex);
+ nextCommOp.wait(args.opIndex);
+ if (pushrecv) {
+ *ring.sendPtrToPrev = (T*)args.ThisOutput;
+ Wait([=] {
+ return *ring.recvPtrFromNext != nullptr;
+ });
+ sharedNextOutput = *ring.recvPtrFromNext;
+ *ring.recvPtrFromNext = nullptr;
+ }
}
__syncthreads();
- for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
- // calculate slice size. for all chunks except (possibly) the last one,
- // this will just be args.SliceSize. For the last one, it may be smaller
- int bigSliceN = args.SliceSize;
- int smallSliceN = 0;
- int lastSliceN = 0;
- int numSlices = NUM_SUBCHUNKS;
- int numBigSlices = numSlices;
- int numSmallSlices = 0;
-
- // last chunk
- if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
- CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
- &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
- args.ChunkSize);
-
- // this offset is only applied to Data pointers, not to Buffer pointers,
- // since we only have one buffer per chunk
- int chunkOffset = chunk * args.ChunkSize;
-
- int offset;
- int sliceSize;
-
- if (tid < THREADS) {
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- getSliceSizeAndOffset(&sliceSize, &offset, s, numSlices,
- numBigSlices, numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
-
- if (PUSHRECV) {
- if (ROLE != ROOT)
- WAIT_FOR_NEW_DATA(chunk, s);
-
- if (ROLE != END)
- Copy<UNROLL, THREADS>(
- (volatile T *)nextData + chunkOffset + offset,
- args.ThisData + chunkOffset + offset,
- sliceSize);
- } else { // PUSH2BUFF
- if (ROLE == ROOT) {
- WAIT_FOR_CHUNK(chunk, s);
-
- Copy<UNROLL, THREADS>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- sliceSize);
- } else if (ROLE == MIDDLE) {
- WAIT_FOR_NEW_DATA_AND_CHUNK(chunk, s);
-
- DoubleCopy<UNROLL, THREADS>(
- args.NextBuffer + (s * args.BufferSliceStride),
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- } else { // ROLE == END
- WAIT_FOR_NEW_DATA(chunk, s);
-
- Copy<UNROLL, THREADS>(
- args.ThisData + chunkOffset + offset,
- args.ThisBuffer + (s * args.BufferSliceStride),
- sliceSize);
- }
- }
- __syncthreads();
- }
- } else { // Consumer thread
- for(int s=0; s<NUM_SUBCHUNKS; ++s) {
- __syncthreads();
- if (ROLE != END)
- SIGNAL_NEW_DATA_AVAILABLE(chunk, s);
-
- // signal chunk done if we don't push into the receive buffer and this
- // is no the last chunk and this is not root
- if ((!PUSHRECV) && (ROLE != ROOT) && (chunk + 1 < args.NumChunks)) {
- SIGNAL_CHUNK_DONE(chunk, s);
- }
+ WaitFlag waitDoneFromNext(ring.recvFlagFromNext, (1-NUM_BUFCHUNKS)*NUM_SUBSTEPS);
+ WaitFlag waitReadyFromPrev(ring.recvFlagFromPrev, 0);
+ PostFlag postDoneToPrev(ring.sendFlagToPrev, 0);
+ PostFlag postReadyToNext(ring.sendFlagToNext, 0);
+
+ typedef Primitives<THREADS, UNROLL, NUM_SUBSTEPS, T> Prims;
+
+ const int size = args.N;
+ const int rank = ring.userRank[0];
+ const int nextRank = ring.userRank[1];
+ const int root = args.root;
+ const int buffSize = args.buffSize / sizeof(T);
+ const int sliceSize = buffSize / NUM_BUFCHUNKS;
+
+ int step = 0;
+ int boffset = 0;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = args.ThisInput;
+ T * __restrict__ thisOutput = args.ThisOutput;
+ T * __restrict__ prevInput = ring.recvBuffer;
+ T * __restrict__ nextOutput = ring.sendBuffer;
+
+ for (int offset = 0; offset < size; offset += sliceSize) {
+ int maxOffset = size-offset;
+ if (rank == root) {
+ Prims::Copy(
+ thisInput + offset,
+ pushrecv ? sharedNextOutput + offset : nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext,
+ postReadyToNext);
+ } else if (nextRank == root) {
+ if (pushrecv) maxOffset = 0; // Only wait for signals
+ Prims::Copy(
+ prevInput + boffset,
+ thisOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitReadyFromPrev,
+ postDoneToPrev);
+ } else {
+ if (pushrecv) {
+ Prims::Copy(
+ thisOutput + offset,
+ sharedNextOutput + offset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
+ } else {
+ Prims::DoubleCopy(
+ prevInput + boffset,
+ thisOutput + offset,
+ nextOutput + boffset,
+ sliceSize, maxOffset,
+ step,
+ waitDoneFromNext, waitReadyFromPrev,
+ postReadyToNext, postDoneToPrev);
}
}
+ NEXT_STEP; // Increases step, boffset
}
- // reset flags
+ // wait for the last data to be pushed to us
if (tid == 0) {
- args.ThisNewDataAvailableFlag[0] = 0;
- args.ThisChunkDoneFlag[0] = 0;
- *args.ThisPtrToNextData = nullptr;
- }
-}
-
-template<typename T>
-ncclResult_t ncclBcastWithType(void* buff, const int count, const int root,
- ncclComm* comm, int numUnroll, cudaStream_t stream) {
- if (count == 0)
- return ncclSuccess;
-
- int index = comm->ncclId;
- int rootId = comm->ringFromUser[root];
-
- int nextId = (index + 1) % comm->nDev;
- int prevId = (index + comm->nDev - 1) % comm->nDev;
-
- // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
- // where bufferN is the number of elements of type T that fit into the buffer.
- // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
- int bufferN = comm->buffSize / sizeof(T);
- // we only need buffer for k slices and k paddings
- int bufferNPerSlice = bufferN / NUM_SUBCHUNKS;
- int maxSliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;
-
- BroadcastKernelArgs<T> args;
-
- args.ThisId = index;
- args.N = count;
+ if (nextRank != root) {
+ // Wait for last update from next then reset the flag
+ waitDoneFromNext.wait(NUM_SUBSTEPS*(step+NUM_BUFCHUNKS-1));
+ *ring.recvFlagFromNext = 0;
+ }
- args.SliceSize = numUnroll * UNROLL_SIZE * sizeof(PackType) / sizeof(T);
+ if (rank != root) {
+ // reset the flag
+ *ring.recvFlagFromPrev = 0;
+ }
- // if we don't directly push into the remote receive buffer, make sure slice
- // fits into the temporary buffer
- if (!comm->useRemoteRecv) {
- // Larger transfers help QPI more than tag updates hurt P2P.
- args.SliceSize *= 8;
+ incrementOpCounter(&args);
}
+}
- args.SliceSize = std::min(maxSliceSize, args.SliceSize);
- args.BufferSliceStride = args.SliceSize;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
+#define THREADS 256
+#define UNROLL 8
- // avoid a case where we have one or more big chunks and one tiny one
- int remainder = args.N % args.ChunkSize;
- if ((args.N > args.ChunkSize) && (remainder > 0) &&
- (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
- args.SliceSize /= 2;
- args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;
+template<class FUNC, typename T>
+ncclResult_t RingBroadcast(void* buff, const int count, const int root,
+ ncclComm* comm, cudaStream_t stream) {
+ if (count == 0)
+ return ncclSuccess;
- // round down so we end up with a big last chunk
- args.NumChunks = args.N / args.ChunkSize;
- } else {
- // round up
- args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
+ if (comm->nRanks != 1) {
+ KernelArgs<T> args;
+ ArgsSetup(&args, buff, buff, root, count, comm);
+ LAUNCH_KERNEL(BroadcastKernel, THREADS, UNROLL, FUNC, T, args, stream);
}
-// printf("sliceSize = %i, chunkSize = %i, numChunks = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks);
-
- args.ThisPtrToNextData = (T**)&(comm->ptrs[nextId].local->recvPtrs[0]);
- args.PrevPtrToThisData = (T**)&(comm->ptrs[prevId].remote->recvPtrs[0]);
-
- args.ThisData = (T*)buff;
- args.ThisBuffer = (volatile T*)comm->ptrs[prevId].local->buff;
- args.NextBuffer = (volatile T*)comm->ptrs[nextId].remote->buff;
-
- // we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
- // to signal the next GPU that new data is available and the following
- // NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
- args.ThisNewDataAvailableFlag = comm->ptrs[prevId].local->flags;
- args.NextNewDataAvailableFlag = comm->ptrs[nextId].remote->flags;
- args.ThisChunkDoneFlag = comm->ptrs[nextId].local->flags + 1;
- args.PrevChunkDoneFlag = comm->ptrs[prevId].remote->flags + 1;
-
- if (comm->nDev != 1) {
- if (comm->useRemoteRecv) {
- if (index == (rootId + comm->nDev - 1) % comm->nDev) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, END, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else if (index == rootId) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, ROOT, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, true, MIDDLE, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
- } else {
- if (index == (rootId + comm->nDev - 1) % comm->nDev) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, END, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else if (index == rootId) {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, ROOT, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- } else {
- BroadcastKernel<NUM_THREADS, UNROLL_COUNT, false, MIDDLE, T>
- <<<1, NUM_THREADS + 1, 0, stream>>>(args);
- }
- }
- }
return ncclSuccess;
}
-class BroadcastFunctor {
-public:
- ncclResult_t operator()(const void* /*dummy sendbuff*/,
- void* buff, int count, ncclDataType_t datatype, ncclRedOp_t /*dummy operation*/,
- int root, ncclComm* comm, cudaStream_t stream) {
- int numUnroll = 4;
-
- switch (datatype) {
- case ncclChar:
- return ncclBcastWithType<char>(buff, count, root, comm, numUnroll, stream);
- case ncclInt:
- return ncclBcastWithType<int>(buff, count, root, comm, numUnroll, stream);
-#ifdef CUDA_HAS_HALF
- case ncclHalf:
- return ncclBcastWithType<half>(buff, count, root, comm, numUnroll, stream);
-#endif
- case ncclFloat:
- return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
- case ncclDouble:
- return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
- case ncclInt64:
- return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
- case ncclUint64:
- return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
- }
- return ncclInvalidType;
+template<typename T, template<typename> class RedOp>
+class Broadcast {
+ public:
+ static ncclResult_t entry(const void* sendbuff, void* recvbuff,
+ int count, int root, ncclComm* comm, cudaStream_t stream) {
+ return RingBroadcast<RedOp<T>, T>(recvbuff, count, root, comm, stream);
}
};
-extern "C" DSOGLOBAL
+NCCL_API(ncclResult_t, ncclBcast, void* buff, int count, ncclDataType_t datatype, int root,
+ ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBcast(void* buff, int count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
- return enqueue(BroadcastFunctor(), nullptr, buff, count, datatype, ncclSum,
- root, comm, stream);
+ return enqueue<Broadcast, FuncNull>(nullptr, buff, count, datatype, root, comm, stream);
}