diff options
Diffstat (limited to 'src/collectives/device/common.h')
-rw-r--r-- | src/collectives/device/common.h | 204 |
1 files changed, 113 insertions, 91 deletions
diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 46eb9f5..265218a 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -10,6 +10,15 @@ #include "collectives.h" #include "devcomm.h" + +#if __CUDA_ARCH__ >= 800 +#define COLL_UNROLL 8 +#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree +#else +#define COLL_UNROLL 4 +#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY +#endif + // Exit If Abort Barrier across CTA: make sure all threads exit consistently // Each thread sets a predicate to true if abort == 1 // all CTA's threads enter the barrier and do a popc on their predicates being True @@ -19,12 +28,12 @@ static inline __device__ void exitIfAbortBarrier(int abort) { asm ("{"); asm volatile (" .reg .pred barr_pred;"); asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort)); - asm volatile (" bar.red.popc.u32 %0, 13, barr_pred;" : "=r"(popc)); + asm volatile (" bar.red.popc.u32 %0, 0, barr_pred;" : "=r"(popc)); asm ("}"); if (popc) { asm volatile ("exit;"); } } -typedef void(*ncclKern_t)(struct CollectiveArgs* args); +typedef void(*ncclKern_t)(struct ncclWorkElem* args); extern __device__ ncclKern_t ncclFuncs[]; static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) { @@ -32,130 +41,143 @@ static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) int* s = (int*)src; for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o]; } -static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid, struct ncclDevComm* comm) { +static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork* hostWork, int tid, struct ncclDevComm* comm) { + __syncthreads(); + load_parallel(localWork, hostWork, sizeof(struct ncclWork), tid); // Check whether the last operation was aborted and make sure all threads exit int abort = tid == 0 ? *(comm->abortFlag) : 0; exitIfAbortBarrier(abort); - load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid); - __syncthreads(); - if (tid == 0) hostColl->active = 0; + if (tid == 0) hostWork->elems[0].active = 0; } -extern __device__ volatile uint64_t* ncclShmem; +template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL> +class ncclFunction { + public: + __device__ void run(struct ncclWorkElem* args) {} +}; + +struct ncclShmemPtrs { + void* srcs[NCCL_MAX_DEV_ARITY+1]; + void* dsts[NCCL_MAX_DEV_ARITY+1]; +}; + +struct ncclShmemData { + union { + volatile uint64_t data[NCCL_LL128_SHMEM_SIZE]; + struct ncclShmemPtrs ptrs[NCCL_MAX_GROUPS]; + }; + struct ncclWork localWork; +}; -/* Functions for aggregation case */ -#define IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ -__device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \ - coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(args); \ +extern __device__ struct ncclShmemData *ncclShmem; +template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL, int FINDEX> +__device__ void ncclKernel(struct ncclWorkElem first) { + int tid = threadIdx.x; + int bid = blockIdx.x; + __shared__ struct ncclShmemData shmem; + ncclShmem = &shmem; + + auto f = ncclFunction<FUNCTION, ALGO, PROTO, REDOP, T, UNROLL>(); + + struct ncclDevComm* comm = first.comm; + struct ncclChannel* channel = comm->channels+bid; + struct ncclWorkElem* w = NULL; + uint16_t index = first.index; + + /* To optimize for latency, (only) the first operation is passed as argument.*/ + if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) w = &first; + + while (1) { + if (w == NULL) { + w = shmem.localWork.elems; + load_coll(&shmem.localWork, channel->workFifo+index, tid, comm); + } + if (tid < w->nThreads) { + if (w->funcIndex == FINDEX) { + f.run(w); + } else { + ncclFuncs[w->funcIndex](w); + } + } + index = (index+1) % NCCL_MAX_OPS; + if (w->active == 2) { + return; + } + w = NULL; + } } +// Only generate kernels for SUM #if NCCL_OP == 0 -/* Kernels with the first operation inlined */ -#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \ -__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ - int tid = threadIdx.x; \ - int bid = blockIdx.x; \ - __shared__ volatile uint64_t shmem[NCCL_LL128_SHMEM_SIZE]; \ - ncclShmem = shmem; \ - __shared__ struct ncclColl localColl; \ - \ - struct ncclDevComm* comm = firstColl.args.comm; \ - struct ncclChannel* channel = comm->channels+bid; \ - struct ncclColl* c; \ - if (bid == 0) { \ - /* To optimize for latency, (only) the first operation is passed as argument.*/ \ - c = &firstColl; \ - } else { \ - c = &localColl; \ - load_coll(c, channel->devCollectives+channel->collFifoHead, tid, comm); \ - } \ - while (1) { \ - if (tid < c->args.nThreads) { \ - if (c->funcIndex == fIndex) { \ - coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(&c->args); \ - } else { \ - ncclFuncs[c->funcIndex](&c->args); \ - } \ - } \ - int nextIndex = c->nextIndex; \ - if (tid == 0) channel->collFifoHead = nextIndex; \ - \ - if (c->active == 2) { \ - return; \ - } \ - \ - /* Load next collective operation*/ \ - c = &localColl; /* for bid 0 */ \ - load_coll(c, channel->devCollectives+nextIndex, tid, comm); \ - } \ +#define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \ +__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \ + ncclKernel<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL, fIndex>(first); \ } #else -#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) +#define IMPL_COLL_KERN(func, algo, proto, redop, type, fInded) #endif +// Examples : AllReduce, RING, LL, Sum, uint8 +#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \ +__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \ + auto f = ncclFunction<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL>(); \ + f.run(args); \ +} + // Only generate inline kernels for LL -#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \ - IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \ - IMPL_COLL_FUNC(coll##LL128, op, ncclFunc, dtype, ctype) \ - IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ - IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, al, NCCL_PROTO_LL)) \ +#define IMPL_COLL4(func, algo, redop, type, ncclType) \ + IMPL_COLL_FUNC(func, algo, LL, redop, type) \ + IMPL_COLL_FUNC(func, algo, LL128, redop, type) \ + IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type) \ + IMPL_COLL_KERN(func, algo, LL, redop, type, FUNC_INDEX(ncclFunc##func, nccl##redop, ncclType, NCCL_ALGO_##algo, NCCL_PROTO_LL)) \ -#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ - IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \ - IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) +#define IMPL_COLL3(func, redop, type, ncclType) \ + IMPL_COLL4(func, TREE, redop, type, ncclType) \ + IMPL_COLL4(func, RING, redop, type, ncclType) \ + IMPL_COLL4(func, COLLNET, redop, type, ncclType) #if NCCL_TYPE == 0 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int8_t, ncclInt8) #elif NCCL_TYPE == 1 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, u8, uint8_t, ncclColl, ncclOp, ncclUint8) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint8_t, ncclUint8) #elif NCCL_TYPE == 2 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, i32, int32_t, ncclColl, ncclOp, ncclInt32) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int32_t, ncclInt32) #elif NCCL_TYPE == 3 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, u32, uint32_t, ncclColl, ncclOp, ncclUint32) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint32_t, ncclUint32) #elif NCCL_TYPE == 4 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, i64, int64_t, ncclColl, ncclOp, ncclInt64) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int64_t, ncclInt64) #elif NCCL_TYPE == 5 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint64_t, ncclUint64) #elif NCCL_TYPE == 6 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, half, ncclFloat16) #elif NCCL_TYPE == 7 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, float, ncclFloat32) #elif NCCL_TYPE == 8 -#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ - IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) +#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, double, ncclFloat64) #endif // Reduction define all functions #if NCCL_OP == 0 -#define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, sum, FuncSum, colln, ncclSum); +#define IMPL_COLL_R(func) IMPL_COLL2(func, Sum); #elif NCCL_OP == 1 -#define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd); +#define IMPL_COLL_R(func) IMPL_COLL2(func, Prod); #elif NCCL_OP == 2 -#define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, min, FuncMin, colln, ncclMin); +#define IMPL_COLL_R(func) IMPL_COLL2(func, Min); #elif NCCL_OP == 3 -#define IMPL_COLL_R(collf, colln) \ - IMPL_COLL2(collf, max, FuncMax, colln, ncclMax); +#define IMPL_COLL_R(func) IMPL_COLL2(func, Max); #endif -// Copy primitives only define one #if NCCL_OP == 0 && NCCL_TYPE == 0 -#define IMPL_COLL_C(collf, colln) \ - IMPL_COLL3(collf, copy, FuncSum, i8, int8_t, colln, ncclSum, ncclInt8); +// Copy primitives only define one function for copy +#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t, ncclInt8); + +// Point-to-point primitives only have one function/kernel. +#define IMPL_COLL_P(func) \ + IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \ + IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, 0); #else -#define IMPL_COLL_C(collf, colln) +#define IMPL_COLL_C(func) +#define IMPL_COLL_P(func) #endif -#define COLL_UNROLL 4 - #endif |