Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2019-11-20 01:57:39 +0300
committerGitHub <noreply@github.com>2019-11-20 01:57:39 +0300
commit299c554dccf923230321ad7495946543f3e9b457 (patch)
tree6a70b52080f0570fc87285b3b2300dbd2f2918ad /src/collectives/device
parentccb1298148327bacb9b83452ed6ae0b29417e7e2 (diff)
2.5.6-1 (#255)
Add LL128 Protocol. Rewrite the topology detection and tree/ring creation (#179). Improve tree performance by sending/receiving from different GPUs. Add model-based tuning to switch between the different algorithms and protocols. Rework P2P/SHM detection in containers (#155, #248). Detect duplicated devices and return an error (#231). Add tuning for GCP
Diffstat (limited to 'src/collectives/device')
-rw-r--r--src/collectives/device/Makefile2
-rw-r--r--src/collectives/device/all_gather.h72
-rw-r--r--src/collectives/device/all_reduce.h183
-rw-r--r--src/collectives/device/broadcast.h52
-rw-r--r--src/collectives/device/common.h31
-rw-r--r--src/collectives/device/common_kernel.h2
-rw-r--r--src/collectives/device/functions.cu15
-rw-r--r--src/collectives/device/op128.h36
-rw-r--r--src/collectives/device/primitives.h481
-rw-r--r--src/collectives/device/prims_ll.h259
-rw-r--r--src/collectives/device/prims_ll128.h410
-rw-r--r--src/collectives/device/reduce.h49
-rw-r--r--src/collectives/device/reduce_scatter.h67
13 files changed, 1269 insertions, 390 deletions
diff --git a/src/collectives/device/Makefile b/src/collectives/device/Makefile
index 0ee587b..001059c 100644
--- a/src/collectives/device/Makefile
+++ b/src/collectives/device/Makefile
@@ -68,4 +68,4 @@ $(DEVOBJ) : $(LIBOBJ)
$(NVCC) $(NVCUFLAGS) -dlink $^ -o $@
clean:
- rm -f $(LIBOBJ) $(DEVOBJ) $(DEPFILES) $(DEPENDFILES) $(STATICLIB) test
+ rm -f $(LIBOBJ) $(DEVOBJ) $(DEPFILES) $(DEPENDFILES) $(RULESFILE) $(STATICLIB)
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index 8e78730..0ad5ba9 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -11,7 +11,7 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
@@ -19,15 +19,15 @@ __device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
const ssize_t size = args->N;
const int nranks = comm->nRanks;
const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
- ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, FUNC>
+ prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
@@ -129,3 +129,67 @@ __device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { }
+
+#include "prims_ll128.h"
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int bid = args->bid;
+ const int nthreads = args->nThreads;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
+ //const int rank = comm->rank;
+ const int nranks = comm->nRanks;
+ ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize);
+
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin AllGather steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ if (thisInput + chunkOffset == thisOutput + offset) { // In place
+ LLprims.send(thisInput+chunkOffset, nelem);
+ } else {
+ LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
+ }
+
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+
+ // step k-1: final store
+ rankDest = ring->devUserRanks[1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recv(thisOutput+offset, nelem);
+ }
+}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllGatherTreeLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index 9b058cc..2449c2b 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -11,7 +11,7 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
@@ -27,7 +27,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->ThisOutput;
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
+ prims(tid, args->nThreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*args->nChannels));
@@ -85,23 +85,28 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
template<int UNROLL, class FUNC, typename T>
__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclTree* tree = &channel->tree;
const ssize_t size = args->N;
const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = args->lastChunkSize;
+ int chunkSize = args->lastChunkSize;
+ const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
const ssize_t loopSize = args->nChannels*chunkSize;
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ }
+
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
do {
+ struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -117,8 +122,9 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
} while(0);
do {
+ struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
+ ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -149,6 +155,8 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
//const int rank = comm->rank;
const int nranks = comm->nRanks;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
+
const ssize_t loopSize = args->nChannels*nranks*chunkSize;
// Compute pointers
@@ -156,10 +164,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->ThisOutput;
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- if (size-gridOffset < loopSize) {
- chunkSize = args->lastChunkSize;
- }
- ssize_t chunkOffset = gridOffset + bid*nranks*chunkSize;
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
/////////////// begin AllReduce steps ///////////////
ssize_t offset;
@@ -168,7 +173,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// step 0: push data to next GPU
slice = ring->devUserRanks[nranks-1];
- offset = chunkOffset + slice * chunkSize;
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.send(thisInput+offset, nelem);
@@ -176,7 +181,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
slice = ring->devUserRanks[nranks-j];
- offset = chunkOffset + slice * chunkSize;
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceSend(thisInput+offset, nelem);
@@ -185,7 +190,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// step k-1: reduce this buffer and data, which will produce the final
// result that we store in this data and push to the next GPU
slice = ring->devUserRanks[0];
- offset = chunkOffset + slice * chunkSize;
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
@@ -193,7 +198,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
slice = ring->devUserRanks[nranks-j];
- offset = chunkOffset + slice * chunkSize;
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
LLprims.recvCopySend(thisOutput+offset, nelem);
@@ -201,7 +206,7 @@ __device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
// Make final copy from buffer to dest.
slice = ring->devUserRanks[1];
- offset = chunkOffset + slice * chunkSize;
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
// Here we need to copy from buffer to this output.
@@ -216,16 +221,21 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclTree* tree = &channel->tree;
const ssize_t size = args->N;
ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
const ssize_t loopSize = args->nChannels*chunkSize;
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ }
+
// Compute pointers
const T * __restrict__ thisInput = (const T*)args->ThisInput;
T * __restrict__ thisOutput = (T*)args->ThisOutput;
do {
+ struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
@@ -243,6 +253,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
} while(0);
do {
+ struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
@@ -259,3 +270,141 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
}
} while(0);
}
+
+#include "prims_ll128.h"
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int bid = args->bid;
+ const int nthreads = args->nThreads;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
+ //const int rank = comm->rank;
+ const int nranks = comm->nRanks;
+ ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+
+ const ssize_t loopSize = args->nChannels*nranks*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
+
+ /////////////// begin AllReduce steps ///////////////
+ ssize_t offset;
+ int nelem;
+ int slice;
+
+ // step 0: push data to next GPU
+ slice = ring->devUserRanks[nranks-1];
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.send(thisInput+offset, nelem);
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ slice = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data and push to the next GPU
+ slice = ring->devUserRanks[0];
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
+
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ slice = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+
+ // Make final copy from buffer to dest.
+ slice = ring->devUserRanks[1];
+ offset = gridOffset + (slice*args->nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ // Here we need to copy from buffer to this output.
+ LLprims.recv(thisOutput+offset, nelem);
+ }
+}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->bid;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* treeUp = &channel->treeUp;
+ struct ncclTree* treeDn = &channel->treeDn;
+ const ssize_t size = args->N;
+ ssize_t chunkSize = args->lastChunkSize;
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
+ const ssize_t loopSize = args->nChannels*chunkSize;
+ int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, args->nChannels*minChunkSize)*minChunkSize;
+ }
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ if (treeUp->up == -1) {
+ // ReduceAndBroadcast : max number of recv is 3, max number of send is 3
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
+ }
+ } else {
+ if (tid < nthreadsSplit) {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (treeUp->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ } else {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, channel, comm, args->opCount);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (treeDn->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+ }
+ }
+ }
+}
diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h
index ae8667f..de8b989 100644
--- a/src/collectives/device/broadcast.h
+++ b/src/collectives/device/broadcast.h
@@ -11,7 +11,7 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
@@ -29,7 +29,7 @@ __device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->ThisOutput;
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
+ prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
@@ -100,3 +100,51 @@ __device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { }
+
+#include "prims_ll128.h"
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int bid = args->bid;
+ const int nthreads = args->nThreads;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
+ const int rank = ring->devUserRanks[0];
+ const int nextRank = ring->devUserRanks[1];
+ const int root = args->root;
+
+ ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize);
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (rank == root) {
+ if (thisInput == thisOutput) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
+ }
+ } else if (nextRank == root) {
+ LLprims.recv(thisOutput + offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput + offset, nelem);
+ }
+ }
+}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclBroadcastTreeLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
index 8c336bf..46eb9f5 100644
--- a/src/collectives/device/common.h
+++ b/src/collectives/device/common.h
@@ -7,9 +7,8 @@
#ifndef NCCL_DEVICE_COMMON_H_
#define NCCL_DEVICE_COMMON_H_
-#include "../collectives.h"
+#include "collectives.h"
#include "devcomm.h"
-#include "nccl.h"
// Exit If Abort Barrier across CTA: make sure all threads exit consistently
// Each thread sets a predicate to true if abort == 1
@@ -31,17 +30,19 @@ extern __device__ ncclKern_t ncclFuncs[];
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
int* d = (int*)dst;
int* s = (int*)src;
- // When aggregation is effective, if some threads have aborted inside the LL kernel,
- // make sure the rest of the threads abort as well
- exitIfAbortBarrier(0);
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
- __syncthreads();
}
-static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid) {
+static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid, struct ncclDevComm* comm) {
+ // Check whether the last operation was aborted and make sure all threads exit
+ int abort = tid == 0 ? *(comm->abortFlag) : 0;
+ exitIfAbortBarrier(abort);
load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid);
+ __syncthreads();
if (tid == 0) hostColl->active = 0;
}
+extern __device__ volatile uint64_t* ncclShmem;
+
/* Functions for aggregation case */
#define IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
__device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \
@@ -51,10 +52,11 @@ __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \
#if NCCL_OP == 0
/* Kernels with the first operation inlined */
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
-__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \
__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
int tid = threadIdx.x; \
int bid = blockIdx.x; \
+ __shared__ volatile uint64_t shmem[NCCL_LL128_SHMEM_SIZE]; \
+ ncclShmem = shmem; \
__shared__ struct ncclColl localColl; \
\
struct ncclDevComm* comm = firstColl.args.comm; \
@@ -65,7 +67,7 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
c = &firstColl; \
} else { \
c = &localColl; \
- load_coll(c, channel->devCollectives+channel->collFifoHead, tid); \
+ load_coll(c, channel->devCollectives+channel->collFifoHead, tid, comm); \
} \
while (1) { \
if (tid < c->args.nThreads) { \
@@ -84,7 +86,7 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
\
/* Load next collective operation*/ \
c = &localColl; /* for bid 0 */ \
- load_coll(c, channel->devCollectives+nextIndex, tid); \
+ load_coll(c, channel->devCollectives+nextIndex, tid, comm); \
} \
}
#else
@@ -93,13 +95,14 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
// Only generate inline kernels for LL
#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \
- IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \
- IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
+ IMPL_COLL_FUNC(coll##LL128, op, ncclFunc, dtype, ctype) \
+ IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
+ IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, al, NCCL_PROTO_LL)) \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
- IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \
- IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1)
+ IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \
+ IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING)
#if NCCL_TYPE == 0
#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h
index 435a598..aa1e936 100644
--- a/src/collectives/device/common_kernel.h
+++ b/src/collectives/device/common_kernel.h
@@ -263,8 +263,6 @@ __device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthread
}
}
-#define WARP_SIZE 32
-
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t,
int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS],
diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu
index 010c454..034fe96 100644
--- a/src/collectives/device/functions.cu
+++ b/src/collectives/device/functions.cu
@@ -8,13 +8,16 @@
#include "collectives.h"
#include "common.h"
+__device__ volatile uint64_t* ncclShmem;
+
#define NCCL_FUNC5(coll, op, dtype) \
- NCCL_COLL_NAME(coll, op, dtype), \
- NCCL_COLL_NAME(coll##LL, op, dtype)
+ NCCL_COLL_NAME(coll##LL, op, dtype), \
+ NCCL_COLL_NAME(coll##LL128, op, dtype), \
+ NCCL_COLL_NAME(coll, op, dtype)
#define NCCL_FUNC4(coll, op, dtype) \
- NCCL_FUNC5(coll##Ring, op, dtype), \
- NCCL_FUNC5(coll##Tree, op, dtype)
+ NCCL_FUNC5(coll##Tree, op, dtype), \
+ NCCL_FUNC5(coll##Ring, op, dtype)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(coll, op) \
@@ -50,7 +53,7 @@
NCCL_FUNCS3B(coll, copy), \
NCCL_FUNCS3B(coll, copy)
-// Must be consistent with ncclColl_t
+// Must be consistent with ncclFunc_t
#define NCCL_FUNCS() { \
NCCL_FUNCS2B(ncclBroadcast), \
NCCL_FUNCS2A(ncclReduce), \
@@ -59,7 +62,7 @@
NCCL_FUNCS2A(ncclAllReduce) }
// Must be consistent with the ncclFuncSet enum
-__device__ ncclKern_t ncclFuncs[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = {
+__device__ ncclKern_t ncclFuncs[NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h
new file mode 100644
index 0000000..9405dc2
--- /dev/null
+++ b/src/collectives/device/op128.h
@@ -0,0 +1,36 @@
+/*************************************************************************
+ * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef OP128_H_
+#define OP128_H_
+
+inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
+ asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];"
+ : "=l"(v0), "=l"(v1) : "l"(ptr));
+}
+
+inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
+ asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};"
+ :: "l"(v0), "l"(v1), "l"(ptr));
+}
+
+inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) {
+ uint64_t* shmemAsmPtr;
+ asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(shmemAsmPtr) : "l"(shmemGenericPtr));
+ return shmemAsmPtr;
+}
+
+inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) {
+ asm volatile("ld.volatile.shared.v2.u64 {%0,%1}, [%2];"
+ : "=l"(v0), "=l"(v1) : "l"(shmemAsmPtr));
+}
+
+inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) {
+ asm volatile("st.volatile.shared.v2.u64 [%2], {%0,%1};"
+ :: "l"(v0), "l"(v1), "l"(shmemAsmPtr));
+}
+
+#endif
diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h
index 7beeaf4..aa3d20d 100644
--- a/src/collectives/device/primitives.h
+++ b/src/collectives/device/primitives.h
@@ -37,15 +37,27 @@ class ncclPrimitives {
private:
const int tid;
const int nthreads;
+ const int wid;
+ const int stepSize;
int nrecv = 0;
int nsend = 0;
- const int stepSize;
- struct ncclConnInfo* recvConn[NRECV];
- struct ncclConnInfo* sendConn[NSEND];
- volatile uint64_t* waitPtr;
+ struct ncclConnInfo* recvConn = NULL;
+ volatile uint64_t* recvConnHeadPtr = NULL;
+ uint64_t recvConnHead;
+ volatile uint64_t* recvConnTailPtr = NULL;
+ uint64_t recvConnTail;
+ uint64_t recvConnTailCache; // Cache last seen value
+
+ struct ncclConnInfo* sendConn = NULL;
+ volatile int* sendConnFifoPtr = NULL;
+ volatile uint64_t* sendConnTailPtr = NULL;
+ uint64_t sendConnTail;
+ volatile uint64_t* sendConnHeadPtr = NULL;
+ uint64_t sendConnHead;
+ uint64_t sendConnHeadCache; // Cache last seen value
+
uint64_t recvStep[NRECV];
uint64_t sendStep[NSEND];
- uint64_t sendConnHead[NSEND];
const T* recvDirectBuff[NRECV];
T* sendDirectBuff[NSEND];
const T* recvBuff[NRECV];
@@ -60,15 +72,18 @@ class ncclPrimitives {
inline __device__ void barrier() {
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
}
+ inline __device__ void subBarrier() {
+ asm volatile ("bar.sync 2, %0;" :: "r"(nthreads-WARP_SIZE));
+ }
uint32_t mismatch = 0;
const uint64_t opCount;
- inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) {
+ inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
if (mismatch) {
// In non-LL, we use _threadfence_system before incrementing opCount, yet we are still waiting for credits here, so there must be a size mismatch
*(comm->fatalDevError) = ncclDevAssertedMismatch;
- } else if (remoteOpCount && *remoteOpCount > opCount) {
+ } else if (conn && *conn->opCountRem > opCount) {
mismatch += 1;
}
}
@@ -76,49 +91,55 @@ class ncclPrimitives {
uint32_t spins = 0;
uint32_t abort = 0;
- inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) {
+ inline __device__ int checkAbort(int i, int send) {
spins++;
- if (spins == SPINS_BEFORE_CHECK_ABORT) {
+ if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
- checkMismatch(remoteOpCount);
+ if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
}
- inline __device__ void waitRecv(int i) {
+ inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
- recvStep[i] += SLICESTEPS;
- if (tid == i) {
- while (*(waitPtr) < recvStep[i]) {
- if (checkAbort(recvConn[i]->opCountRem)) break;
+ if (sendConnHeadPtr) {
+ while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
+ sendConnHeadCache = *sendConnHeadPtr;
+ if (checkAbort(wid, 1)) break;
+ }
+ if (sendConnFifoPtr) {
+ sendConnFifoPtr[sendConnHead%NCCL_STEPS] = nbytes;
}
+ sendConnHead += SLICESTEPS;
}
}
- inline __device__ void waitSend(int i) {
+ inline __device__ void waitRecv() {
spins = 0;
mismatch = 0;
- sendStep[i] += SLICESTEPS;
- if (tid == WARP_SIZE+i) {
- while (sendConnHead[i] + NCCL_STEPS < sendStep[i]) {
- sendConnHead[i] = *waitPtr;
- if (checkAbort(sendConn[i]->opCountRem)) break;
+ if (recvConnTailPtr) {
+ while (recvConnTailCache < recvConnTail + SLICESTEPS) {
+ recvConnTailCache = *recvConnTailPtr;
+ if (checkAbort(wid, 0)) break;
}
+ recvConnTail += SLICESTEPS;
}
}
- inline __device__ void postRecv(int i) {
- *(recvConn[i]->head) = recvStep[i] += SLICESTEPS;
+ inline __device__ void incRecv(int i) {
+ recvStep[i] += SLICESTEPS;
}
-
- inline __device__ void postSend(int i) {
- *(sendConn[i]->tail) = sendStep[i] += SLICESTEPS;
+ inline __device__ void postRecv() {
+ if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += SLICESTEPS;
}
- inline __device__ void postSendSize(int i, int size) {
- if (sendConn[i]->fifo) sendConn[i]->fifo[sendStep[i]%NCCL_STEPS] = size;
+ inline __device__ void incSend(int i) {
+ sendStep[i] += SLICESTEPS;
+ }
+ inline __device__ void postSend() {
+ if (sendConnTailPtr) *sendConnTailPtr = sendConnTail += SLICESTEPS;
}
template <int DIRECTRECV>
@@ -131,11 +152,22 @@ class ncclPrimitives {
return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
}
+ template <int DIRECTRECV>
+ inline __device__ int directRecvInc(int i, int directInc, int sliceInc) {
+ return DIRECTRECV && recvDirectBuff[i] ? directInc : sliceInc;
+ }
+
+ template <int DIRECTSEND>
+ inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
+ return DIRECTSEND && sendDirectBuff[i] ? directInc : sliceInc;
+ }
+
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
inline __device__ void
GenericOp(const T* srcPtr, T* dstPtr, int nelem, int directOffset) {
int offset = 0;
- int sliceSize = stepSize * SLICESTEPS;
+ int sliceSize = stepSize*SLICESTEPS;
+ int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32);
const T* srcs[RECV*NRECV+SRC];
srcs[0] = SRC ? srcPtr : directRecvPtr<DIRECTRECV>(0, directOffset);
@@ -151,101 +183,126 @@ class ncclPrimitives {
for (int i=1; i<NSEND && i<nsend; i++) dsts[DST+i] = directSendPtr<DIRECTSEND>(i, directOffset);
}
- #pragma unroll 1
+ bool syncThread = tid >= nthreads-WARP_SIZE;
+
+ #pragma unroll
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
- int realSize = max(0, min(sliceSize, nelem-offset));
- if (tid < nthreads) {
- FOR_SEND(waitSend);
- FOR_RECV(waitRecv);
+ int realSize = max(0, min(dataSize, nelem-offset));
+ if (!syncThread) {
+ if (SEND) waitSend(realSize*sizeof(T));
+ if (RECV) waitRecv();
if (realSize > 0) {
- barrier();
+ subBarrier();
if (DIRECTRECV && recvDirectBuff[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (SEND) {
- ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize);
+ ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads-WARP_SIZE, 1, srcs, nsend, dsts+1, realSize);
}
} else {
- ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
+ ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads-WARP_SIZE, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
}
}
- exitIfAbortBarrier(abort);
- } else {
- exitIfAbortBarrier(abort);
- FOR_SEND(postSendSize, realSize*sizeof(T));
- if (SEND) __threadfence_system();
- FOR_SEND(postSend);
- FOR_RECV(postRecv);
}
- for (int i=0; i<RECV*NRECV+SRC; i++) srcs[i] += sliceSize;
- for (int i=0; i<SEND*NSEND+DST; i++) dsts[i] += sliceSize;
- offset += sliceSize;
+ barrier();
+ FOR_SEND(incSend);
+ FOR_RECV(incRecv);
+ if (syncThread) {
+ if (SEND) {
+ if (realSize > 0 && wid == 0) __threadfence_system();
+ __syncwarp();
+ postSend();
+ }
+ if (RECV) postRecv();
+ }
+ srcs[0] += SRC ? realSize : directRecvInc<DIRECTRECV>(0, realSize, sliceSize);
+ for (int i=1-SRC; i<RECV*NRECV; i++) srcs[SRC+i] += sliceSize;
+ dsts[0] += DST ? realSize : directSendInc<DIRECTSEND>(0, realSize, sliceSize);
+ for (int i=1-DST; i<SEND*NSEND; i++) dsts[DST+i] += directSendInc<DIRECTSEND>(i, realSize, sliceSize);
+ offset += realSize;
}
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
- recvConn[i] = conn;
- recvBuff[i] = (const T*)recvConn[i]->buff;
- recvStep[i] = recvConn[i]->step;
+ recvBuff[i] = (const T*)conn->buff;
+ recvStep[i] = conn->step;
recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
- // Return credits in case we rounded up.
- if (tid == nthreads) *recvConn[i]->head = recvStep[i];
- if (tid == i) {
- waitPtr = recvConn[i]->tail;
- *(recvConn[i]->opCountLoc) = opCount;
- }
recvDirectBuff[i] = NULL;
- if (directBuff && recvConn[i]->direct) {
+ if (directBuff && conn->direct) {
recvDirectBuff[i] = directBuff;
- if (tid == 0) *recvConn[i]->ptrExchange = directBuff;
+ if (tid == 0) *conn->ptrExchange = directBuff;
}
+ if (wid == i) recvConn = conn;
+ if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up
nrecv++;
}
+ __device__ __forceinline__ void loadRecvSync() {
+ if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && wid<nrecv) {
+ recvConnTailPtr = recvConn->tail;
+ recvConnTailCache = *recvConnTailPtr;
+ }
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConnHeadPtr = recvConn->head;
+ // Return credits in case we rounded up.
+ *recvConnHeadPtr = recvConnHead;
+ // Update opCount in case we skipped some operations
+ *(recvConn->opCountLoc) = opCount;
+ }
+ }
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i, T* directBuff) {
- sendConn[i] = conn;
- sendBuff[i] = (T*)sendConn[i]->buff;
- sendStep[i] = sendConn[i]->step;
+ sendBuff[i] = (T*)conn->buff;
+ sendStep[i] = conn->step;
sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
- if (tid == WARP_SIZE+i) {
- waitPtr = sendConn[i]->head;
- sendConnHead[i] = *waitPtr;
- *(sendConn[i]->opCountLoc) = opCount;
- }
sendDirectBuff[i] = NULL;
- if (directBuff && sendConn[i]->direct) {
- void* volatile* ptr = sendConn[i]->ptrExchange;
+ if (directBuff && conn->direct) {
+ void* volatile* ptr = conn->ptrExchange;
while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
- __syncthreads();
+ barrier();
if (tid == 0) *ptr = NULL;
}
+ if (wid == i) sendConn = conn;
+ if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up
nsend++;
}
+ __device__ __forceinline__ void loadSendSync() {
+ if (tid < nsend) {
+ sendConnHeadPtr = sendConn->head;
+ sendConnHeadCache = *sendConnHeadPtr;
+ sendConnFifoPtr = sendConn->fifo;
+ *(sendConn->opCountLoc) = opCount;
+ }
+ if (tid >= nthreads-WARP_SIZE && wid<nsend) {
+ sendConnTailPtr = sendConn->tail;
+ }
+ }
- __device__ __forceinline__ void saveRecvConn(int i) {
- if (tid == i) {
- recvConn[i]->step = recvStep[i];
+ __device__ __forceinline__ void saveRecvSync() {
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConn->step = recvConnHead;
+ *(recvConn->opCountLoc) = opCount+1;
__threadfence_system();
- *(recvConn[i]->opCountLoc) += 1;
}
}
- __device__ __forceinline__ void saveSendConn(int i) {
- if (tid == WARP_SIZE+i) {
- sendConn[i]->step = sendStep[i];
+ __device__ __forceinline__ void saveSendSync() {
+ if (tid < nsend) {
+ sendConn->step = sendConnHead;
+ *(sendConn->opCountLoc) = opCount+1;
__threadfence_system();
- *(sendConn[i]->opCountLoc) += 1;
}
}
public:
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
- : comm(comm), tid(tid), nthreads(nthreads), stepSize(stepSize), opCount(opCount) {
- // Make sure step is updated before we read it
- __syncthreads();
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) {
+ // Make sure step is updated before we read it.
+ barrier();
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff);
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i, directBuff);
+ loadRecvSync();
+ loadSendSync();
}
__device__ __forceinline__ void
@@ -305,267 +362,13 @@ class ncclPrimitives {
}
__device__ __forceinline__ ~ncclPrimitives() {
- // Save steps for next collective. Have thread 0 do it to be compatible
- // with the way LL works.
- for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i);
- for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i);
+ // Save steps for the next operation
+ saveRecvSync();
+ saveSendSync();
}
};
-template <typename T, class FUNC, int NRECV, int NSEND>
-class ncclLLPrimitives {
- private:
- const int tid;
- const int nthreads;
- int nrecv = 0;
- int nsend = 0;
- struct ncclConnInfo* recvConn[NRECV];
- struct ncclConnInfo* sendConn[NSEND];
- volatile uint64_t* waitPtr;
- volatile uint64_t* postPtr;
- volatile int* fifoPtr;
- uint64_t recvStep[NRECV];
- uint64_t sendStep[NSEND];
- uint64_t sendConnHead;
- union ncclLLFifoLine* recvBuff[NRECV];
- union ncclLLFifoLine* sendBuff[NSEND];
- struct ncclDevComm* comm;
-
- inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
- inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
- inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
- inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
- inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
- inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
-
- // Exit If Abort Barrier : make sure all threads exit consistently
- // Each thread sets a predicate to true if val == 1
- // all CTA's threads enter the barrier and do a popc on their predicates being True
- // If any of the thread's predicate was True, all the threads call exit()
- inline __device__ void exitIfAbortLocalBarrier() {
- uint32_t popc;
- asm ("{");
- asm volatile (" .reg .pred barr_pred;");
- asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort));
- asm volatile (" bar.red.popc.u32 %0, 14, %1, barr_pred;" : "=r"(popc) : "r"(nthreads));
- asm ("}");
- if (popc) {
- // Make sure threads not participating in the operation get the abort and all threads exit
- exitIfAbortBarrier(1);
- }
- }
-
- inline __device__ void barrier() {
- asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
- }
-
- uint32_t mismatch = 0;
- const uint64_t opCount;
+#include "prims_ll.h"
+//#include "prims_ll128.h"
- inline __device__ void checkMismatch(volatile uint64_t* remoteOpCount) {
- if (mismatch > 20) {
- // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
- // Note that we are not using _threadfence_system in LL so the error cannot be asserted
- *(comm->fatalDevError) = ncclDevSuspectedMismatch;
- } else if (remoteOpCount && *remoteOpCount > opCount) {
- mismatch += 1;
- }
- }
-
- uint32_t spins = 0;
- uint32_t abort = 0;
-
- inline __device__ int checkAbort(volatile uint64_t* remoteOpCount) {
- spins++;
- if (spins == SPINS_BEFORE_CHECK_ABORT) {
- abort = *(comm->abortFlag);
- checkMismatch(remoteOpCount);
- spins = 0;
- }
- return abort;
- }
-
- inline __device__ void waitSend(int i, int nbytes) {
- spins = 0;
- mismatch = 0;
- if (tid == WARP_SIZE+i) {
- while (sendConnHead + NCCL_STEPS < sendStep[i] + 1) {
- sendConnHead = *waitPtr;
- if (checkAbort(sendConn[i]->opCountRem)) break;
- }
- if (fifoPtr) {
- int size = ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes;
- fifoPtr[sendStep[i]%NCCL_STEPS] = size;
- }
- }
- }
-
- inline __device__ void postRecv(int i) {
- recvStep[i]++;
- if (tid == i) *postPtr = recvStep[i];
- }
-
- inline __device__ void postSend(int i, int offset) {
- // LL Cleanup : write all flags in the slice to make sure we don't have
- // data corruption when flag loops over.
- if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
- for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
- }
- sendStep[i]++;
- }
-
- __device__ uint64_t readLL(int i, int offset) {
- union ncclLLFifoLine* src = recvPtr(i) + offset;
- uint32_t flag = recvFlag(i);
- uint32_t data1, flag1, data2, flag2;
- spins = 0;
- mismatch = 0;
- do {
- asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
- if (checkAbort(recvConn[i]->opCountRem)) break;
- } while ((flag1 != flag) || (flag2 != flag));
- uint64_t val64 = data1 + (((uint64_t)data2) << 32);
- return val64;
- }
-
- __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
- asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
- }
-
- // Using memcpy handles misaligned pointers.
- __device__ uint64_t readAL(uint64_t* src) {
- uint64_t val;
- memcpy((char*)&val, (char*)src, sizeof(uint64_t));
- return val;
- }
-
- __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
- memcpy((char*)dst, (char*)&val, nbytes);
- }
-
- template <int RECV, int SEND, int SRC, int DST>
- __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
- uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
- FOR_SEND(waitSend, nbytes*2);
- barrier();
- uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
- uint64_t* srcPack = (uint64_t*)srcPtr;
- uint64_t* dstPack = (uint64_t*)dstPtr;
- int offset = tid;
- // Do multiples of 64 bits
- #pragma unroll 2
- for (; offset<npack; offset+=nthreads) {
- // Recv : local, then intra-node, then inter-node
- uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
- if (RECV) {
- if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
- for (int i=1; i<NRECV && i<nrecv; i++) {
- val = MULTI<FUNC, T>()(readLL(i, offset), val);
- }
- }
-
- // Send : inter-node, then intra-node, then local
- if (SEND) {
- for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
- storeLL(sendPtr(0)+offset, val, sendFlag(0));
- }
- if (DST) {
- if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
- // Last incomplete word
- storeAL(dstPack+offset, val, nbytes & 0x7);
- } else {
- storeAL(dstPack+offset, val, sizeof(uint64_t));
- }
- }
- }
- exitIfAbortLocalBarrier();
- FOR_RECV(postRecv);
- FOR_SEND(postSend, offset);
- }
-
- __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
- recvConn[i] = conn;
- recvBuff[i] = recvConn[i]->llBuff;
- recvStep[i] = recvConn[i]->step;
- if (tid == i) {
- postPtr = recvConn[i]->head;
- *(recvConn[i]->opCountLoc) = opCount;
- }
- nrecv++;
- }
-
- __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
- sendConn[i] = conn;
- sendBuff[i] = sendConn[i]->llBuff;
- sendStep[i] = sendConn[i]->step;
- if (tid == WARP_SIZE+i) {
- waitPtr = sendConn[i]->head;
- fifoPtr = sendConn[i]->fifo;
- sendConnHead = *waitPtr;
- *(sendConn[i]->opCountLoc) = opCount;
- }
- nsend++;
- }
-
- __device__ __forceinline__ void saveRecvConn(int i) {
- if (tid == i) {
- recvConn[i]->step = recvStep[i];
- *(recvConn[i]->opCountLoc) += 1;
- __threadfence_block();
- }
- }
-
- __device__ __forceinline__ void saveSendConn(int i) {
- if (tid == WARP_SIZE+i) {
- sendConn[i]->step = sendStep[i];
- *(sendConn[i]->opCountLoc) += 1;
- __threadfence_block();
- }
- }
-
- public:
- __device__ __forceinline__
- ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
- : comm(comm), tid(tid), nthreads(nthreads), opCount(opCount) {
- // Make sure step is updated before we read it.
- barrier();
-
- for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
- for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
- }
-
- __device__ void send(const T* src, int nelem) {
- return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
- }
-
- __device__ void recv(T* dst, int nelem) {
- return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
- }
-
- __device__ void recvReduceSend(const T* src, int nelem) {
- return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
- }
-
- __device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
- return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
- }
-
- __device__ void copySend(const T* src, T* dst, int nelem) {
- return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
- }
-
- __device__ void recvCopySend(T* dst, int nelem) {
- return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
- }
-
- __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
- return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
- }
-
- __device__ __forceinline__ ~ncclLLPrimitives() {
- // Save steps for the next operation
- for (int i=0; i<NRECV && i<nrecv; i++) saveRecvConn(i);
- for (int i=0; i<NSEND && i<nsend; i++) saveSendConn(i);
- }
-};
#endif
diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h
new file mode 100644
index 0000000..f919493
--- /dev/null
+++ b/src/collectives/device/prims_ll.h
@@ -0,0 +1,259 @@
+template <typename T, class FUNC, int NRECV, int NSEND>
+class ncclLLPrimitives {
+ private:
+ const int tid;
+ const int nthreads;
+ const int wid;
+ int nrecv = 0;
+ int nsend = 0;
+ struct ncclConnInfo* recvConn = NULL;
+ volatile uint64_t* recvConnHeadPtr = NULL;
+ uint64_t recvConnHead;
+
+ struct ncclConnInfo* sendConn = NULL;
+ volatile int* sendConnFifoPtr = NULL;
+ volatile uint64_t* sendConnHeadPtr = NULL;
+ uint64_t sendConnHead;
+ uint64_t sendConnHeadCache; // Cache last seen value
+
+ uint64_t recvStep[NRECV];
+ uint64_t sendStep[NSEND];
+ union ncclLLFifoLine* recvBuff[NRECV];
+ union ncclLLFifoLine* sendBuff[NSEND];
+ struct ncclDevComm* comm;
+
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL_SLICE_LINES; }
+ inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
+ inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
+ inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
+ inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
+
+ inline __device__ void barrier() {
+ asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
+ }
+
+ uint32_t mismatch = 0;
+ const uint64_t opCount;
+
+ inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
+ if (mismatch > 20) {
+ // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
+ // Note that we are not using _threadfence_system in LL so the error cannot be asserted
+ *(comm->fatalDevError) = ncclDevSuspectedMismatch;
+ } else if (conn && *conn->opCountRem > opCount) {
+ mismatch += 1;
+ }
+ }
+
+ uint32_t spins = 0;
+ uint32_t abort = 0;
+
+ inline __device__ int checkAbort(int i, int send) {
+ spins++;
+ if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
+ abort = *(comm->abortFlag);
+ if (wid == i) checkMismatch(send ? sendConn : recvConn);
+ spins = 0;
+ }
+ return abort;
+ }
+
+ inline __device__ void waitSend(int nbytes) {
+ spins = 0;
+ mismatch = 0;
+ if (sendConnHeadPtr) {
+ while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
+ sendConnHeadCache = *sendConnHeadPtr;
+ if (checkAbort(wid, 1)) break;
+ }
+ if (sendConnFifoPtr) {
+ int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? NCCL_LL_SLICE_LINES*sizeof(union ncclLLFifoLine) : nbytes;
+ sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size;
+ }
+ sendConnHead += 1;
+ }
+ barrier();
+ }
+
+ inline __device__ void incRecv(int i) {
+ recvStep[i] += 1;
+ }
+ inline __device__ void postRecv() {
+ barrier();
+ if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
+ }
+
+ inline __device__ void incSend(int i, int offset) {
+ // LL Cleanup : write all flags in the slice to make sure we don't have
+ // data corruption when flag loops over.
+ if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
+ for (int o = offset; o<NCCL_LL_SLICE_LINES; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
+ }
+ sendStep[i]++;
+ }
+
+ __device__ uint64_t readLL(int i, int offset) {
+ union ncclLLFifoLine* src = recvPtr(i) + offset;
+ uint32_t flag = recvFlag(i);
+ uint32_t data1, flag1, data2, flag2;
+ spins = 0;
+ mismatch = 0;
+ do {
+ asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4));
+ if (checkAbort(i, 0)) break;
+ } while ((flag1 != flag) || (flag2 != flag));
+ uint64_t val64 = data1 + (((uint64_t)data2) << 32);
+ return val64;
+ }
+
+ __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
+ asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag));
+ }
+
+ // Using memcpy handles misaligned pointers.
+ __device__ uint64_t readAL(uint64_t* src) {
+ uint64_t val;
+ memcpy((char*)&val, (char*)src, sizeof(uint64_t));
+ return val;
+ }
+
+ __device__ void storeAL(uint64_t* dst, uint64_t val, uint32_t nbytes) {
+ memcpy((char*)dst, (char*)&val, nbytes);
+ }
+
+ template <int RECV, int SEND, int SRC, int DST>
+ __device__ void LLGenericOp(const T* srcPtr, T* dstPtr, int nelem) {
+ uint32_t nbytes = nelem < 0 ? 0 : nelem*sizeof(T);
+ uint32_t npack = DIVUP(nbytes, sizeof(uint64_t));
+ uint64_t* srcPack = (uint64_t*)srcPtr;
+ uint64_t* dstPack = (uint64_t*)dstPtr;
+ int offset = tid;
+
+ // Always waitSend in case of cleanup
+ if (SEND) waitSend(npack*sizeof(union ncclLLFifoLine));
+
+ // Do multiples of 64 bits
+ #pragma unroll 2
+ for (; offset<npack; offset+=nthreads) {
+ // Recv : local, then intra-node, then inter-node
+ uint64_t val = SRC ? readAL(srcPack+offset) : readLL(0, offset);
+ if (RECV) {
+ if (SRC) val = MULTI<FUNC, T>()(readLL(0, offset), val);
+ for (int i=1; i<NRECV && i<nrecv; i++) {
+ val = MULTI<FUNC, T>()(readLL(i, offset), val);
+ }
+ }
+
+ // Send : inter-node, then intra-node, then local
+ if (SEND) {
+ for (int i=1; i<NSEND && i<nsend; i++) storeLL(sendPtr(i)+offset, val, sendFlag(i));
+ storeLL(sendPtr(0)+offset, val, sendFlag(0));
+ }
+ if (DST) {
+ if (((offset*sizeof(uint64_t)) ^ nbytes) < sizeof(uint64_t)) {
+ // Last incomplete word
+ storeAL(dstPack+offset, val, nbytes & 0x7);
+ } else {
+ storeAL(dstPack+offset, val, sizeof(uint64_t));
+ }
+ }
+ }
+ FOR_RECV(incRecv); if (RECV) postRecv();
+ FOR_SEND(incSend, offset);
+ }
+
+ __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
+ recvBuff[i] = conn->llBuff;
+ recvStep[i] = conn->step;
+ if (wid == i) recvConn = conn;
+ nrecv++;
+ }
+ __device__ __forceinline__ void loadRecvSync() {
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConnHeadPtr = recvConn->head;
+ recvConnHead = recvConn->step;
+ // Update opCount in case we skipped some operations
+ *(recvConn->opCountLoc) = opCount;
+ }
+ }
+
+ __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
+ sendBuff[i] = conn->llBuff;
+ sendStep[i] = conn->step;
+ if (wid == i) sendConn = conn;
+ nsend++;
+ }
+ __device__ __forceinline__ void loadSendSync() {
+ if (tid < nsend) {
+ sendConnHeadPtr = sendConn->head;
+ sendConnHeadCache = *sendConnHeadPtr;
+ sendConnHead = sendConn->step;
+ sendConnFifoPtr = sendConn->fifo;
+ *(sendConn->opCountLoc) = opCount;
+ }
+ }
+
+ __device__ __forceinline__ void saveRecvSync() {
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConn->step = recvConnHead;
+ *(recvConn->opCountLoc) = opCount+1;
+ __threadfence_block();
+ }
+ }
+
+ __device__ __forceinline__ void saveSendSync() {
+ if (tid < nsend) {
+ sendConn->step = sendConnHead;
+ *(sendConn->opCountLoc) = opCount+1;
+ __threadfence_block();
+ }
+ }
+
+ public:
+ __device__ __forceinline__
+ ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), opCount(opCount) {
+ // Make sure step is updated before we read it.
+ barrier();
+
+ for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
+ for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
+ loadRecvSync();
+ loadSendSync();
+ }
+
+ __device__ void send(const T* src, int nelem) {
+ return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
+ }
+
+ __device__ void recv(T* dst, int nelem) {
+ return LLGenericOp<1, 0, 0, 1>(NULL, dst, nelem);
+ }
+
+ __device__ void recvReduceSend(const T* src, int nelem) {
+ return LLGenericOp<1, 1, 1, 0>(src, NULL, nelem);
+ }
+
+ __device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
+ return LLGenericOp<1, 0, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ void copySend(const T* src, T* dst, int nelem) {
+ return LLGenericOp<0, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ void recvCopySend(T* dst, int nelem) {
+ return LLGenericOp<1, 1, 0, 1>(NULL, dst, nelem);
+ }
+
+ __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
+ return LLGenericOp<1, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ __forceinline__ ~ncclLLPrimitives() {
+ // Save steps for the next operation
+ saveRecvSync();
+ saveSendSync();
+ }
+};
diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h
new file mode 100644
index 0000000..40a8cff
--- /dev/null
+++ b/src/collectives/device/prims_ll128.h
@@ -0,0 +1,410 @@
+/*************************************************************************
+ * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#include "op128.h"
+
+#define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1)
+
+template <typename T, class FUNC, int NRECV, int NSEND>
+class ncclLL128Primitives {
+ private:
+ const int tid;
+ const int nthreads;
+ const int wid;
+ const int warp;
+ const bool flagThread;
+ int nrecv = 0;
+ int nsend = 0;
+ struct ncclConnInfo* recvConn = NULL;
+ volatile uint64_t* recvConnHeadPtr = NULL;
+ uint64_t recvConnHead;
+
+ struct ncclConnInfo* sendConn = NULL;
+ volatile int* sendConnFifoPtr = NULL;
+ volatile uint64_t* sendConnTailPtr = NULL;
+ uint64_t sendConnTail;
+ volatile uint64_t* sendConnHeadPtr = NULL;
+ uint64_t sendConnHead;
+ uint64_t sendConnHeadCache; // Cache last seen value
+
+ uint64_t recvStep[NRECV];
+ uint64_t sendStep[NSEND];
+ uint64_t* recvBuff[NRECV];
+ uint64_t* sendBuff[NSEND];
+ struct ncclDevComm* comm;
+
+ volatile uint64_t* shmem;
+
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; }
+ inline __device__ uint64_t* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
+ inline __device__ uint64_t* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
+ inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; }
+ inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; }
+
+ inline __device__ void barrier() {
+ if (NSEND>NRECV) {
+ asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
+ } else {
+ asm volatile ("bar.sync 3, %0;" :: "r"(nthreads));
+ }
+ }
+
+ uint32_t mismatch = 0;
+ const uint64_t opCount;
+
+ inline __device__ void checkMismatch(struct ncclConnInfo* conn) {
+ if (mismatch > 20) {
+ // We have seen that the peer advanced opcount so many times yet we are still waiting for credit of current op, so it is _most likely_ a mismatch
+ // Note that we are not using _threadfence_system in LL so the error cannot be asserted
+ *(comm->fatalDevError) = ncclDevSuspectedMismatch;
+ } else if (conn && *conn->opCountRem > opCount) {
+ mismatch += 1;
+ }
+ }
+
+ uint32_t spins = 0;
+ uint32_t abort = 0;
+
+ inline __device__ int checkAbort(int i, int send) {
+ spins++;
+ if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
+ abort = *(comm->abortFlag);
+ if (wid == i) checkMismatch(send ? sendConn : recvConn);
+ spins = 0;
+ }
+ return abort;
+ }
+
+ inline __device__ void waitSend(int nbytes) {
+ spins = 0;
+ mismatch = 0;
+ if (sendConnHeadPtr) {
+ while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
+ sendConnHeadCache = *sendConnHeadPtr;
+ if (checkAbort(wid, 1)) break;
+ }
+ if (sendConnFifoPtr) {
+ sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes;
+ }
+ sendConnHead += 1;
+ }
+ }
+
+ inline __device__ void incRecv(int i) {
+ recvStep[i] += 1;
+ }
+ inline __device__ void postRecv() {
+ if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
+ }
+
+ inline __device__ void incSend(int i) {
+ sendStep[i] += 1;
+ }
+ inline __device__ void postSend() {
+ if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; }
+ }
+
+ template <int ELEMS_PER_THREAD>
+ inline __device__ void loadSrcToShmem128(int maxOffset, const uint64_t* src64Ptr) {
+#if 0
+ uint64_t v[ELEMS_PER_THREAD];
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ if (u*WARP_SIZE < maxOffset) load128(src64Ptr+u*WARP_SIZE, v[u], v[u+1]);
+ }
+ uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ storeShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
+ }
+#else
+ uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ if (u*WARP_SIZE < maxOffset) {
+ uint64_t v0, v1;
+ load128(src64Ptr+u*WARP_SIZE, v0, v1);
+ storeShmem128(shmemAsmPtr+u*WARP_SIZE, v0, v1);
+ }
+ }
+#endif
+ }
+
+ inline __device__ void loadSrcToShmem(int start, int end, const T* srcPtr) {
+ T* shmemPtr = (T*)(shmem-2*wid);
+ for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
+ shmemPtr[offset] = srcPtr[offset];
+ }
+ }
+
+ template <int ELEMS_PER_THREAD>
+ inline __device__ void storeShmemToDst128(int maxOffset, uint64_t* dst64Ptr) {
+ uint64_t v[ELEMS_PER_THREAD];
+ uint64_t* shmemAsmPtr = shmemCvtPtr(shmem);
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ loadShmem128(shmemAsmPtr+u*WARP_SIZE, v[u], v[u+1]);
+ }
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ if (u*WARP_SIZE < maxOffset) store128(dst64Ptr+u*WARP_SIZE, v[u], v[u+1]);
+ }
+ }
+
+ inline __device__ void storeShmemToDst(int start, int end, T* dstPtr) {
+ T* shmemPtr = (T*)(shmem-2*wid);
+ for (int offset = start+wid; offset < end; offset += WARP_SIZE) {
+ dstPtr[offset] = shmemPtr[offset];
+ }
+ }
+
+ #define WARP_MASK 0xffffffff
+
+ template <int ELEMS_PER_THREAD, int RECV, int SEND, int SRC, int DST>
+ __device__ __forceinline__ void recvReduceSendCopy(int ll128Offset) {
+ uint64_t v[ELEMS_PER_THREAD];
+
+ /************* Data Loading : SHMEM -> REG **************/
+ if (SRC) {
+ volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ v[u] = shmem64Ptr[u*(WARP_SIZE-2)];
+ if (!flagThread) v[u+1] = shmem64Ptr[u*(WARP_SIZE-2)+1];
+ }
+ }
+ /*********** End Data Loading : SHMEM -> REG ************/
+
+ /************************ Recv **************************/
+ if (RECV) {
+ uint64_t flag = recvFlag(0);
+ uint64_t* ptr = recvPtr(0)+ll128Offset;
+ bool needReload;
+ uint64_t v0, v1;
+ do {
+ needReload = false;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ load128(ptr+u*WARP_SIZE, v0, v1);
+ needReload |= flagThread && (v1 != flag);
+ }
+ } while (__any_sync(WARP_MASK, needReload) && checkAbort(0, 0) == 0);
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ load128(ptr+u*WARP_SIZE, v0, v1);
+ v[u] = SRC ? MULTI<FUNC, T>()(v0, v[u]) : v0;
+ v[u+1] = SRC ? MULTI<FUNC, T>()(v1, v[u+1]) : v1;
+ }
+
+ for (int i=1; i<NRECV && i<nrecv; i++) {
+ uint64_t flag = recvFlag(i);
+ uint64_t* ptr = recvPtr(i)+ll128Offset;
+ uint64_t v0, v1;
+ do {
+ needReload = false;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ load128(ptr+u*WARP_SIZE, v0, v1);
+ needReload |= flagThread && (v1 != flag);
+ }
+ } while (__any_sync(WARP_MASK, needReload) && checkAbort(i, 0) == 0);
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ load128(ptr+u*WARP_SIZE, v0, v1);
+ v[u] = MULTI<FUNC, T>()(v0, v[u]);
+ v[u+1] = MULTI<FUNC, T>()(v1, v[u+1]);
+ }
+ }
+ }
+ /********************** End Recv ************************/
+
+ /************************ Send **************************/
+ if (SEND) {
+ for (int i=1; i<NSEND && i<nsend; i++) {
+ int flag = sendFlag(i);
+ uint64_t* ptr = sendPtr(i)+ll128Offset;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
+ }
+ }
+ int flag = sendFlag(0);
+ uint64_t* ptr = sendPtr(0)+ll128Offset;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
+ }
+ }
+ /********************** End Send ************************/
+
+ /************* Data Storing : REG -> SHMEM **************/
+ if (DST) {
+ volatile uint64_t* shmem64Ptr = shmem - (2*wid)/NCCL_LL128_LINEELEMS;
+ #pragma unroll
+ for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
+ shmem64Ptr[u*(WARP_SIZE-2)] = v[u];
+ if (!flagThread) shmem64Ptr[u*(WARP_SIZE-2)+1] = v[u+1];
+ }
+ }
+ /*********** End data Storing : REG -> SHMEM ************/
+ }
+
+ #define LL128INC (WARP_SIZE*NCCL_LL128_SHMEM_ELEMS_PER_THREAD)
+ #define ELEMINC (LL128INC-(LL128INC/NCCL_LL128_LINEELEMS))
+
+ template <int RECV, int SEND, int SRC, int DST>
+ __device__ void GenericOp(const T* srcPtr, T* dstPtr, int nelem) {
+ if (nelem <= 0) {
+ // Don't move any data but still increase steps and sync with prev/next
+ if (SEND) waitSend(0);
+ FOR_SEND(incSend); if (SEND) postSend();
+ FOR_RECV(incRecv); if (RECV) postRecv();
+ return;
+ }
+ const int nelem64 = ((nelem*sizeof(T))/(2*sizeof(uint64_t)))*2;
+ const uint64_t* src64Ptr = ((uint64_t*)srcPtr);
+ uint64_t* dst64Ptr = ((uint64_t*)dstPtr);
+
+ int ll128Offset = LL128INC*warp+2*wid;
+ int elemOffset = ELEMINC*warp;
+ const int nwarps = nthreads/WARP_SIZE;
+
+ if (SEND) waitSend(DIVUP(nelem*sizeof(T), ELEMINC*sizeof(uint64_t))*LL128INC*sizeof(uint64_t));
+ barrier();
+
+ while (elemOffset*(sizeof(uint64_t)/sizeof(T)) < nelem) {
+ const int maxOffset128 = min(nelem64-elemOffset, (int)ELEMINC);
+ const int maxOffset = min(nelem-(elemOffset*((int)(sizeof(uint64_t)/sizeof(T)))), (int)(ELEMINC*(sizeof(uint64_t)/sizeof(T))));
+ if (SRC) {
+ int done = 0;
+ if ((((uint64_t)srcPtr)&0xf) == 0) {
+ loadSrcToShmem128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, src64Ptr+elemOffset+2*wid);
+ done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
+ }
+ loadSrcToShmem(done, maxOffset, (T*)(src64Ptr+elemOffset));
+ }
+ __syncwarp();
+ recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SRC, DST>(ll128Offset);
+ __syncwarp();
+ if (DST) {
+ int done = 0;
+ if ((((uint64_t)dstPtr)&0xf) == 0) {
+ storeShmemToDst128<NCCL_LL128_SHMEM_ELEMS_PER_THREAD>(maxOffset128-2*wid, dst64Ptr+elemOffset+2*wid);
+ done = maxOffset128*(sizeof(uint64_t)/sizeof(T));
+ }
+ storeShmemToDst(done, maxOffset, (T*)(dst64Ptr+elemOffset));
+ }
+ __syncwarp();
+ ll128Offset += LL128INC*nwarps;
+ elemOffset += ELEMINC*nwarps;
+ }
+
+ barrier();
+ FOR_SEND(incSend); if (SEND) postSend();
+ FOR_RECV(incRecv); if (RECV) postRecv();
+ }
+
+ __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
+ recvBuff[i] = conn->ll128Buff;
+ recvStep[i] = conn->step;
+ if (wid == i) recvConn = conn;
+ nrecv++;
+ }
+ __device__ __forceinline__ void loadRecvSync() {
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConnHeadPtr = recvConn->head;
+ recvConnHead = recvConn->step;
+ // Update opCount in case we skipped some operations
+ *(recvConn->opCountLoc) = opCount;
+ }
+ }
+
+ __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
+ sendBuff[i] = conn->ll128Buff;
+ sendStep[i] = conn->step;
+ if (wid == i) sendConn = conn;
+ nsend++;
+ }
+ __device__ __forceinline__ void loadSendSync() {
+ if (tid < nsend) {
+ sendConnHeadPtr = sendConn->head;
+ sendConnHeadCache = *sendConnHeadPtr;
+ sendConnHead = sendConn->step;
+ sendConnFifoPtr = sendConn->fifo;
+ *(sendConn->opCountLoc) = opCount;
+ }
+ if (tid >= nthreads-WARP_SIZE && wid<nsend) {
+ if (sendConn->fifo) {
+ sendConnTailPtr = sendConn->tail;
+ sendConnTail = sendConn->step;
+ }
+ }
+ }
+
+ __device__ __forceinline__ void saveRecvSync() {
+ if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
+ recvConn->step = recvConnHead;
+ *(recvConn->opCountLoc) = opCount+1;
+ __threadfence_block();
+ }
+ }
+
+ __device__ __forceinline__ void saveSendSync() {
+ if (tid < nsend) {
+ sendConn->step = sendConnHead;
+ *(sendConn->opCountLoc) = opCount+1;
+ __threadfence_block();
+ }
+ }
+
+ public:
+ __device__ __forceinline__
+ ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
+ // Make sure step is updated before we read it.
+ barrier();
+
+ for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
+ for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
+ loadRecvSync();
+ loadSendSync();
+ }
+
+ __device__ void send(const T* src, int nelem) {
+ return GenericOp<0, 1, 1, 0>(src, NULL, nelem);
+ }
+
+ __device__ void recv(T* dst, int nelem) {
+ return GenericOp<1, 0, 0, 1>(NULL, dst, nelem);
+ }
+
+ __device__ void recvReduceSend(const T* src, int nelem) {
+ return GenericOp<1, 1, 1, 0>(src, NULL, nelem);
+ }
+
+ __device__ void recvReduceCopy(const T* src, T* dst, int nelem) {
+ return GenericOp<1, 0, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ void copySend(const T* src, T* dst, int nelem) {
+ return GenericOp<0, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ void recvCopySend(T* dst, int nelem) {
+ return GenericOp<1, 1, 0, 1>(NULL, dst, nelem);
+ }
+
+ __device__ void recvReduceCopySend(const T* src, T* dst, int nelem) {
+ return GenericOp<1, 1, 1, 1>(src, dst, nelem);
+ }
+
+ __device__ __forceinline__ ~ncclLL128Primitives() {
+ // Save steps for the next operation
+ saveRecvSync();
+ saveSendSync();
+ }
+};
diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h
index d2d5d3b..0680abe 100644
--- a/src/collectives/device/reduce.h
+++ b/src/collectives/device/reduce.h
@@ -11,7 +11,7 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
@@ -30,7 +30,7 @@ __device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->ThisOutput;
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
+ prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
@@ -93,3 +93,48 @@ __device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { }
+
+#include "prims_ll128.h"
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int bid = args->bid;
+ const int nthreads = args->nThreads;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
+ const int rank = comm->rank;
+ const int nranks = comm->nRanks;
+ const int prevRank = ring->devUserRanks[nranks-1];
+ const int root = args->root;
+
+ ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize);
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (prevRank == root) {
+ LLprims.send(thisInput+offset, nelem);
+ } else if (rank == root) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceTreeLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h
index 09ba56e..1985148 100644
--- a/src/collectives/device/reduce_scatter.h
+++ b/src/collectives/device/reduce_scatter.h
@@ -11,7 +11,7 @@
template<int UNROLL, class FUNC, typename T>
__device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
- const int nthreads = blockDim.x - 1;
+ const int nthreads = args->nThreads-WARP_SIZE;
const int bid = args->bid;
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
@@ -19,7 +19,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
const ssize_t size = args->N;
const int nranks = comm->nRanks;
const int stepSize = channel->buffSize / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS;
const ssize_t loopSize = args->nChannels*(ssize_t)chunkSize;
// Compute pointers
@@ -27,7 +27,7 @@ __device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
T * __restrict__ thisOutput = (T*)args->ThisOutput;
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
+ prims(tid, args->nThreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,args->nChannels));
@@ -121,3 +121,64 @@ __device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
template<int UNUSED, class FUNC, typename T>
__device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { }
+
+#include "prims_ll128.h"
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
+ const int tid = threadIdx.x;
+ const int bid = args->bid;
+ const int nthreads = args->nThreads;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, channel, comm, args->opCount);
+
+ const ssize_t size = args->N;
+ //const int rank = comm->rank;
+ const int nranks = comm->nRanks;
+ ssize_t chunkSize = (NCCL_LL128_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+
+ const ssize_t loopSize = args->nChannels*chunkSize;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->ThisInput;
+ T * __restrict__ thisOutput = (T*)args->ThisOutput;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, args->nChannels*minChunkSize)*minChunkSize, chunkSize);
+
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin ReduceScatter steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[nranks-1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.send(thisInput+offset, nelem);
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
+ }
+}
+
+template<int UNUSED, class FUNC, typename T>
+__device__ void ncclReduceScatterTreeLL128Kernel(struct CollectiveArgs* args) { }