Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/collectives.h')
-rw-r--r--src/include/collectives.h81
1 files changed, 42 insertions, 39 deletions
diff --git a/src/include/collectives.h b/src/include/collectives.h
index f854364..9b9022e 100644
--- a/src/include/collectives.h
+++ b/src/include/collectives.h
@@ -8,55 +8,58 @@
#define NCCL_COLLECTIVES_H_
#define FUNC_INDEX_P2P 0
-#define FUNC_INDEX(coll, redop, dtype, al, pr) (1+(((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
+#define FUNC_INDEX(func, redop, ncclType, al, pr) (1+(((((func)*ncclNumOps + (redop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
-#define NCCL_COLL_NAME(coll, op, dtype) \
- coll##_##op##_##dtype
+#define NCCL_FUNC_NAME(func, algo, proto, redop, type) \
+ ncclFunction_##func##_##algo##_##proto##_##redop##_##type
-#define NCCL_KERN_NAME(coll, op, dtype) \
- coll##Kernel_##op##_##dtype
+#define NCCL_KERN_NAME(func, algo, proto, redop, type) \
+ ncclKernel_##func##_##algo##_##proto##_##redop##_##type
+
+#define NCCL_IMPL_NAME(func, algo, proto) \
+ nccl##func##algo##proto
/* Declare all collective operations */
-#define DECL_COLL5(coll, op, dtype) \
- extern __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args); \
- extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl c); \
+#define DECL5(func, algo, proto, redop, type) \
+ extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \
+ extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem c); \
-#define DECL_COLL4(coll, op, dtype) \
- DECL_COLL5(coll, op, dtype) \
- DECL_COLL5(coll##LL, op, dtype) \
- DECL_COLL5(coll##LL128, op, dtype)
+#define DECL4(func, algo, redop, type) \
+ DECL5(func, algo, SIMPLE, redop, type) \
+ DECL5(func, algo, LL, redop, type) \
+ DECL5(func, algo, LL128, redop, type)
-#define DECL_COLL3(coll, op, dtype) \
- DECL_COLL4(coll##Ring, op, dtype) \
- DECL_COLL4(coll##Tree, op, dtype) \
- DECL_COLL4(coll##CollNet, op, dtype)
+#define DECL3(func, redop, type) \
+ DECL4(func, RING, redop, type) \
+ DECL4(func, TREE, redop, type) \
+ DECL4(func, COLLNET, redop, type)
-#define DECL_COLL2(coll, op) \
- DECL_COLL3(coll, op, i8) \
- DECL_COLL3(coll, op, u8) \
- DECL_COLL3(coll, op, i32) \
- DECL_COLL3(coll, op, u32) \
- DECL_COLL3(coll, op, i64) \
- DECL_COLL3(coll, op, u64) \
- DECL_COLL3(coll, op, f16) \
- DECL_COLL3(coll, op, f32) \
- DECL_COLL3(coll, op, f64)
+#define DECL2(func, redop) \
+ DECL3(func, redop, int8_t) \
+ DECL3(func, redop, uint8_t) \
+ DECL3(func, redop, int32_t) \
+ DECL3(func, redop, uint32_t) \
+ DECL3(func, redop, int64_t) \
+ DECL3(func, redop, uint64_t) \
+ DECL3(func, redop, half) \
+ DECL3(func, redop, float) \
+ DECL3(func, redop, double)
-#define DECL_COLL(coll) \
- DECL_COLL2(coll, sum) \
- DECL_COLL2(coll, prod) \
- DECL_COLL2(coll, min) \
- DECL_COLL2(coll, max)
+#define DECL(func) \
+ DECL2(func, Sum) \
+ DECL2(func, Prod) \
+ DECL2(func, Min) \
+ DECL2(func, Max)
-#define DECL_ALL_COLLS \
- DECL_COLL2(ncclBroadcast, copy) \
- DECL_COLL(ncclReduce) \
- DECL_COLL2(ncclAllGather, copy) \
- DECL_COLL(ncclReduceScatter) \
- DECL_COLL(ncclAllReduce) \
- DECL_COLL5(ncclSendRecv,copy,i8) \
+#define DECL_ALL \
+ DECL2(Broadcast, Sum) \
+ DECL(Reduce) \
+ DECL2(AllGather, Sum) \
+ DECL(ReduceScatter) \
+ DECL(AllReduce) \
+ DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) \
-DECL_ALL_COLLS
+DECL_ALL
// CHUNKSIZE must be a multiple of SLICESIZE
#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)