Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device')
-rw-r--r--src/collectives/device/all_gather.h9
-rw-r--r--src/collectives/device/all_reduce.h118
-rw-r--r--src/collectives/device/broadcast.h9
-rw-r--r--src/collectives/device/common.h3
-rw-r--r--src/collectives/device/functions.cu3
-rw-r--r--src/collectives/device/primitives.h4
-rw-r--r--src/collectives/device/reduce.h9
-rw-r--r--src/collectives/device/reduce_scatter.h9
8 files changed, 158 insertions, 6 deletions
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index 0ad5ba9..059092c 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -69,6 +69,9 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllGatherTreeKernel(struct CollectiveArgs* args) { }
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclAllGatherCollNetKernel(struct CollectiveArgs* args) { }
+
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -130,6 +133,9 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { }
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllGatherCollNetLLKernel(struct CollectiveArgs* args) { }
+
#include "prims_ll128.h"
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
@@ -193,3 +199,6 @@ __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherTreeLL128Kernel(struct CollectiveArgs* args) { }
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllGatherCollNetLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index 2449c2b..173b5fa 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -106,7 +106,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -124,7 +124,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -140,6 +140,62 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
} while(0);
}
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->bid;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ const ssize_t size = args->N;
+ const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
+ int chunkSize = args->lastChunkSize;
+ const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ }
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ if (blockIdx.x < args->nChannels) { // first half of the channels do reduce
+ struct ncclTree* tree = &channel->collTreeUp;
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ }
+
+ if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast
+ struct ncclTree* tree = &channel->collTreeDn;
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.recv(thisOutput+offset, nelem);
+ } else {
+ prims.recvCopySend(thisOutput+offset, nelem);
+ }
+ }
+ }
+}
+
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -271,6 +327,61 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
} while(0);
}
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->bid;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ const ssize_t size = args->N;
+ ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ }
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ if (blockIdx.x < args->nChannels) { // first half of the channels do reduce
+ struct ncclTree* tree = &channel->collTreeUp;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ }
+
+ if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast
+ struct ncclTree* tree = &channel->collTreeDn;
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+ }
+ }
+}
+
#include "prims_ll128.h"
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
@@ -408,3 +519,6 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
}
}
}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllReduceCollNetLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h
index de8b989..5146682 100644
--- a/src/collectives/device/broadcast.h
+++ b/src/collectives/device/broadcast.h
@@ -54,6 +54,9 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclBroadcastTreeKernel(struct CollectiveArgs* args) { }
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclBroadcastCollNetKernel(struct CollectiveArgs* args) { }
+
template<int UNUSED, class FUNC, typename T>
__device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -101,6 +104,9 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { }
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclBroadcastCollNetLLKernel(struct CollectiveArgs* args) { }
+
#include "prims_ll128.h"
template<int UNUSED, class FUNC, typename T>
__device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
@@ -148,3 +154,6 @@ __device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclBroadcastTreeLL128Kernel(struct CollectiveArgs* args) { }
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclBroadcastCollNetLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
index 46eb9f5..6e06369 100644
--- a/src/collectives/device/common.h
+++ b/src/collectives/device/common.h
@@ -102,7 +102,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \
- IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING)
+ IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) \
+ IMPL_COLL4(coll##CollNet, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_COLLNET)
#if NCCL_TYPE == 0
#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu
index 034fe96..d10f11e 100644
--- a/src/collectives/device/functions.cu
+++ b/src/collectives/device/functions.cu
@@ -17,7 +17,8 @@ __device__ volatile uint64_t* ncclShmem;
#define NCCL_FUNC4(coll, op, dtype) \
NCCL_FUNC5(coll##Tree, op, dtype), \
- NCCL_FUNC5(coll##Ring, op, dtype)
+ NCCL_FUNC5(coll##Ring, op, dtype), \
+ NCCL_FUNC5(coll##CollNet, op, dtype)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(coll, op) \
diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h
index b624359..c1067bf 100644
--- a/src/collectives/device/primitives.h
+++ b/src/collectives/device/primitives.h
@@ -227,7 +227,7 @@ class ncclPrimitives {
recvStep[i] = conn->step;
recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
recvDirectBuff[i] = NULL;
- if (directBuff && conn->direct) {
+ if (directBuff && (conn->direct & NCCL_DIRECT_GPU)) {
recvDirectBuff[i] = directBuff;
if (tid == 0) *conn->ptrExchange = directBuff;
}
@@ -254,7 +254,7 @@ class ncclPrimitives {
sendStep[i] = conn->step;
sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
sendDirectBuff[i] = NULL;
- if (directBuff && conn->direct) {
+ if (directBuff && (conn->direct & NCCL_DIRECT_GPU)) {
void* volatile* ptr = conn->ptrExchange;
while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
barrier();
diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h
index 0680abe..e36613f 100644
--- a/src/collectives/device/reduce.h
+++ b/src/collectives/device/reduce.h
@@ -50,6 +50,9 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclReduceTreeKernel(struct CollectiveArgs* args) { }
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclReduceCollNetKernel(struct CollectiveArgs* args) { }
+
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -94,6 +97,9 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { }
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceCollNetLLKernel(struct CollectiveArgs* args) { }
+
#include "prims_ll128.h"
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
@@ -138,3 +144,6 @@ __device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceTreeLL128Kernel(struct CollectiveArgs* args) { }
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceCollNetLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h
index 1985148..0b0ae81 100644
--- a/src/collectives/device/reduce_scatter.h
+++ b/src/collectives/device/reduce_scatter.h
@@ -64,6 +64,9 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclReduceScatterTreeKernel(struct CollectiveArgs* args) { }
+template<int UNROLL, class FUNC, typename T>
+__device__ void ncclReduceScatterCollNetKernel(struct CollectiveArgs* args) { }
+
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -122,6 +125,9 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { }
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceScatterCollNetLLKernel(struct CollectiveArgs* args) { }
+
#include "prims_ll128.h"
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
@@ -182,3 +188,6 @@ __device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceScatterTreeLL128Kernel(struct CollectiveArgs* args) { }
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceScatterCollNetLL128Kernel(struct CollectiveArgs* args) { }