diff options
Diffstat (limited to 'src/collectives/device')
-rw-r--r-- | src/collectives/device/all_gather.h | 9 | ||||
-rw-r--r-- | src/collectives/device/all_reduce.h | 118 | ||||
-rw-r--r-- | src/collectives/device/broadcast.h | 9 | ||||
-rw-r--r-- | src/collectives/device/common.h | 3 | ||||
-rw-r--r-- | src/collectives/device/functions.cu | 3 | ||||
-rw-r--r-- | src/collectives/device/primitives.h | 4 | ||||
-rw-r--r-- | src/collectives/device/reduce.h | 9 | ||||
-rw-r--r-- | src/collectives/device/reduce_scatter.h | 9 |
8 files changed, 158 insertions, 6 deletions
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 0ad5ba9..059092c 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -69,6 +69,9 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclAllGatherTreeKernel(struct CollectiveArgs* args) { } +template<int UNROLL, class FUNC, typename T> +__device__ void ncclAllGatherCollNetKernel(struct CollectiveArgs* args) { } + template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -130,6 +133,9 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { } +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllGatherCollNetLLKernel(struct CollectiveArgs* args) { } + #include "prims_ll128.h" template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) { @@ -193,3 +199,6 @@ __device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclAllGatherTreeLL128Kernel(struct CollectiveArgs* args) { } + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllGatherCollNetLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 2449c2b..173b5fa 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -106,7 +106,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -124,7 +124,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -140,6 +140,62 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { } while(0); } +template<int UNROLL, class FUNC, typename T> +__device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads-WARP_SIZE; + const int bid = args->bid; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + const ssize_t size = args->N; + const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS); + int chunkSize = args->lastChunkSize; + const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nChannels*chunkSize; + + if (loopSize > size) { + chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + } + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + if (blockIdx.x < args->nChannels) { // first half of the channels do reduce + struct ncclTree* tree = &channel->collTreeUp; + ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Up + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (tree->up == -1) { + prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); + } else if (tree->down[0] == -1) { + prims.send(thisInput+offset, nelem); + } else { + prims.recvReduceSend(thisInput+offset, nelem); + } + } + } + + if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast + struct ncclTree* tree = &channel->collTreeDn; + ncclPrimitives<UNROLL, 1, 1, T, 1, 1, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Down + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (tree->up == -1) { + prims.send(thisOutput+offset, nelem); + } else if (tree->down[0] == -1) { + prims.recv(thisOutput+offset, nelem); + } else { + prims.recvCopySend(thisOutput+offset, nelem); + } + } + } +} + template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -271,6 +327,61 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { } while(0); } +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = args->nThreads; + const int bid = args->bid; + struct ncclDevComm* comm = args->comm; + struct ncclChannel* channel = comm->channels+blockIdx.x; + const ssize_t size = args->N; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nChannels*chunkSize; + + if (loopSize > size) { + chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize; + } + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + + if (blockIdx.x < args->nChannels) { // first half of the channels do reduce + struct ncclTree* tree = &channel->collTreeUp; + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Up + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (tree->up == -1) { + LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem); + } else if (tree->down[0] == -1) { + LLprims.send(thisInput+offset, nelem); + } else { + LLprims.recvReduceSend(thisInput+offset, nelem); + } + } + } + + if (blockIdx.x >= args->nChannels) { // second half of the channels do broadcast + struct ncclTree* tree = &channel->collTreeDn; + ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + // Down + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + if (tree->up == -1) { + LLprims.send(thisOutput+offset, nelem); + } else if (tree->down[0] == -1) { + LLprims.recv(thisOutput+offset, nelem); + } else { + LLprims.recvCopySend(thisOutput+offset, nelem); + } + } + } +} + #include "prims_ll128.h" template<int UNUSED, class FUNC, typename T> __device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) { @@ -408,3 +519,6 @@ __device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) { } } } + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllReduceCollNetLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index de8b989..5146682 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -54,6 +54,9 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclBroadcastTreeKernel(struct CollectiveArgs* args) { } +template<int UNROLL, class FUNC, typename T> +__device__ void ncclBroadcastCollNetKernel(struct CollectiveArgs* args) { } + template<int UNUSED, class FUNC, typename T> __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -101,6 +104,9 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { } +template<int UNUSED, class FUNC, typename T> +__device__ void ncclBroadcastCollNetLLKernel(struct CollectiveArgs* args) { } + #include "prims_ll128.h" template<int UNUSED, class FUNC, typename T> __device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) { @@ -148,3 +154,6 @@ __device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclBroadcastTreeLL128Kernel(struct CollectiveArgs* args) { } + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclBroadcastCollNetLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 46eb9f5..6e06369 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -102,7 +102,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ #define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \ - IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) + IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) \ + IMPL_COLL4(coll##CollNet, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_COLLNET) #if NCCL_TYPE == 0 #define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 034fe96..d10f11e 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -17,7 +17,8 @@ __device__ volatile uint64_t* ncclShmem; #define NCCL_FUNC4(coll, op, dtype) \ NCCL_FUNC5(coll##Tree, op, dtype), \ - NCCL_FUNC5(coll##Ring, op, dtype) + NCCL_FUNC5(coll##Ring, op, dtype), \ + NCCL_FUNC5(coll##CollNet, op, dtype) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(coll, op) \ diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index b624359..c1067bf 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -227,7 +227,7 @@ class ncclPrimitives { recvStep[i] = conn->step; recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS); recvDirectBuff[i] = NULL; - if (directBuff && conn->direct) { + if (directBuff && (conn->direct & NCCL_DIRECT_GPU)) { recvDirectBuff[i] = directBuff; if (tid == 0) *conn->ptrExchange = directBuff; } @@ -254,7 +254,7 @@ class ncclPrimitives { sendStep[i] = conn->step; sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS); sendDirectBuff[i] = NULL; - if (directBuff && conn->direct) { + if (directBuff && (conn->direct & NCCL_DIRECT_GPU)) { void* volatile* ptr = conn->ptrExchange; while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL); barrier(); diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index 0680abe..e36613f 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -50,6 +50,9 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclReduceTreeKernel(struct CollectiveArgs* args) { } +template<int UNROLL, class FUNC, typename T> +__device__ void ncclReduceCollNetKernel(struct CollectiveArgs* args) { } + template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -94,6 +97,9 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { } +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceCollNetLLKernel(struct CollectiveArgs* args) { } + #include "prims_ll128.h" template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) { @@ -138,3 +144,6 @@ __device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceTreeLL128Kernel(struct CollectiveArgs* args) { } + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceCollNetLL128Kernel(struct CollectiveArgs* args) { } diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 1985148..0b0ae81 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -64,6 +64,9 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) { template<int UNROLL, class FUNC, typename T> __device__ void ncclReduceScatterTreeKernel(struct CollectiveArgs* args) { } +template<int UNROLL, class FUNC, typename T> +__device__ void ncclReduceScatterCollNetKernel(struct CollectiveArgs* args) { } + template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -122,6 +125,9 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { } +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceScatterCollNetLLKernel(struct CollectiveArgs* args) { } + #include "prims_ll128.h" template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) { @@ -182,3 +188,6 @@ __device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) { template<int UNUSED, class FUNC, typename T> __device__ void ncclReduceScatterTreeLL128Kernel(struct CollectiveArgs* args) { } + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceScatterCollNetLL128Kernel(struct CollectiveArgs* args) { } |