diff options
Diffstat (limited to 'src/include/collectives.h')
-rw-r--r-- | src/include/collectives.h | 81 |
1 files changed, 42 insertions, 39 deletions
diff --git a/src/include/collectives.h b/src/include/collectives.h index f854364..9b9022e 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -8,55 +8,58 @@ #define NCCL_COLLECTIVES_H_ #define FUNC_INDEX_P2P 0 -#define FUNC_INDEX(coll, redop, dtype, al, pr) (1+(((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) +#define FUNC_INDEX(func, redop, ncclType, al, pr) (1+(((((func)*ncclNumOps + (redop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) -#define NCCL_COLL_NAME(coll, op, dtype) \ - coll##_##op##_##dtype +#define NCCL_FUNC_NAME(func, algo, proto, redop, type) \ + ncclFunction_##func##_##algo##_##proto##_##redop##_##type -#define NCCL_KERN_NAME(coll, op, dtype) \ - coll##Kernel_##op##_##dtype +#define NCCL_KERN_NAME(func, algo, proto, redop, type) \ + ncclKernel_##func##_##algo##_##proto##_##redop##_##type + +#define NCCL_IMPL_NAME(func, algo, proto) \ + nccl##func##algo##proto /* Declare all collective operations */ -#define DECL_COLL5(coll, op, dtype) \ - extern __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args); \ - extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl c); \ +#define DECL5(func, algo, proto, redop, type) \ + extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \ + extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem c); \ -#define DECL_COLL4(coll, op, dtype) \ - DECL_COLL5(coll, op, dtype) \ - DECL_COLL5(coll##LL, op, dtype) \ - DECL_COLL5(coll##LL128, op, dtype) +#define DECL4(func, algo, redop, type) \ + DECL5(func, algo, SIMPLE, redop, type) \ + DECL5(func, algo, LL, redop, type) \ + DECL5(func, algo, LL128, redop, type) -#define DECL_COLL3(coll, op, dtype) \ - DECL_COLL4(coll##Ring, op, dtype) \ - DECL_COLL4(coll##Tree, op, dtype) \ - DECL_COLL4(coll##CollNet, op, dtype) +#define DECL3(func, redop, type) \ + DECL4(func, RING, redop, type) \ + DECL4(func, TREE, redop, type) \ + DECL4(func, COLLNET, redop, type) -#define DECL_COLL2(coll, op) \ - DECL_COLL3(coll, op, i8) \ - DECL_COLL3(coll, op, u8) \ - DECL_COLL3(coll, op, i32) \ - DECL_COLL3(coll, op, u32) \ - DECL_COLL3(coll, op, i64) \ - DECL_COLL3(coll, op, u64) \ - DECL_COLL3(coll, op, f16) \ - DECL_COLL3(coll, op, f32) \ - DECL_COLL3(coll, op, f64) +#define DECL2(func, redop) \ + DECL3(func, redop, int8_t) \ + DECL3(func, redop, uint8_t) \ + DECL3(func, redop, int32_t) \ + DECL3(func, redop, uint32_t) \ + DECL3(func, redop, int64_t) \ + DECL3(func, redop, uint64_t) \ + DECL3(func, redop, half) \ + DECL3(func, redop, float) \ + DECL3(func, redop, double) -#define DECL_COLL(coll) \ - DECL_COLL2(coll, sum) \ - DECL_COLL2(coll, prod) \ - DECL_COLL2(coll, min) \ - DECL_COLL2(coll, max) +#define DECL(func) \ + DECL2(func, Sum) \ + DECL2(func, Prod) \ + DECL2(func, Min) \ + DECL2(func, Max) -#define DECL_ALL_COLLS \ - DECL_COLL2(ncclBroadcast, copy) \ - DECL_COLL(ncclReduce) \ - DECL_COLL2(ncclAllGather, copy) \ - DECL_COLL(ncclReduceScatter) \ - DECL_COLL(ncclAllReduce) \ - DECL_COLL5(ncclSendRecv,copy,i8) \ +#define DECL_ALL \ + DECL2(Broadcast, Sum) \ + DECL(Reduce) \ + DECL2(AllGather, Sum) \ + DECL(ReduceScatter) \ + DECL(AllReduce) \ + DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) \ -DECL_ALL_COLLS +DECL_ALL // CHUNKSIZE must be a multiple of SLICESIZE #define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4) |