Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/collectives.h')
-rw-r--r--src/include/collectives.h73
1 files changed, 73 insertions, 0 deletions
diff --git a/src/include/collectives.h b/src/include/collectives.h
new file mode 100644
index 0000000..69c8e74
--- /dev/null
+++ b/src/include/collectives.h
@@ -0,0 +1,73 @@
+/*************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef NCCL_COLLECTIVES_H_
+#define NCCL_COLLECTIVES_H_
+
+#include "core.h"
+#include "info.h"
+
+#define FUNC_INDEX(coll, redop, dtype, al, pr) ((((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
+
+#define NCCL_COLL_NAME(coll, op, dtype) \
+ coll##_##op##_##dtype
+
+#define NCCL_KERN_NAME(coll, op, dtype) \
+ coll##Kernel_##op##_##dtype
+
+/* Declare all collective operations */
+#define DECL_COLL5(coll, op, dtype) \
+ extern __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args); \
+ extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl c); \
+
+#define DECL_COLL4(coll, op, dtype) \
+ DECL_COLL5(coll, op, dtype) \
+ DECL_COLL5(coll##LL, op, dtype) \
+ DECL_COLL5(coll##LL128, op, dtype)
+
+#define DECL_COLL3(coll, op, dtype) \
+ DECL_COLL4(coll##Ring, op, dtype) \
+ DECL_COLL4(coll##Tree, op, dtype)
+
+#define DECL_COLL2(coll, op) \
+ DECL_COLL3(coll, op, i8) \
+ DECL_COLL3(coll, op, u8) \
+ DECL_COLL3(coll, op, i32) \
+ DECL_COLL3(coll, op, u32) \
+ DECL_COLL3(coll, op, i64) \
+ DECL_COLL3(coll, op, u64) \
+ DECL_COLL3(coll, op, f16) \
+ DECL_COLL3(coll, op, f32) \
+ DECL_COLL3(coll, op, f64)
+
+#define DECL_COLL(coll) \
+ DECL_COLL2(coll, sum) \
+ DECL_COLL2(coll, prod) \
+ DECL_COLL2(coll, min) \
+ DECL_COLL2(coll, max)
+
+#define DECL_ALL_COLLS \
+ DECL_COLL2(ncclBroadcast, copy) \
+ DECL_COLL(ncclReduce) \
+ DECL_COLL2(ncclAllGather, copy) \
+ DECL_COLL(ncclReduceScatter) \
+ DECL_COLL(ncclAllReduce) \
+
+DECL_ALL_COLLS
+
+// CHUNKSIZE must be a multiple of SLICESIZE
+#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
+#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2)
+#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4)
+#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2)
+#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4)
+#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2)
+#define BROADCAST_SLICESTEPS 1
+#define BROADCAST_CHUNKSTEPS 1
+#define REDUCE_SLICESTEPS 1
+#define REDUCE_CHUNKSTEPS 1
+
+#endif