Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2020-09-05 00:35:05 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2020-11-17 22:08:52 +0300
commit920dbe5b359fe5817b8ba874476ca4ba2dc5f1ef (patch)
treeda539cb823c9e11e4fa8e7e6de88dd4a662c7128
parent084207e685c4587e7d0aa2f1f7f148d3e0e68da6 (diff)
2.8.3-1
Optimization for Tree allreduce on A100. Improve aggregation performance. Use shared buffers for inter-node send/recv. Add NVTX profiling hooks. Accelerate alltoall connections by merging communication for all channels. Add support for one hop communication through NVLink, for faster send/recv communication on cubemesh topologies like DGX-1. Improve alltoall scheduling to better balance intra/inter node communication. Increase send/recv parallelism by 8x, each warp sending or receiving to a different peer. Net: move to v4. Net: make flush operation asynchronous to accelerate alltoall. Net: define maximum number of requests. Fix hang when using LL128 protocol after 2^31 steps. Fix #379 : topology injection failing when using less GPUs than described in the XML. Fix #394 : protocol mismatch causing hangs or crashes when using one GPU per node.
-rw-r--r--makefiles/common.mk5
-rw-r--r--makefiles/version.mk4
-rw-r--r--pkg/debian/control.in4
-rw-r--r--pkg/redhat/nccl.spec.in6
-rw-r--r--src/bootstrap.cc450
-rw-r--r--src/channel.cc4
-rw-r--r--src/collectives/all_gather.cc5
-rw-r--r--src/collectives/all_reduce.cc5
-rw-r--r--src/collectives/broadcast.cc5
-rw-r--r--src/collectives/device/all_gather.cu4
-rw-r--r--src/collectives/device/all_gather.h376
-rw-r--r--src/collectives/device/all_reduce.cu4
-rw-r--r--src/collectives/device/all_reduce.h953
-rw-r--r--src/collectives/device/broadcast.cu4
-rw-r--r--src/collectives/device/broadcast.h285
-rw-r--r--src/collectives/device/common.h203
-rw-r--r--src/collectives/device/common_kernel.h174
-rw-r--r--src/collectives/device/functions.cu102
-rw-r--r--src/collectives/device/primitives.h332
-rw-r--r--src/collectives/device/prims_ll.h2
-rw-r--r--src/collectives/device/prims_ll128.h10
-rw-r--r--src/collectives/device/reduce.cu4
-rw-r--r--src/collectives/device/reduce.h271
-rw-r--r--src/collectives/device/reduce_scatter.cu4
-rw-r--r--src/collectives/device/reduce_scatter.h359
-rw-r--r--src/collectives/device/sendrecv.cu5
-rw-r--r--src/collectives/device/sendrecv.h139
-rw-r--r--src/collectives/reduce.cc5
-rw-r--r--src/collectives/reduce_scatter.cc5
-rw-r--r--src/collectives/sendrecv.cc6
-rw-r--r--src/debug.cc6
-rw-r--r--src/enqueue.cc393
-rw-r--r--src/graph/connect.cc175
-rw-r--r--src/graph/paths.cc50
-rw-r--r--src/graph/rings.cc2
-rw-r--r--src/graph/search.cc60
-rw-r--r--src/graph/topo.cc32
-rw-r--r--src/graph/topo.h39
-rw-r--r--src/graph/trees.cc53
-rw-r--r--src/graph/tuning.cc99
-rw-r--r--src/graph/xml.cc44
-rw-r--r--src/graph/xml.h45
-rw-r--r--src/group.cc190
-rw-r--r--src/include/bootstrap.h4
-rw-r--r--src/include/coll_net.h2
-rw-r--r--src/include/collectives.h81
-rw-r--r--src/include/comm.h19
-rw-r--r--src/include/core.h1
-rw-r--r--src/include/cpuset.h4
-rw-r--r--src/include/devcomm.h73
-rw-r--r--src/include/enqueue.h2
-rw-r--r--src/include/graph.h18
-rw-r--r--src/include/nccl_net.h27
-rw-r--r--src/include/net.h4
-rw-r--r--src/include/nvmlwrap.h16
-rw-r--r--src/include/nvtx.h14
-rw-r--r--src/include/nvtx3.hpp2268
-rw-r--r--src/include/nvtx3/nvToolsExt.h1470
-rw-r--r--src/include/nvtx3/nvToolsExtCuda.h141
-rw-r--r--src/include/nvtx3/nvToolsExtCudaRt.h117
-rw-r--r--src/include/nvtx3/nvToolsExtOpenCL.h191
-rw-r--r--src/include/nvtx3/nvToolsExtSync.h382
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImpl.h438
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImplCore.h307
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h81
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h102
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h161
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h83
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxInit.h312
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxInitDecls.h81
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxInitDefs.h573
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h83
-rw-r--r--src/include/nvtx3/nvtxDetail/nvtxTypes.h304
-rw-r--r--src/include/p2p.h39
-rw-r--r--src/include/param.h5
-rw-r--r--src/include/proxy.h34
-rw-r--r--src/include/socket.h21
-rw-r--r--src/include/transport.h7
-rw-r--r--src/include/trees.h6
-rw-r--r--src/init.cc190
-rw-r--r--src/misc/argcheck.cc10
-rw-r--r--src/misc/nvmlwrap.cc56
-rw-r--r--src/proxy.cc497
-rw-r--r--src/transport.cc116
-rw-r--r--src/transport/coll_net.cc303
-rw-r--r--src/transport/net.cc421
-rw-r--r--src/transport/net_ib.cc157
-rw-r--r--src/transport/net_socket.cc29
-rw-r--r--src/transport/p2p.cc206
-rw-r--r--src/transport/shm.cc8
90 files changed, 11175 insertions, 3212 deletions
diff --git a/makefiles/common.mk b/makefiles/common.mk
index 8e91a45..d4c353b 100644
--- a/makefiles/common.mk
+++ b/makefiles/common.mk
@@ -11,6 +11,7 @@ KEEP ?= 0
DEBUG ?= 0
TRACE ?= 0
PROFAPI ?= 0
+NVTX ?= 1
NVCC = $(CUDA_HOME)/bin/nvcc
@@ -87,6 +88,10 @@ ifneq ($(TRACE), 0)
CXXFLAGS += -DENABLE_TRACE
endif
+ifeq ($(NVTX), 0)
+CXXFLAGS += -DNVTX_DISABLE
+endif
+
ifneq ($(KEEP), 0)
NVCUFLAGS += -keep
endif
diff --git a/makefiles/version.mk b/makefiles/version.mk
index 8142428..f2539c5 100644
--- a/makefiles/version.mk
+++ b/makefiles/version.mk
@@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
-NCCL_MINOR := 7
-NCCL_PATCH := 8
+NCCL_MINOR := 8
+NCCL_PATCH := 3
NCCL_SUFFIX :=
PKG_REVISION := 1
diff --git a/pkg/debian/control.in b/pkg/debian/control.in
index c8f5266..22c60f8 100644
--- a/pkg/debian/control.in
+++ b/pkg/debian/control.in
@@ -9,7 +9,7 @@ Package: libnccl${nccl:Major}
Section: libs
Architecture: ${pkg:Arch}
Depends: ${misc:Depends}, ${shlibs:Depends}
-Description: NVIDIA Collectives Communication Library (NCCL) Runtime
+Description: NVIDIA Collective Communication Library (NCCL) Runtime
NCCL (pronounced "Nickel") is a stand-alone library of standard collective
communication routines for GPUs, implementing all-reduce, all-gather, reduce,
broadcast, and reduce-scatter.
@@ -21,7 +21,7 @@ Package: libnccl-dev
Section: libdevel
Architecture: ${pkg:Arch}
Depends: ${misc:Depends}, ${shlibs:Depends}, libnccl${nccl:Major} (= ${binary:Version})
-Description: NVIDIA Collectives Communication Library (NCCL) Development Files
+Description: NVIDIA Collective Communication Library (NCCL) Development Files
NCCL (pronounced "Nickel") is a stand-alone library of standard collective
communication routines for GPUs, implementing all-reduce, all-gather, reduce,
broadcast, and reduce-scatter.
diff --git a/pkg/redhat/nccl.spec.in b/pkg/redhat/nccl.spec.in
index 5fad346..f1cce5c 100644
--- a/pkg/redhat/nccl.spec.in
+++ b/pkg/redhat/nccl.spec.in
@@ -1,7 +1,7 @@
Name: libnccl
Version: ${nccl:Major}.${nccl:Minor}.${nccl:Patch}${nccl:Suffix}
Release: ${pkg:Revision}+cuda${cuda:Major}.${cuda:Minor}
-Summary: NVIDIA Collectives Communication Library (NCCL) Runtime
+Summary: NVIDIA Collective Communication Library (NCCL) Runtime
Group: Development/Libraries
License: BSD
@@ -18,13 +18,13 @@ NVLink, NVswitch, as well as networking using InfiniBand Verbs or TCP/IP
sockets.
%package devel
-Summary: NVIDIA Collectives Communication Library (NCCL) Runtime
+Summary: NVIDIA Collective Communication Library (NCCL) Runtime
Group: Development/Libraries
%description devel
NCCL development files
%package static
-Summary: NVIDIA Collectives Communication Library (NCCL) Runtime
+Summary: NVIDIA Collective Communication Library (NCCL) Runtime
Group: Development/Libraries
%description static
NCCL static library
diff --git a/src/bootstrap.cc b/src/bootstrap.cc
index e90dd66..bd6ec99 100644
--- a/src/bootstrap.cc
+++ b/src/bootstrap.cc
@@ -13,144 +13,77 @@
#include <unistd.h>
#include <sys/types.h>
-struct bootstrapNetComm {
- int fd;
-};
-
/* Init functions */
-static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
-static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
-static int bootstrapNetIfs = -1;
+static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1];
+static union socketAddress bootstrapNetIfAddr;
+static int bootstrapNetInitDone = 0;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
ncclResult_t bootstrapNetInit() {
- if (bootstrapNetIfs == -1) {
+ if (bootstrapNetInitDone == 0) {
pthread_mutex_lock(&bootstrapNetLock);
- if (bootstrapNetIfs == -1) {
- bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
- if (bootstrapNetIfs <= 0) {
- WARN("Bootstrap : no socket interface found");
- return ncclInternalError;
+ if (bootstrapNetInitDone == 0) {
+ char* env = getenv("NCCL_COMM_ID");
+ if (env) {
+ union socketAddress remoteAddr;
+ if (GetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) {
+ WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
+ return ncclInvalidArgument;
+ }
+ if (findInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
+ WARN("NET/Socket : No usable listening interface found");
+ return ncclSystemError;
+ }
} else {
- char line[1024];
- char addrline[1024];
- line[0] = '\0';
- for (int i=0; i<bootstrapNetIfs; i++) {
- snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
- socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
+ int nIfs = findInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1);
+ if (nIfs <= 0) {
+ WARN("Bootstrap : no socket interface found");
+ return ncclInternalError;
}
- line[1023] = '\0';
- INFO(NCCL_INIT, "Bootstrap : Using%s", line);
}
+ char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2];
+ sprintf(line, " %s:", bootstrapNetIfName);
+ socketToString(&bootstrapNetIfAddr.sa, line+strlen(line));
+ INFO(NCCL_INIT, "Bootstrap : Using%s", line);
+ bootstrapNetInitDone = 1;
}
pthread_mutex_unlock(&bootstrapNetLock);
}
return ncclSuccess;
}
-static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
- NCCLCHECK(ncclCalloc(comm, 1));
- (*comm)->fd = -1;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
- if (dev >= bootstrapNetIfs) return ncclInternalError;
- memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
- return ncclSuccess;
-}
-
/* Socket Interface Selection type */
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
-static ncclResult_t bootstrapNetListen(int dev, ncclNetHandle_t* netHandle, void** listenComm) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- static_assert(sizeof(union socketAddress) < NCCL_NET_HANDLE_MAXSIZE, "union socketAddress size is too large");
- // if dev >= 0, listen based on dev
- if (dev >= 0) {
- NCCLCHECK(bootstrapNetGetSocketAddr(dev, connectAddr));
- } else if (dev == findSubnetIf) {
- // handle stores a remote address
- // need to find a local addr that is in the same network as the remote addr
- union socketAddress localAddr;
- char ifName[MAX_IF_NAME_SIZE];
- if (findInterfaceMatchSubnet(ifName, &localAddr, connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
- WARN("NET/Socket : No usable listening interface found");
- return ncclSystemError;
- }
- // pass the local address back
- memcpy(connectAddr, &localAddr, sizeof(localAddr));
- } // Otherwise, handle stores a local address
- struct bootstrapNetComm* comm;
- NCCLCHECK(bootstrapNetNewComm(&comm));
- NCCLCHECK(createListenSocket(&comm->fd, connectAddr));
- *listenComm = comm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetConnect(int dev, ncclNetHandle_t* netHandle, void** sendComm) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- struct bootstrapNetComm* comm;
- NCCLCHECK(bootstrapNetNewComm(&comm));
- NCCLCHECK(connectAddress(&comm->fd, connectAddr));
- *sendComm = comm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
- struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
- struct bootstrapNetComm* rComm;
- NCCLCHECK(bootstrapNetNewComm(&rComm));
+static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) {
struct sockaddr_in sockaddr;
socklen_t socklen = sizeof(struct sockaddr_in);
- SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
- *recvComm = rComm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetClose(void* opaqueComm) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
- if (comm) {
- close(comm->fd);
- free(comm);
- }
+ SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd);
return ncclSuccess;
}
-static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
-
// Additional sync functions
-static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
- NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
- NCCLCHECK(socketSend(comm->fd, data, size));
+static ncclResult_t bootstrapNetSend(int fd, void* data, int size) {
+ NCCLCHECK(socketSend(fd, &size, sizeof(int)));
+ NCCLCHECK(socketSend(fd, data, size));
return ncclSuccess;
}
-static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
+static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) {
int recvSize;
- NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
+ NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int)));
if (recvSize > size) {
WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
return ncclInternalError;
}
- NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
- return ncclSuccess;
-}
-
-ncclResult_t bootstrapNetCreateHandle(ncclNetHandle_t* netHandle, const char* str) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- NCCLCHECK(GetSocketAddrFromString(connectAddr, str));
+ NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size)));
return ncclSuccess;
}
struct extInfo {
int rank;
int nranks;
- ncclNetHandle_t extHandleListenRoot;
- ncclNetHandle_t extHandleListen;
+ union socketAddress extAddressListenRoot;
+ union socketAddress extAddressListen;
};
#include <sys/resource.h>
@@ -163,27 +96,29 @@ static ncclResult_t setFilesLimit() {
return ncclSuccess;
}
-static void *bootstrapRoot(void* listenComm) {
+static void *bootstrapRoot(void* args) {
+ int listenFd = (uint64_t)args;
+ ncclResult_t res = ncclSuccess;
+ int nranks = 0, c = 0;
struct extInfo info;
- ncclNetHandle_t *rankHandles = NULL;
- ncclNetHandle_t *rankHandlesRoot = NULL; // for initial rank <-> root information exchange
- ncclNetHandle_t zero = { 0 }; // for sanity checking
- void* tmpComm;
- ncclResult_t res;
+ union socketAddress *rankAddresses = NULL;
+ union socketAddress *rankAddressesRoot = NULL; // for initial rank <-> root information exchange
+ union socketAddress *zero = NULL;
+ NCCLCHECKGOTO(ncclCalloc(&zero, 1), res, out);
setFilesLimit();
TRACE(NCCL_INIT, "BEGIN");
/* Receive addresses from all ranks */
- int nranks = 0, c = 0;
do {
- NCCLCHECKGOTO(bootstrapNetAccept(listenComm, &tmpComm), res, out);
- NCCLCHECKGOTO(bootstrapNetRecv(tmpComm, &info, sizeof(info)), res, out);
- NCCLCHECKGOTO(bootstrapNetCloseRecv(tmpComm), res, out);
+ int tmpFd;
+ NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out);
+ NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out);
+ close(tmpFd);
if (c == 0) {
nranks = info.nranks;
- NCCLCHECKGOTO(ncclCalloc(&rankHandles, nranks), res, out);
- NCCLCHECKGOTO(ncclCalloc(&rankHandlesRoot, nranks), res, out);
+ NCCLCHECKGOTO(ncclCalloc(&rankAddresses, nranks), res, out);
+ NCCLCHECKGOTO(ncclCalloc(&rankAddressesRoot, nranks), res, out);
}
if (nranks != info.nranks) {
@@ -191,14 +126,14 @@ static void *bootstrapRoot(void* listenComm) {
goto out;
}
- if (memcmp(&zero, &rankHandlesRoot[info.rank], sizeof(ncclNetHandle_t)) != 0) {
+ if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union socketAddress)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
goto out;
}
// Save the connection handle for that rank
- memcpy(rankHandlesRoot+info.rank, info.extHandleListenRoot, sizeof(ncclNetHandle_t));
- memcpy(rankHandles+info.rank, info.extHandleListen, sizeof(ncclNetHandle_t));
+ memcpy(rankAddressesRoot+info.rank, &info.extAddressListenRoot, sizeof(union socketAddress));
+ memcpy(rankAddresses+info.rank, &info.extAddressListen, sizeof(union socketAddress));
++c;
TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks);
@@ -208,44 +143,46 @@ static void *bootstrapRoot(void* listenComm) {
// Send the connect handle for the next rank in the AllGather ring
for (int r=0; r<nranks; ++r) {
int next = (r+1) % nranks;
- void *tmpSendComm;
- NCCLCHECKGOTO(bootstrapNetConnect(0, rankHandlesRoot+r, &tmpSendComm), res, out);
- NCCLCHECKGOTO(bootstrapNetSend(tmpSendComm, rankHandles+next, sizeof(ncclNetHandle_t)), res, out);
- NCCLCHECKGOTO(bootstrapNetCloseSend(tmpSendComm), res, out);
+ int tmpSendFd;
+ NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out);
+ NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out);
+ close(tmpSendFd);
}
TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks);
out:
- bootstrapNetCloseListen(listenComm);
- if (rankHandles) free(rankHandles);
- if (rankHandlesRoot) free(rankHandlesRoot);
+ close(listenFd);
+ if (rankAddresses) free(rankAddresses);
+ if (rankAddressesRoot) free(rankAddressesRoot);
+ if (zero) free(zero);
TRACE(NCCL_INIT, "DONE");
return NULL;
}
ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) {
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
- void* listenComm;
- NCCLCHECK(bootstrapNetListen(idFromEnv ? dontCareIf : 0, netHandle, &listenComm));
+ union socketAddress* connectAddr = (union socketAddress*) id;
+ int listenFd;
+ NCCLCHECK(createListenSocket(&listenFd, connectAddr));
pthread_t thread;
- pthread_create(&thread, NULL, bootstrapRoot, listenComm);
+ pthread_create(&thread, NULL, bootstrapRoot, (void*)(uint64_t)listenFd);
return ncclSuccess;
}
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) {
- static_assert(sizeof(ncclNetHandle_t) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId");
+ static_assert(sizeof(union socketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId");
memset(id, 0, sizeof(ncclUniqueId));
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
+ union socketAddress* connectAddr = (union socketAddress*) id;
char* env = getenv("NCCL_COMM_ID");
if (env) {
INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env);
- if (bootstrapNetCreateHandle(netHandle, env) != 0) {
+ if (GetSocketAddrFromString(connectAddr, env) != ncclSuccess) {
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return ncclInvalidArgument;
}
} else {
+ memcpy(id, &bootstrapNetIfAddr, sizeof(union socketAddress));
NCCLCHECK(bootstrapCreateRoot(id, false));
}
@@ -254,24 +191,135 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) {
struct unexConn {
int peer;
- void* comm;
+ int fd;
struct unexConn* next;
};
+// Remote allocator state
+struct remAllocState {
+ int cudaDev;
+ int listenFd;
+ int stop;
+};
+
struct extState {
- void* extBstrapListenComm;
- void* extBstrapRingRecvComm;
- void* extBstrapRingSendComm;
- ncclNetHandle_t* peerBstrapHandles;
+ int extListenFd;
+ int extRingRecvFd;
+ int extRingSendFd;
+ union socketAddress* peerCommAddresses;
+ union socketAddress* peerAllocAddresses;
struct unexConn* unexpectedConnections;
+ int cudaDev;
int rank;
int nranks;
- int dev;
+
+ // Intermediate memory allocation service
+ struct remAllocState* allocState;
+ pthread_t allocThread;
};
+#define MAX_SEGMENTS 128
+
+static ncclResult_t remoteAlloc(void** ptr, int fd) {
+ size_t size;
+ NCCLCHECK(socketRecv(fd, &size, sizeof(size_t)));
+ cudaIpcMemHandle_t devIpc;
+ NCCLCHECK(ncclCudaCalloc((char**)ptr, size));
+ cudaError_t res = cudaIpcGetMemHandle(&devIpc, *ptr);
+ if (res != cudaSuccess) {
+ WARN("[Rem Allocator] cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res));
+ cudaFree(*ptr);
+ CUDACHECK(res);
+ }
+ // The CUDA IPC
+ NCCLCHECK(socketSend(fd, &devIpc, sizeof(cudaIpcMemHandle_t)));
+ // And the direct pointer
+ NCCLCHECK(socketSend(fd, ptr, sizeof(void*)));
+ return ncclSuccess;
+}
+
+#include <poll.h>
+
+// Service thread to allocate memory for other GPUs, used as intermediate step.
+void* ncclRemoteMemAllocationService(void* args) {
+ struct remAllocState* state = (struct remAllocState *) args;
+ if (cudaSetDevice(state->cudaDev) != cudaSuccess) {
+ WARN("[Rem Allocator] Failed to set CUDA device %d\n", state->cudaDev);
+ }
+
+ // Prepare poll descriptor
+ void* segments[MAX_SEGMENTS];
+ struct pollfd pollfds[MAX_SEGMENTS+1];
+ for (int s=0; s<MAX_SEGMENTS; s++) segments[s] = NULL;
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ pollfds[s].fd = -1;
+ pollfds[s].events = POLLHUP;
+ }
+ pollfds[MAX_SEGMENTS].fd = state->listenFd;
+ pollfds[MAX_SEGMENTS].events = POLLIN;
+
+ int nbuffers = 0;
+ while (state->stop == 0 || (state->stop == 1 && nbuffers > 0)) {
+ if (int error = poll(pollfds, MAX_SEGMENTS+1, 100/*ms*/) < 0) {
+ WARN("[Rem Allocator] Poll failed with error %d", error);
+ return NULL;
+ }
+ if (pollfds[MAX_SEGMENTS].revents) {
+ int s = 0;
+ while (segments[s] != NULL && s < MAX_SEGMENTS) s++;
+ if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) {
+ pollfds[s].fd = -1;
+ } else {
+ if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) {
+ WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd);
+ close(pollfds[s].fd);
+ pollfds[s].fd = -1;
+ } else {
+ nbuffers++;
+ }
+ }
+ }
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ if (pollfds[s].revents & POLLHUP) {
+ if (cudaFree(segments[s]) != cudaSuccess) {
+ WARN("[Rem Allocator] cudaFree %p failed", segments[s]);
+ }
+ segments[s] = NULL;
+ close(pollfds[s].fd);
+ pollfds[s].fd = -1;
+ nbuffers--;
+ }
+ }
+ }
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ if (segments[s]) cudaFree(segments[s]);
+ close(pollfds[s].fd);
+ }
+ close(state->listenFd);
+ free(state);
+ return NULL;
+}
+
+ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr) {
+ struct extState* state = (struct extState*)commState;
+ int fd;
+ ncclResult_t res;
+ *id = -1;
+ NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank));
+ NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end);
+ NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(cudaIpcMemHandle_t)), res, end);
+ NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end);
+ *id = fd;
+end:
+ return res;
+}
+
+ncclResult_t bootstrapRemFree(int id, int rank, void* commState) {
+ SYSCHECK(close(id), "close");
+ return ncclSuccess;
+}
+
ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
- bool idFromEnv = getenv("NCCL_COMM_ID") != NULL;
struct extState* state;
NCCLCHECK(ncclCalloc(&state, 1));
state->rank = rank;
@@ -283,19 +331,15 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
struct extInfo info = { 0 };
info.rank = rank;
info.nranks = nranks;
- void *tmpSendComm, *tmpRecvComm;
- // Pass the remote address to listen via info
- if (idFromEnv) {
- memcpy(&info.extHandleListen, netHandle, sizeof(ncclNetHandle_t));
- memcpy(&info.extHandleListenRoot, netHandle, sizeof(ncclNetHandle_t));
- }
- // listen will return the local address via info (specify interface type 'findSubnetIf')
- state->dev = idFromEnv ? findSubnetIf : 0;
- void* extBstrapListenCommRoot;
- NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListen, &state->extBstrapListenComm));
- NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListenRoot, &extBstrapListenCommRoot));
+ int tmpSendFd, tmpRecvFd;
- // stagger connection times to avoid an overload of the root at very high rank counts
+ int extListenFdRoot;
+ memcpy(&info.extAddressListen, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ memcpy(&info.extAddressListenRoot, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ NCCLCHECK(createListenSocket(&state->extListenFd, &info.extAddressListen));
+ NCCLCHECK(createListenSocket(&extListenFdRoot, &info.extAddressListenRoot));
+
+ // stagger connection times to avoid an overload of the root
if (nranks > 128) {
long msec = rank;
struct timespec tv;
@@ -306,25 +350,35 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
}
// send info on my listening socket to root
- NCCLCHECK(bootstrapNetConnect(state->dev, netHandle, &tmpSendComm));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, &info, sizeof(info)));
- NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));
+ union socketAddress* rootAddr = (union socketAddress*)id;
+ NCCLCHECK(connectAddress(&tmpSendFd, rootAddr));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info)));
+ close(tmpSendFd);
// get info on my "next" rank in the bootstrap ring from root
- ncclNetHandle_t extHandleNext;
- NCCLCHECK(bootstrapNetAccept(extBstrapListenCommRoot, &tmpRecvComm));
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext)));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
- NCCLCHECK(bootstrapNetCloseListen(extBstrapListenCommRoot));
+ union socketAddress extAddressNext;
+ NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext)));
+ close(tmpRecvFd);
+ close(extListenFdRoot);
- NCCLCHECK(bootstrapNetConnect(state->dev, &extHandleNext, &state->extBstrapRingSendComm));
+ NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext));
// Accept the connect request from the previous rank in the AllGather ring
- NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &state->extBstrapRingRecvComm));
+ NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd));
// AllGather all listen handlers
- NCCLCHECK(ncclCalloc(&state->peerBstrapHandles, nranks));
- memcpy(state->peerBstrapHandles+rank, info.extHandleListen, sizeof(ncclNetHandle_t));
- NCCLCHECK(bootstrapAllGather(state, state->peerBstrapHandles, sizeof(ncclNetHandle_t)));
+ NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks));
+ memcpy(state->peerCommAddresses+rank, &info.extAddressListen, sizeof(union socketAddress));
+ NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union socketAddress)));
+
+ // Create the memory allocation service
+ NCCLCHECK(ncclCalloc(&state->peerAllocAddresses, nranks));
+ memcpy(state->peerAllocAddresses+rank, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ NCCLCHECK(ncclCalloc(&state->allocState, 1));
+ CUDACHECK(cudaGetDevice(&state->allocState->cudaDev));
+ NCCLCHECK(createListenSocket(&state->allocState->listenFd, state->peerAllocAddresses+rank));
+ pthread_create(&state->allocThread, NULL, ncclRemoteMemAllocationService, state->allocState);
+ NCCLCHECK(bootstrapAllGather(state, state->peerAllocAddresses, sizeof(union socketAddress)));
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
@@ -348,9 +402,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
size_t sslice = (rank - i + nranks) % nranks;
// Send slice to the right
- NCCLCHECK(bootstrapNetSend(state->extBstrapRingSendComm, data+sslice*size, size));
+ NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size));
// Recv slice from the left
- NCCLCHECK(bootstrapNetRecv(state->extBstrapRingRecvComm, data+rslice*size, size));
+ NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size));
}
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
@@ -359,20 +413,20 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
ncclResult_t bootstrapSend(void* commState, int peer, void* data, int size) {
struct extState* state = (struct extState*)commState;
- void* tmpSendComm;
- NCCLCHECK(bootstrapNetConnect(state->dev, state->peerBstrapHandles+peer, &tmpSendComm));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, &state->rank, sizeof(int)));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, data, size));
- NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));
+ int tmpSendFd;
+ NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int)));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size));
+ close(tmpSendFd);
return ncclSuccess;
}
-ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) {
+ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int fd) {
// New unex
struct unexConn* unex;
NCCLCHECK(ncclCalloc(&unex, 1));
unex->peer = peer;
- unex->comm = comm;
+ unex->fd = fd;
// Enqueue
struct unexConn* list = state->unexpectedConnections;
@@ -385,7 +439,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) {
return ncclSuccess;
}
-void* unexpectedDequeue(struct extState* state, int peer) {
+int unexpectedDequeue(struct extState* state, int peer) {
struct unexConn* elem = state->unexpectedConnections;
struct unexConn* prev = NULL;
while (elem) {
@@ -395,41 +449,41 @@ void* unexpectedDequeue(struct extState* state, int peer) {
} else {
prev->next = elem->next;
}
- void* comm = elem->comm;
+ int fd = elem->fd;
free(elem);
- return comm;
+ return fd;
}
prev = elem;
elem = elem->next;
}
- return NULL;
+ return -1;
}
// We can't know who we'll receive from, so we need to receive everything at once
ncclResult_t bootstrapRecv(void* commState, int peer, void* data, int size) {
struct extState* state = (struct extState*)commState;
- void* tmpRecvComm;
+ int tmpRecvFd;
// Search unexpected connections first
- if ((tmpRecvComm = unexpectedDequeue(state, peer)) != NULL) {
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
+ if ((tmpRecvFd = unexpectedDequeue(state, peer)) != -1) {
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
+ close(tmpRecvFd);
return ncclSuccess;
}
// Then look for new connections
while (1) {
- NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &tmpRecvComm));
+ NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd));
int newPeer;
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &newPeer, sizeof(int)));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int)));
if (newPeer == peer) {
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
+ close(tmpRecvFd);
return ncclSuccess;
}
// Unexpected connection. Save for later.
- NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvComm));
+ NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvFd));
}
}
@@ -439,11 +493,17 @@ ncclResult_t bootstrapClose(void* commState) {
WARN("Unexpected connections are not empty.\n");
return ncclInternalError;
}
- NCCLCHECK(bootstrapNetCloseListen(state->extBstrapListenComm));
- NCCLCHECK(bootstrapNetCloseSend(state->extBstrapRingSendComm));
- NCCLCHECK(bootstrapNetCloseRecv(state->extBstrapRingRecvComm));
+ close(state->extListenFd);
+ close(state->extRingSendFd);
+ close(state->extRingRecvFd);
+
+ state->allocState->stop = 1;
+
+ // Join the allocThread so we catch resource leaks as being hung here
+ // pthread_join(state->allocThread, nullptr);
- free(state->peerBstrapHandles);
+ free(state->peerCommAddresses);
+ free(state->peerAllocAddresses);
free(state);
return ncclSuccess;
@@ -451,10 +511,12 @@ ncclResult_t bootstrapClose(void* commState) {
ncclResult_t bootstrapAbort(void* commState) {
struct extState* state = (struct extState*)commState;
- bootstrapNetCloseListen(state->extBstrapListenComm);
- bootstrapNetCloseSend(state->extBstrapRingSendComm);
- bootstrapNetCloseRecv(state->extBstrapRingRecvComm);
- free(state->peerBstrapHandles);
+ close(state->extListenFd);
+ close(state->extRingSendFd);
+ close(state->extRingRecvFd);
+ state->allocState->stop = 2;
+ free(state->peerCommAddresses);
+ free(state->peerAllocAddresses);
free(state);
return ncclSuccess;
}
diff --git a/src/channel.cc b/src/channel.cc
index d22ea63..b2cc5b7 100644
--- a/src/channel.cc
+++ b/src/channel.cc
@@ -25,14 +25,14 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelid) {
}
// Per-channel operation list.
- NCCLCHECK(ncclCudaHostCalloc(&channel->collectives, NCCL_MAX_OPS));
+ NCCLCHECK(ncclCudaHostCalloc(&channel->workFifo, NCCL_MAX_OPS));
return ncclSuccess;
}
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks) {
if (channel->id == -1) return ncclSuccess;
// Operation list
- NCCLCHECK(ncclCudaHostFree(channel->collectives));
+ NCCLCHECK(ncclCudaHostFree(channel->workFifo));
// Free Ring index to rank tables
free(channel->ring.userRanks);
diff --git a/src/collectives/all_gather.cc b/src/collectives/all_gather.cc
index 348c176..266fd5a 100644
--- a/src/collectives/all_gather.cc
+++ b/src/collectives/all_gather.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,7 +11,8 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollAllGather, "AllGather",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
return ncclEnqueueCheck(&info);
diff --git a/src/collectives/all_reduce.cc b/src/collectives/all_reduce.cc
index 7796d5b..b67f3be 100644
--- a/src/collectives/all_reduce.cc
+++ b/src/collectives/all_reduce.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -10,7 +10,8 @@ NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollAllReduce, "AllReduce",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
return ncclEnqueueCheck(&info);
diff --git a/src/collectives/broadcast.cc b/src/collectives/broadcast.cc
index 042301b..db0fb49 100644
--- a/src/collectives/broadcast.cc
+++ b/src/collectives/broadcast.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,7 +11,8 @@ NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollBroadcast, "Broadcast",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
return ncclEnqueueCheck(&info);
diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu
index 109c341..4022e2e 100644
--- a/src/collectives/device/all_gather.cu
+++ b/src/collectives/device/all_gather.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -8,4 +8,4 @@
#include "common.h"
#include "collectives.h"
-IMPL_COLL_C(ncclAllGather, ncclCollAllGather);
+IMPL_COLL_C(AllGather);
diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h
index f7556b0..e057dc8 100644
--- a/src/collectives/device/all_gather.h
+++ b/src/collectives/device/all_gather.h
@@ -8,197 +8,201 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllGatherRingKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
- const ssize_t size = args->coll.count;
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*realChunkSize;
-
- /////////////// begin AllGather steps ///////////////
- ssize_t offset;
- int nelem = min(realChunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- if (thisInput + chunkOffset == thisOutput + offset) { // In place
- prims.directSend(thisInput+chunkOffset, offset, nelem);
- } else {
- prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * ALLGATHER_CHUNKSTEPS;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*realChunkSize;
+
+ /////////////// begin AllGather steps ///////////////
+ ssize_t offset;
+ int nelem = min(realChunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ if (thisInput + chunkOffset == thisOutput + offset) { // In place
+ prims.directSend(thisInput+chunkOffset, offset, nelem);
+ } else {
+ prims.directCopySend(thisInput+chunkOffset, thisOutput+offset, offset, nelem);
+ }
+
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
+ }
+
+ // Make final copy from buffer to dest.
+ rankDest = ring->devUserRanks[1];
+ offset = chunkOffset + rankDest * size;
+
+ // Final wait/copy.
+ prims.directRecv(thisOutput+offset, offset, nelem);
+ }
}
-
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- prims.directRecvCopySend(thisOutput+offset, offset, nelem);
- }
-
- // Make final copy from buffer to dest.
- rankDest = ring->devUserRanks[1];
- offset = chunkOffset + rankDest * size;
-
- // Final wait/copy.
- prims.directRecv(thisOutput+offset, offset, nelem);
- }
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllGatherTreeKernel(struct CollectiveArgs* args) { }
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllGatherCollNetKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherRingLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- if (size-gridOffset < loopSize) {
- chunkSize = args->coll.lastChunkSize;
- }
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
-
- /////////////// begin AllGather steps ///////////////
- ssize_t offset;
- int nelem = min(chunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- if (thisInput + chunkOffset == thisOutput + offset) { // In place
- LLprims.send(thisInput+chunkOffset, nelem);
- } else {
- LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ if (size-gridOffset < loopSize) {
+ chunkSize = args->coll.lastChunkSize;
+ }
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin AllGather steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ if (thisInput + chunkOffset == thisOutput + offset) { // In place
+ LLprims.send(thisInput+chunkOffset, nelem);
+ } else {
+ LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
+ }
+
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+
+ // step k-1: final store
+ rankDest = ring->devUserRanks[1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recv(thisOutput+offset, nelem);
+ }
}
-
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvCopySend(thisOutput+offset, nelem);
- }
-
- // step k-1: final store
- rankDest = ring->devUserRanks[1];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recv(thisOutput+offset, nelem);
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherTreeLLKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherCollNetLLKernel(struct CollectiveArgs* args) { }
+};
#include "prims_ll128.h"
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherRingLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
- // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
-
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
-
- /////////////// begin AllGather steps ///////////////
- ssize_t offset;
- int nelem = min(chunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- if (thisInput + chunkOffset == thisOutput + offset) { // In place
- LLprims.send(thisInput+chunkOffset, nelem);
- } else {
- LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
+
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin AllGather steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ if (thisInput + chunkOffset == thisOutput + offset) { // In place
+ LLprims.send(thisInput+chunkOffset, nelem);
+ } else {
+ LLprims.copySend(thisInput+chunkOffset, thisOutput+offset, nelem);
+ }
+
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+
+ // step k-1: final store
+ rankDest = ring->devUserRanks[1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recv(thisOutput+offset, nelem);
+ }
}
+};
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvCopySend(thisOutput+offset, nelem);
- }
-
- // step k-1: final store
- rankDest = ring->devUserRanks[1];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recv(thisOutput+offset, nelem);
- }
-}
+template<int PROTO, class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllGather, NCCL_ALGO_TREE, PROTO, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherTreeLL128Kernel(struct CollectiveArgs* args) { }
+template<int PROTO, class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllGather, NCCL_ALGO_COLLNET, PROTO, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllGatherCollNetLL128Kernel(struct CollectiveArgs* args) { }
diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu
index 85d007e..e7c3c28 100644
--- a/src/collectives/device/all_reduce.cu
+++ b/src/collectives/device/all_reduce.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -8,4 +8,4 @@
#include "common.h"
#include "collectives.h"
-IMPL_COLL_R(ncclAllReduce, ncclCollAllReduce);
+IMPL_COLL_R(AllReduce);
diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h
index d4eee03..fe2e6fc 100644
--- a/src/collectives/device/all_reduce.h
+++ b/src/collectives/device/all_reduce.h
@@ -8,524 +8,597 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
- const ssize_t size = args->coll.count;
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm);
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
- ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
-
- /////////////// begin AllReduce steps ///////////////
- ssize_t offset;
- int nelem;
- int chunk;
-
- // step 0: push data to next GPU
- chunk = ring->devUserRanks[nranks-1];
- offset = chunkOffset + chunk * realChunkSize;
- nelem = min(realChunkSize, size-offset);
-
- prims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * ALLREDUCE_CHUNKSTEPS;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
+ ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*nranks*realChunkSize;
+
+ /////////////// begin AllReduce steps ///////////////
+ ssize_t offset;
+ int nelem;
+ int chunk;
+
+ // step 0: push data to next GPU
+ chunk = ring->devUserRanks[nranks-1];
offset = chunkOffset + chunk * realChunkSize;
nelem = min(realChunkSize, size-offset);
- prims.recvReduceSend(thisInput+offset, nelem);
- }
+ prims.send(thisInput+offset, nelem);
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data and push to the next GPU
- chunk = ring->devUserRanks[0];
- offset = chunkOffset + chunk * realChunkSize;
- nelem = min(realChunkSize, size-offset);
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + chunk * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
- prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data and push to the next GPU
+ chunk = ring->devUserRanks[0];
offset = chunkOffset + chunk * realChunkSize;
nelem = min(realChunkSize, size-offset);
- prims.directRecvCopySend(thisOutput+offset, offset, nelem);
- }
+ prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
- // Make final copy from buffer to dest.
- chunk = ring->devUserRanks[1];
- offset = chunkOffset + chunk * realChunkSize;
- nelem = min(realChunkSize, size-offset);
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + chunk * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
- // Final wait/copy.
- prims.directRecv(thisOutput+offset, offset, nelem);
- }
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- int chunkSize = args->coll.lastChunkSize;
- const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- if (loopSize > size) {
- chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
+ }
+
+ // Make final copy from buffer to dest.
+ chunk = ring->devUserRanks[1];
+ offset = chunkOffset + chunk * realChunkSize;
+ nelem = min(realChunkSize, size-offset);
+
+ // Final wait/copy.
+ prims.directRecv(thisOutput+offset, offset, nelem);
+ }
}
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-2*WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->tree;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ int chunkSize = args->coll.lastChunkSize;
+ const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ }
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- do {
- struct ncclTree* tree = &channel->treeUp;
- // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Up
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- prims.send(thisInput+offset, nelem);
- } else {
- prims.recvReduceSend(thisInput+offset, nelem);
+#if 1
+ if (tid < nthreads+WARP_SIZE) {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
+ prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
}
}
- } while(0);
- do {
- struct ncclTree* tree = &channel->treeDn;
- // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Down
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- prims.directSend(thisOutput+offset, offset, nelem);
- } else if (tree->down[0] == -1) {
- prims.directRecv(thisOutput+offset, offset, nelem);
+ if (tid < nthreads+WARP_SIZE) {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
+ prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.directSend(thisOutput+offset, offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.directRecv(thisOutput+offset, offset, nelem);
+ } else {
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
+ }
+ }
+ }
+#else
+ int nthreadsSplit = nthreads/2;
+ if (nthreadsSplit == 256) nthreadsSplit += 64;
+ if (tree->up == -1) {
+ if (tid < nthreads+WARP_SIZE) {
+ // ReduceAndBroadcast : max number of recv is 3, max number of send is 3
+ ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY, 1, FUNC>
+ prims(tid, nthreads, tree->down, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ prims.directRecvReduceCopySend(thisInput+offset, thisOutput+offset, offset, nelem);
+ }
+ }
+ } else {
+ if (tid < nthreadsSplit + WARP_SIZE) {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_DEV_ARITY, 1, 0, FUNC>
+ prims(tid, nthreadsSplit, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->down[0] == -1) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
} else {
- prims.directRecvCopySend(thisOutput+offset, offset, nelem);
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_DEV_ARITY, 1, FUNC>
+ prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &tree->up, tree->down, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 2);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->down[0] == -1) {
+ prims.directRecv(thisOutput+offset, offset, nelem);
+ } else {
+ prims.directRecvCopySend(thisOutput+offset, offset, nelem);
+ }
+ }
}
}
- } while(0);
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclAllReduceCollNetKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- int chunkSize = args->coll.lastChunkSize;
- const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- if (loopSize > size) {
- chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+#endif
}
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->collTree;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ int chunkSize = args->coll.lastChunkSize;
+ const ssize_t minChunkSize = nthreads*8*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ }
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- if (blockIdx.x < nChannels) { // first half of the channels do reduce
- struct ncclTree* tree = &channel->collTreeUp;
- ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Up
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- prims.send(thisInput+offset, nelem);
- } else {
- prims.recvReduceSend(thisInput+offset, nelem);
+ if (blockIdx.x < nChannels) { // first half of the channels do reduce
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC>
+ prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
}
}
- }
- if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
- struct ncclTree* tree = &channel->collTreeDn;
- ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Down
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- prims.send(thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- prims.recv(thisOutput+offset, nelem);
- } else {
- prims.recvCopySend(thisOutput+offset, nelem);
+ if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 1, 0, FUNC>
+ prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ prims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ prims.recv(thisOutput+offset, nelem);
+ } else {
+ prims.recvCopySend(thisOutput+offset, nelem);
+ }
}
}
}
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceRingLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*nranks*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
-
- /////////////// begin AllReduce steps ///////////////
- ssize_t offset;
- int nelem;
- int chunk;
-
- // step 0: push data to next GPU
- chunk = ring->devUserRanks[nranks-1];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
-
- LLprims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads * (sizeof(uint64_t)) / sizeof(T);
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*nranks*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
+
+ /////////////// begin AllReduce steps ///////////////
+ ssize_t offset;
+ int nelem;
+ int chunk;
+
+ // step 0: push data to next GPU
+ chunk = ring->devUserRanks[nranks-1];
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
- LLprims.recvReduceSend(thisInput+offset, nelem);
- }
+ LLprims.send(thisInput+offset, nelem);
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data and push to the next GPU
- chunk = ring->devUserRanks[0];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
- LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data and push to the next GPU
+ chunk = ring->devUserRanks[0];
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
- LLprims.recvCopySend(thisOutput+offset, nelem);
- }
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
- // Make final copy from buffer to dest.
- chunk = ring->devUserRanks[1];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
- // Here we need to copy from buffer to this output.
- LLprims.recv(thisOutput+offset, nelem);
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- if (loopSize > size) {
- chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
+
+ // Make final copy from buffer to dest.
+ chunk = ring->devUserRanks[1];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
+
+ // Here we need to copy from buffer to this output.
+ LLprims.recv(thisOutput+offset, nelem);
+ }
}
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->tree;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ }
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- do {
- struct ncclTree* tree = &channel->treeUp;
- // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Up
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- LLprims.send(thisInput+offset, nelem);
- } else {
- LLprims.recvReduceSend(thisInput+offset, nelem);
+ do {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclLLPrimitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
}
- }
- } while(0);
+ } while(0);
- do {
- struct ncclTree* tree = &channel->treeDn;
- // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Down
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- LLprims.send(thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- LLprims.recv(thisOutput+offset, nelem);
- } else {
- LLprims.recvCopySend(thisOutput+offset, nelem);
+ do {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
}
- }
- } while(0);
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceCollNetLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- if (loopSize > size) {
- chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ } while(0);
}
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->collTree;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const ssize_t minChunkSize = nthreads*sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ }
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
- if (blockIdx.x < nChannels) { // first half of the channels do reduce
- struct ncclTree* tree = &channel->collTreeUp;
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Up
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- LLprims.send(thisInput+offset, nelem);
- } else {
- LLprims.recvReduceSend(thisInput+offset, nelem);
+ if (blockIdx.x < nChannels) { // first half of the channels do reduce
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
}
}
- }
- if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
- struct ncclTree* tree = &channel->collTreeDn;
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Down
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (tree->up == -1) {
- LLprims.send(thisOutput+offset, nelem);
- } else if (tree->down[0] == -1) {
- LLprims.recv(thisOutput+offset, nelem);
- } else {
- LLprims.recvCopySend(thisOutput+offset, nelem);
+ if (blockIdx.x >= nChannels) { // second half of the channels do broadcast
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->up == -1) {
+ LLprims.send(thisOutput+offset, nelem);
+ } else if (tree->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
}
}
}
-}
+};
#include "prims_ll128.h"
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceRingLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
- // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*nranks*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
-
- /////////////// begin AllReduce steps ///////////////
- ssize_t offset;
- int nelem;
- int chunk;
-
- // step 0: push data to next GPU
- chunk = ring->devUserRanks[nranks-1];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
-
- LLprims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*nranks*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize, chunkSize);
+
+ /////////////// begin AllReduce steps ///////////////
+ ssize_t offset;
+ int nelem;
+ int chunk;
+
+ // step 0: push data to next GPU
+ chunk = ring->devUserRanks[nranks-1];
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
- LLprims.recvReduceSend(thisInput+offset, nelem);
- }
+ LLprims.send(thisInput+offset, nelem);
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data and push to the next GPU
- chunk = ring->devUserRanks[0];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
- LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
- // k-2 steps: copy to next GPU
- for (int j=1; j<nranks-1; ++j) {
- chunk = ring->devUserRanks[nranks-j];
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data and push to the next GPU
+ chunk = ring->devUserRanks[0];
offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
nelem = min(chunkSize, size-offset);
- LLprims.recvCopySend(thisOutput+offset, nelem);
- }
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
- // Make final copy from buffer to dest.
- chunk = ring->devUserRanks[1];
- offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
- nelem = min(chunkSize, size-offset);
+ // k-2 steps: copy to next GPU
+ for (int j=1; j<nranks-1; ++j) {
+ chunk = ring->devUserRanks[nranks-j];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
- // Here we need to copy from buffer to this output.
- LLprims.recv(thisOutput+offset, nelem);
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceTreeLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclTree* treeUp = &channel->treeUp;
- struct ncclTree* treeDn = &channel->treeDn;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = args->coll.lastChunkSize;
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
- const ssize_t loopSize = nChannels*chunkSize;
- int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
- const ssize_t size = args->coll.count;
-
- if (loopSize > size) {
- chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
- }
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
+ // Make final copy from buffer to dest.
+ chunk = ring->devUserRanks[1];
+ offset = gridOffset + (chunk*nChannels+bid) * chunkSize;
+ nelem = min(chunkSize, size-offset);
- if (treeUp->up == -1) {
- // ReduceAndBroadcast : max number of recv is 3, max number of send is 3
- ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, treeUp->down, treeDn->down, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
+ // Here we need to copy from buffer to this output.
+ LLprims.recv(thisOutput+offset, nelem);
}
- } else {
- if (tid < nthreadsSplit) {
- // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
- ncclLL128Primitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreadsSplit, treeUp->down, &treeUp->up, stepSize, channel, comm);
+ }
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_TREE, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclTree* tree = &channel->tree;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = args->coll.lastChunkSize;
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/8;
+ const ssize_t loopSize = nChannels*chunkSize;
+ int nthreadsSplit = NCCL_LL128_SPLIT(nthreads);
+ const ssize_t size = args->coll.count;
+
+ if (loopSize > size) {
+ chunkSize = DIVUP(size, nChannels*minChunkSize)*minChunkSize;
+ }
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ if (tree->up == -1) {
+ // ReduceAndBroadcast : max number of recv is 3, max number of send is 3
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, NCCL_MAX_DEV_ARITY> LLprims(tid, nthreads, tree->down, tree->down, stepSize, channel, comm);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Up
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
- if (treeUp->down[0] == -1) {
- LLprims.send(thisInput+offset, nelem);
- } else {
- LLprims.recvReduceSend(thisInput+offset, nelem);
- }
+ LLprims.recvReduceCopySend(thisInput+offset, thisOutput+offset, nelem);
}
} else {
- // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
- ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &treeDn->up, treeDn->down, stepSize, channel, comm);
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- // Down
- ssize_t offset = gridOffset + bid*chunkSize;
- int nelem = min(chunkSize, size-offset);
- if (treeDn->down[0] == -1) {
- LLprims.recv(thisOutput+offset, nelem);
- } else {
- LLprims.recvCopySend(thisOutput+offset, nelem);
+ if (tid < nthreadsSplit) {
+ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
+ ncclLL128Primitives<T, FUNC, NCCL_MAX_DEV_ARITY, 1> LLprims(tid, nthreadsSplit, tree->down, &tree->up, stepSize, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Up
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->down[0] == -1) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
+ } else {
+ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
+ ncclLL128Primitives<T, FUNC, 1, NCCL_MAX_DEV_ARITY> LLprims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, stepSize, channel, comm);
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ // Down
+ ssize_t offset = gridOffset + bid*chunkSize;
+ int nelem = min(chunkSize, size-offset);
+ if (tree->down[0] == -1) {
+ LLprims.recv(thisOutput+offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput+offset, nelem);
+ }
}
}
}
}
-}
+};
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclAllReduceCollNetLL128Kernel(struct CollectiveArgs* args) { }
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+__device__ void run(struct ncclWorkElem* args) { }
+};
diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu
index 8c8dbb6..7759585 100644
--- a/src/collectives/device/broadcast.cu
+++ b/src/collectives/device/broadcast.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -8,4 +8,4 @@
#include "common.h"
#include "collectives.h"
-IMPL_COLL_C(ncclBroadcast, ncclCollBroadcast);
+IMPL_COLL_C(Broadcast);
diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h
index a4c30d2..72216ac 100644
--- a/src/collectives/device/broadcast.h
+++ b/src/collectives/device/broadcast.h
@@ -8,152 +8,155 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclBroadcastRingKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS;
- const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = ring->devUserRanks[0];
- const int nextRank = ring->devUserRanks[1];
- const int root = args->coll.root;
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t offset = gridOffset + bid*realChunkSize;
- int nelem = min(realChunkSize, size-offset);
-
- if (rank == root) {
- if (thisInput == thisOutput) {
- prims.send(thisInput+offset, nelem);
- } else {
- prims.copySend(thisInput+offset, thisOutput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * BROADCAST_CHUNKSTEPS;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = ring->devUserRanks[0];
+ const int nextRank = ring->devUserRanks[1];
+ const int root = args->coll.root;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t offset = gridOffset + bid*realChunkSize;
+ int nelem = min(realChunkSize, size-offset);
+
+ if (rank == root) {
+ if (thisInput == thisOutput) {
+ prims.send(thisInput+offset, nelem);
+ } else {
+ prims.copySend(thisInput+offset, thisOutput+offset, nelem);
+ }
+ } else if (nextRank == root) {
+ prims.recv(thisOutput+offset, nelem);
+ } else {
+ prims.recvCopySend(thisOutput+offset, nelem);
+ }
}
- } else if (nextRank == root) {
- prims.recv(thisOutput+offset, nelem);
- } else {
- prims.recvCopySend(thisOutput+offset, nelem);
}
- }
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclBroadcastTreeKernel(struct CollectiveArgs* args) { }
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclBroadcastCollNetKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastRingLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = ring->devUserRanks[0];
- const int nextRank = ring->devUserRanks[1];
- const int root = args->coll.root;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- if (size-gridOffset < loopSize) {
- chunkSize = args->coll.lastChunkSize;
- }
- ssize_t offset = gridOffset + bid*chunkSize;
-
- int nelem = min(chunkSize, size-offset);
- if (rank == root) {
- if (thisInput == thisOutput) {
- LLprims.send(thisInput+offset, nelem);
- } else {
- LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = ring->devUserRanks[0];
+ const int nextRank = ring->devUserRanks[1];
+ const int root = args->coll.root;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ if (size-gridOffset < loopSize) {
+ chunkSize = args->coll.lastChunkSize;
+ }
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (rank == root) {
+ if (thisInput == thisOutput) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
+ }
+ } else if (nextRank == root) {
+ LLprims.recv(thisOutput + offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput + offset, nelem);
+ }
}
- } else if (nextRank == root) {
- LLprims.recv(thisOutput + offset, nelem);
- } else {
- LLprims.recvCopySend(thisOutput + offset, nelem);
}
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastTreeLLKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastCollNetLLKernel(struct CollectiveArgs* args) { }
+};
#include "prims_ll128.h"
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastRingLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = ring->devUserRanks[0];
- const int nextRank = ring->devUserRanks[1];
- const int root = args->coll.root;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
- ssize_t offset = gridOffset + bid*chunkSize;
-
- int nelem = min(chunkSize, size-offset);
- if (rank == root) {
- if (thisInput == thisOutput) {
- LLprims.send(thisInput+offset, nelem);
- } else {
- LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = ring->devUserRanks[0];
+ const int nextRank = ring->devUserRanks[1];
+ const int root = args->coll.root;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (rank == root) {
+ if (thisInput == thisOutput) {
+ LLprims.send(thisInput+offset, nelem);
+ } else {
+ LLprims.copySend(thisInput + offset, thisOutput + offset, nelem);
+ }
+ } else if (nextRank == root) {
+ LLprims.recv(thisOutput + offset, nelem);
+ } else {
+ LLprims.recvCopySend(thisOutput + offset, nelem);
+ }
}
- } else if (nextRank == root) {
- LLprims.recv(thisOutput + offset, nelem);
- } else {
- LLprims.recvCopySend(thisOutput + offset, nelem);
}
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastTreeLL128Kernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclBroadcastCollNetLL128Kernel(struct CollectiveArgs* args) { }
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
index a76f4e8..265218a 100644
--- a/src/collectives/device/common.h
+++ b/src/collectives/device/common.h
@@ -10,6 +10,15 @@
#include "collectives.h"
#include "devcomm.h"
+
+#if __CUDA_ARCH__ >= 800
+#define COLL_UNROLL 8
+#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
+#else
+#define COLL_UNROLL 4
+#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY
+#endif
+
// Exit If Abort Barrier across CTA: make sure all threads exit consistently
// Each thread sets a predicate to true if abort == 1
// all CTA's threads enter the barrier and do a popc on their predicates being True
@@ -19,12 +28,12 @@ static inline __device__ void exitIfAbortBarrier(int abort) {
asm ("{");
asm volatile (" .reg .pred barr_pred;");
asm volatile (" setp.eq.u32 barr_pred,%0,1;" :: "r"(abort));
- asm volatile (" bar.red.popc.u32 %0, 13, barr_pred;" : "=r"(popc));
+ asm volatile (" bar.red.popc.u32 %0, 0, barr_pred;" : "=r"(popc));
asm ("}");
if (popc) { asm volatile ("exit;"); }
}
-typedef void(*ncclKern_t)(struct CollectiveArgs* args);
+typedef void(*ncclKern_t)(struct ncclWorkElem* args);
extern __device__ ncclKern_t ncclFuncs[];
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
@@ -32,131 +41,143 @@ static __device__ void load_parallel(void* dst, void* src, size_t size, int tid)
int* s = (int*)src;
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
}
-static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid, struct ncclDevComm* comm) {
+static __device__ void load_coll(struct ncclWork* localWork, struct ncclWork* hostWork, int tid, struct ncclDevComm* comm) {
+ __syncthreads();
+ load_parallel(localWork, hostWork, sizeof(struct ncclWork), tid);
// Check whether the last operation was aborted and make sure all threads exit
int abort = tid == 0 ? *(comm->abortFlag) : 0;
exitIfAbortBarrier(abort);
- load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid);
- __syncthreads();
- if (tid == 0) hostColl->active = 0;
+ if (tid == 0) hostWork->elems[0].active = 0;
}
-extern __device__ volatile uint64_t* ncclShmem;
+template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
+
+struct ncclShmemPtrs {
+ void* srcs[NCCL_MAX_DEV_ARITY+1];
+ void* dsts[NCCL_MAX_DEV_ARITY+1];
+};
+
+struct ncclShmemData {
+ union {
+ volatile uint64_t data[NCCL_LL128_SHMEM_SIZE];
+ struct ncclShmemPtrs ptrs[NCCL_MAX_GROUPS];
+ };
+ struct ncclWork localWork;
+};
-/* Functions for aggregation case */
-#define IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
-__device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \
- coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(args); \
+extern __device__ struct ncclShmemData *ncclShmem;
+template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL, int FINDEX>
+__device__ void ncclKernel(struct ncclWorkElem first) {
+ int tid = threadIdx.x;
+ int bid = blockIdx.x;
+ __shared__ struct ncclShmemData shmem;
+ ncclShmem = &shmem;
+
+ auto f = ncclFunction<FUNCTION, ALGO, PROTO, REDOP, T, UNROLL>();
+
+ struct ncclDevComm* comm = first.comm;
+ struct ncclChannel* channel = comm->channels+bid;
+ struct ncclWorkElem* w = NULL;
+ uint16_t index = first.index;
+
+ /* To optimize for latency, (only) the first operation is passed as argument.*/
+ if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) w = &first;
+
+ while (1) {
+ if (w == NULL) {
+ w = shmem.localWork.elems;
+ load_coll(&shmem.localWork, channel->workFifo+index, tid, comm);
+ }
+ if (tid < w->nThreads) {
+ if (w->funcIndex == FINDEX) {
+ f.run(w);
+ } else {
+ ncclFuncs[w->funcIndex](w);
+ }
+ }
+ index = (index+1) % NCCL_MAX_OPS;
+ if (w->active == 2) {
+ return;
+ }
+ w = NULL;
+ }
}
+// Only generate kernels for SUM
#if NCCL_OP == 0
-/* Kernels with the first operation inlined */
-#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
-__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
- int tid = threadIdx.x; \
- int bid = blockIdx.x; \
- __shared__ volatile uint64_t shmem[NCCL_LL128_SHMEM_SIZE]; \
- ncclShmem = shmem; \
- __shared__ struct ncclColl localColl; \
- \
- struct ncclDevComm* comm = firstColl.args.comm; \
- struct ncclChannel* channel = comm->channels+bid; \
- struct ncclColl* c; \
- if (bid == 0) { \
- /* To optimize for latency, (only) the first operation is passed as argument.*/ \
- c = &firstColl; \
- } else { \
- c = &localColl; \
- load_coll(c, channel->collectives+channel->collFifoHead, tid, comm); \
- } \
- while (1) { \
- if (tid < c->args.common.nThreads) { \
- if (c->funcIndex == fIndex) { \
- coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(&c->args); \
- } else { \
- ncclFuncs[c->funcIndex](&c->args); \
- } \
- } \
- int nextIndex = c->nextIndex; \
- if (tid == 0) channel->collFifoHead = nextIndex; \
- \
- if (c->active == 2) { \
- return; \
- } \
- \
- /* Load next collective operation*/ \
- c = &localColl; /* for bid 0 */ \
- load_coll(c, channel->collectives+nextIndex, tid, comm); \
- } \
+#define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \
+__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \
+ ncclKernel<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL, fIndex>(first); \
}
#else
-#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
+#define IMPL_COLL_KERN(func, algo, proto, redop, type, fInded)
#endif
+// Examples : AllReduce, RING, LL, Sum, uint8
+#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \
+__device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \
+ auto f = ncclFunction<ncclFunc##func, NCCL_ALGO_##algo, NCCL_PROTO_##proto, Func##redop<type>, type, COLL_UNROLL>(); \
+ f.run(args); \
+}
+
// Only generate inline kernels for LL
-#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \
- IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \
- IMPL_COLL_FUNC(coll##LL128, op, ncclFunc, dtype, ctype) \
- IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
- IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, al, NCCL_PROTO_LL)) \
+#define IMPL_COLL4(func, algo, redop, type, ncclType) \
+ IMPL_COLL_FUNC(func, algo, LL, redop, type) \
+ IMPL_COLL_FUNC(func, algo, LL128, redop, type) \
+ IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type) \
+ IMPL_COLL_KERN(func, algo, LL, redop, type, FUNC_INDEX(ncclFunc##func, nccl##redop, ncclType, NCCL_ALGO_##algo, NCCL_PROTO_LL)) \
-#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
- IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_TREE) \
- IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_RING) \
- IMPL_COLL4(coll##CollNet, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, NCCL_ALGO_COLLNET)
+#define IMPL_COLL3(func, redop, type, ncclType) \
+ IMPL_COLL4(func, TREE, redop, type, ncclType) \
+ IMPL_COLL4(func, RING, redop, type, ncclType) \
+ IMPL_COLL4(func, COLLNET, redop, type, ncclType)
#if NCCL_TYPE == 0
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int8_t, ncclInt8)
#elif NCCL_TYPE == 1
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, u8, uint8_t, ncclColl, ncclOp, ncclUint8)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint8_t, ncclUint8)
#elif NCCL_TYPE == 2
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, i32, int32_t, ncclColl, ncclOp, ncclInt32)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int32_t, ncclInt32)
#elif NCCL_TYPE == 3
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, u32, uint32_t, ncclColl, ncclOp, ncclUint32)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint32_t, ncclUint32)
#elif NCCL_TYPE == 4
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, i64, int64_t, ncclColl, ncclOp, ncclInt64)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, int64_t, ncclInt64)
#elif NCCL_TYPE == 5
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, uint64_t, ncclUint64)
#elif NCCL_TYPE == 6
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, half, ncclFloat16)
#elif NCCL_TYPE == 7
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, float, ncclFloat32)
#elif NCCL_TYPE == 8
-#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
- IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64)
+#define IMPL_COLL2(func, redop) IMPL_COLL3(func, redop, double, ncclFloat64)
#endif
// Reduction define all functions
#if NCCL_OP == 0
-#define IMPL_COLL_R(collf, colln) \
- IMPL_COLL2(collf, sum, FuncSum, colln, ncclSum);
+#define IMPL_COLL_R(func) IMPL_COLL2(func, Sum);
#elif NCCL_OP == 1
-#define IMPL_COLL_R(collf, colln) \
- IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd);
+#define IMPL_COLL_R(func) IMPL_COLL2(func, Prod);
#elif NCCL_OP == 2
-#define IMPL_COLL_R(collf, colln) \
- IMPL_COLL2(collf, min, FuncMin, colln, ncclMin);
+#define IMPL_COLL_R(func) IMPL_COLL2(func, Min);
#elif NCCL_OP == 3
-#define IMPL_COLL_R(collf, colln) \
- IMPL_COLL2(collf, max, FuncMax, colln, ncclMax);
+#define IMPL_COLL_R(func) IMPL_COLL2(func, Max);
#endif
-// Copy primitives only define one
#if NCCL_OP == 0 && NCCL_TYPE == 0
-#define IMPL_COLL_C(collf, colln) \
- IMPL_COLL3(collf, copy, FuncSum, i8, int8_t, colln, ncclSum, ncclInt8);
+// Copy primitives only define one function for copy
+#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t, ncclInt8);
+
+// Point-to-point primitives only have one function/kernel.
+#define IMPL_COLL_P(func) \
+ IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \
+ IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, 0);
#else
-#define IMPL_COLL_C(collf, colln)
+#define IMPL_COLL_C(func)
+#define IMPL_COLL_P(func)
#endif
-#define COLL_UNROLL 4
-
#endif
diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h
index aa1e936..ff466a0 100644
--- a/src/collectives/device/common_kernel.h
+++ b/src/collectives/device/common_kernel.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -16,6 +16,12 @@
// Define min for ssize_t
static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; }
+template <typename T>
+inline __device__ void loadPtr(void** ptr, T* &v) {
+ asm volatile("ld.volatile.global.u64 %0, [%1];"
+ : "=l"(v) : "l"(ptr));
+}
+
typedef uint64_t PackType;
// unpack x and y to elements of type T and apply FUNC to each element
@@ -245,28 +251,57 @@ inline __device__ void Store128(Pack128* p, Pack128& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
-template<class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
-__device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthreads,
- int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
- const int offset, const int N) {
- for (int idx = offset+tid; idx < offset+N; idx += nthreads) {
- T val = vFetch(srcs[0]+idx);
+template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
+__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
+ int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) {
+ const int inc = nw * UNROLL * WARP_SIZE;
+ int offset = w * UNROLL * WARP_SIZE + t;
+
+ const T* srcs[MAXSRCS];
+ for (int i=0; i<MAXSRCS; i++) srcs[i] = s[i]+elemOffset+offset;
+ T* dsts[MAXDSTS];
+ for (int i=0; i<MAXDSTS; i++) dsts[i] = d[i]+elemOffset+offset;
+
+ while (offset < Nelem) {
+ T vals[UNROLL];
+ // Load and reduce
+ for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
+
+ #pragma unroll
+ for (int i=1; i<MINSRCS; i++) {
+ T vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
+ }
#pragma unroll
- for (int i=1; i<MINSRCS; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
- #pragma unroll 1
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
+ for (int i=MINSRCS; i<MAXSRCS; i++) {
+ if (i<nsrcs) {
+ T vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
+ }
+ }
+ // Store
#pragma unroll
- for (int i=0; i<MINDSTS; i++) vStore(dsts[i]+idx, val);
- #pragma unroll 1
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) vStore(dsts[i]+idx, val);
+ for (int i = 0; i < MINDSTS; i++) {
+ for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ #pragma unroll
+ for (int i=MINDSTS; i<MAXDSTS; i++) {
+ if (i<ndsts) {
+ for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ }
+ for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
+ for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
+ offset += inc;
}
}
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
-__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t,
- int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS],
- const int elemOffset, const int Npack) {
+__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
+ int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) {
const int inc = nw * UNROLL * WARP_SIZE;
int offset = w * UNROLL * WARP_SIZE + t;
@@ -280,25 +315,31 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw,
// Load and reduce
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
+ #pragma unroll
for (int i=1; i<MINSRCS; i++) {
Pack128 vals2[UNROLL];
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
- #pragma unroll 1
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
- Pack128 vals2[UNROLL];
- for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
- for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
+ #pragma unroll
+ for (int i=MINSRCS; i<MAXSRCS; i++) {
+ if (i<nsrcs) {
+ Pack128 vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
+ }
}
// Store
+ #pragma unroll
for (int i = 0; i < MINDSTS; i++) {
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
}
- #pragma unroll 1
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) {
- for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ #pragma unroll
+ for (int i=MINDSTS; i<MAXDSTS; i++) {
+ if (i<ndsts) {
+ for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
}
for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
@@ -309,72 +350,65 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw,
template <typename T>
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); }
-// Try to limit consecutive load/stores to 8.
-// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise
-#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS)))
+#define PACKELEMS (sizeof(Pack128) / sizeof(T))
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
- int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
+ int nsrcs, const T** srcs, int ndsts, T** dsts,
int N) {
int Nrem = N;
if (Nrem <= 0) return;
- int alignDiff = 0;
- int align = ptrAlign128(srcs[0]);
- #pragma unroll
- for (int i=1; i<MINSRCS; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
- #pragma unroll
- for (int i=0; i<MINDSTS; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
-
- int Npreamble = alignDiff ? Nrem :
- N < alignof(Pack128) ? N :
- (alignof(Pack128) - align) % alignof(Pack128);
-
- // stage 1: preamble: handle any elements up to the point of everything coming
- // into alignment
- if (Npreamble) {
- ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, 0, Npreamble);
- Nrem -= Npreamble;
- if (Nrem == 0) return;
- }
- int offset = Npreamble;
-
- // stage 2: fast path: use 128b loads/stores to do the bulk of the work,
- // assuming the pointers we have are all 128-bit alignable.
int w = tid / WARP_SIZE; // Warp number
int nw = nthreads / WARP_SIZE; // Number of warps
int t = tid % WARP_SIZE; // Thread (inside the warp)
- const int packFactor = sizeof(Pack128) / sizeof(T);
+ // Check that all is 16B aligned. If not don't use 16B load/stores.
+ int align = 0;
+ #pragma unroll
+ for (int i=0; i<MINSRCS; i++) align |= ptrAlign128(srcs[i]);
+ for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) align |= ptrAlign128(srcs[i]);
+ #pragma unroll
+ for (int i=0; i<MINDSTS; i++) align |= ptrAlign128(dsts[i]);
+ for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) align |= ptrAlign128(dsts[i]);
- // stage 2a: main loop
- int Npack2a = (Nrem / (packFactor * AUTOUNROLL * WARP_SIZE))
- * (AUTOUNROLL * WARP_SIZE); // round down
- int Nelem2a = Npack2a * packFactor;
+ int offset = 0;
+ if (align == 0) {
+ // fast path: use 128b loads/stores to do the bulk of the work,
+ // assuming the pointers we have are all 128-bit aligned.
- ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2a);
+ // main loop
+ int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down
+ int Nelem = Npack * PACKELEMS;
- Nrem -= Nelem2a;
- if (Nrem == 0) return;
- offset += Nelem2a;
+ ReduceCopy128bMulti<FUNC, T, UNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
- // stage 2b: slightly less optimized for section when we don't have full
- // unrolling
+ Nrem -= Nelem;
+ if (Nrem == 0) return;
+ offset += Nelem;
+
+ // slightly less optimized for section when we don't have full unrolling
+ Npack = Nrem / PACKELEMS;
+ Nelem = Npack * PACKELEMS;
+
+ ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
+
+ Nrem -= Nelem;
+ if (Nrem == 0) return;
+ offset += Nelem;
+ }
- int Npack2b = Nrem / packFactor;
- int Nelem2b = Npack2b * packFactor;
+ // unrolled, by-type (mostly for unaligned buffers)
+ int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down
- ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2b);
+ ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem);
- Nrem -= Nelem2b;
+ Nrem -= Nelem;
if (Nrem == 0) return;
- offset += Nelem2b;
+ offset += Nelem;
- // stage 2c: tail
- ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, offset, Nrem);
+ // no unroll, by type. Should finish what's remaining.
+ ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem);
}
#endif // COMMON_KERNEL_H_
diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu
index 119cd36..553a882 100644
--- a/src/collectives/device/functions.cu
+++ b/src/collectives/device/functions.cu
@@ -8,60 +8,60 @@
#include "collectives.h"
#include "common.h"
-__device__ volatile uint64_t* ncclShmem;
+__device__ struct ncclShmemData* ncclShmem;
-#define NCCL_FUNC5(coll, op, dtype) \
- NCCL_COLL_NAME(coll##LL, op, dtype), \
- NCCL_COLL_NAME(coll##LL128, op, dtype), \
- NCCL_COLL_NAME(coll, op, dtype)
+#define NCCL_FUNC5(func, algo, redop, type) \
+ NCCL_FUNC_NAME(func, algo, LL, redop, type), \
+ NCCL_FUNC_NAME(func, algo, LL128, redop, type), \
+ NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type)
-#define NCCL_FUNC4(coll, op, dtype) \
- NCCL_FUNC5(coll##Tree, op, dtype), \
- NCCL_FUNC5(coll##Ring, op, dtype), \
- NCCL_FUNC5(coll##CollNet, op, dtype)
+#define NCCL_FUNC4(func, redop, type) \
+ NCCL_FUNC5(func, TREE, redop, type), \
+ NCCL_FUNC5(func, RING, redop, type), \
+ NCCL_FUNC5(func, COLLNET, redop, type)
// Must be consistent with ncclDataType_t
-#define NCCL_FUNCS3A(coll, op) \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, u8), \
- NCCL_FUNC4(coll, op, i32), \
- NCCL_FUNC4(coll, op, u32), \
- NCCL_FUNC4(coll, op, i64), \
- NCCL_FUNC4(coll, op, u64), \
- NCCL_FUNC4(coll, op, f16), \
- NCCL_FUNC4(coll, op, f32), \
- NCCL_FUNC4(coll, op, f64)
-#define NCCL_FUNCS3B(coll, op) \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8), \
- NCCL_FUNC4(coll, op, i8)
+#define NCCL_FUNCS3A(func, redop) \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, uint8_t), \
+ NCCL_FUNC4(func, redop, int32_t), \
+ NCCL_FUNC4(func, redop, uint32_t), \
+ NCCL_FUNC4(func, redop, int64_t), \
+ NCCL_FUNC4(func, redop, uint64_t), \
+ NCCL_FUNC4(func, redop, half), \
+ NCCL_FUNC4(func, redop, float), \
+ NCCL_FUNC4(func, redop, double)
+#define NCCL_FUNCS3B(func, redop) \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t), \
+ NCCL_FUNC4(func, redop, int8_t)
// Must be consistent with ncclRedOp_t
-#define NCCL_FUNCS2A(coll) \
- NCCL_FUNCS3A(coll, sum ), \
- NCCL_FUNCS3A(coll, prod), \
- NCCL_FUNCS3A(coll, max ), \
- NCCL_FUNCS3A(coll, min )
-#define NCCL_FUNCS2B(coll) \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy)
+#define NCCL_FUNCS2A(func) \
+ NCCL_FUNCS3A(func, Sum ), \
+ NCCL_FUNCS3A(func, Prod), \
+ NCCL_FUNCS3A(func, Max ), \
+ NCCL_FUNCS3A(func, Min )
+#define NCCL_FUNCS2B(func) \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum)
// Must be consistent with ncclFunc_t
#define NCCL_FUNCS() { \
- NCCL_COLL_NAME(ncclSendRecv, copy, i8),\
- NCCL_FUNCS2B(ncclBroadcast), \
- NCCL_FUNCS2A(ncclReduce), \
- NCCL_FUNCS2B(ncclAllGather), \
- NCCL_FUNCS2A(ncclReduceScatter), \
- NCCL_FUNCS2A(ncclAllReduce) }
+ NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),\
+ NCCL_FUNCS2B(Broadcast), \
+ NCCL_FUNCS2A(Reduce), \
+ NCCL_FUNCS2B(AllGather), \
+ NCCL_FUNCS2A(ReduceScatter), \
+ NCCL_FUNCS2A(AllReduce) }
// Must be consistent with the ncclFuncSet enum
__device__ ncclKern_t ncclFuncs[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
@@ -69,12 +69,12 @@ __device__ ncclKern_t ncclFuncs[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCC
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
#if __CUDA_ARCH__
- NCCL_COLL_NAME(ncclSendRecv, copy, i8),
- NCCL_FUNCS2B(ncclBroadcast),
- NCCL_FUNCS2A(ncclReduce),
- NCCL_FUNCS2B(ncclAllGather),
- NCCL_FUNCS2A(ncclReduceScatter),
- NCCL_FUNCS2A(ncclAllReduce)
+ NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
+ NCCL_FUNCS2B(Broadcast),
+ NCCL_FUNCS2A(Reduce),
+ NCCL_FUNCS2B(AllGather),
+ NCCL_FUNCS2A(ReduceScatter),
+ NCCL_FUNCS2A(AllReduce)
#endif
};
diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h
index 12078a8..69348db 100644
--- a/src/collectives/device/primitives.h
+++ b/src/collectives/device/primitives.h
@@ -31,63 +31,57 @@
} \
} while (0)
+#define ROLE_SRC 0x01
+#define ROLE_DST 0x02
+#define ROLE_WAIT_RECV 0x04
+#define ROLE_WAIT_SEND 0x08
+#define ROLE_POST_SEND 0x10
+#define ROLE_POST_RECV 0x20
+
// Implementation of primitive types
template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, int DIRECT, class FUNC>
class ncclPrimitives {
private:
const int tid;
- const int nthreads;
- const int wid;
+ int nthreads;
+ int nworkers;
const int stepSize;
int nrecv = 0;
int nsend = 0;
- struct ncclConnInfo* recvConn = NULL;
- volatile uint64_t* recvConnHeadPtr = NULL;
- uint64_t recvConnHead;
- volatile uint64_t* recvConnTailPtr = NULL;
- uint64_t recvConnTail;
- uint64_t recvConnTailCache; // Cache last seen value
-
- struct ncclConnInfo* sendConn = NULL;
- volatile int* sendConnFifoPtr = NULL;
- volatile uint64_t* sendConnTailPtr = NULL;
- uint64_t sendConnTail;
- volatile uint64_t* sendConnHeadPtr = NULL;
- uint64_t sendConnHead;
- uint64_t sendConnHeadCache; // Cache last seen value
-
- uint64_t recvStep[NRECV];
- uint64_t sendStep[NSEND];
- const T* recvDirectBuff[NRECV];
- T* sendDirectBuff[NSEND];
- const T* recvBuff[NRECV];
- T* sendBuff[NSEND];
+ struct ncclConnInfo* conn = NULL;
+ volatile int* connSizesFifoPtr = NULL;
+ void** connPtrsFifoPtr = NULL;
+ volatile uint64_t* connHeadPtr = NULL;
+ volatile uint64_t* connTailPtr = NULL;
+ uint64_t connTailCache; // Cache last seen value
+ uint64_t connHeadCache; // Cache last seen value
+
+ int index; // Peer index I'm responsible for
+ int peer = -1;
+ int role = 0;
+ int group;
+ uint64_t step;
+ T* direct = NULL;
+ T* buff;
struct ncclDevComm* comm;
- inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
- inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
- inline __device__ const T* recvPtr(int i) { return ((const T*)recvBuff[i])+recvOffset(i); }
- inline __device__ T* sendPtr(int i) { return ((T*)sendBuff[i])+sendOffset(i); }
+ const T** srcs;
+ T** dsts;
+ // Don't use barrier 0 as it's used by the final sync
inline __device__ void barrier() {
- if (NSEND>NRECV) {
- asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE));
- } else {
- asm volatile ("bar.sync 2, %0;" :: "r"(nthreads+WARP_SIZE));
- }
+ if (nthreads == WARP_SIZE) __syncwarp();
+ else asm volatile ("bar.sync %0, %1;" :: "r"(group+1), "r"(nthreads));
}
inline __device__ void subBarrier() {
- if (NSEND>NRECV) {
- asm volatile ("bar.sync 3, %0;" :: "r"(nthreads));
- } else {
- asm volatile ("bar.sync 4, %0;" :: "r"(nthreads));
- }
+ if (nworkers == nthreads) barrier();
+ else asm volatile ("bar.sync %0, %1;" :: "r"(group+2), "r"(nworkers));
}
uint32_t spins = 0;
uint32_t abort = 0;
- inline __device__ int checkAbort(int i, int send) {
+ inline __device__ int checkAbort() {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = *(comm->abortFlag);
@@ -96,63 +90,45 @@ class ncclPrimitives {
return abort;
}
- inline __device__ void waitSend(int nbytes) {
+ template <int DIRECTPTR>
+ inline __device__ T* directPtr(ssize_t directOffset) {
+ return DIRECTPTR && direct ? direct+directOffset : buff+(step%NCCL_STEPS)*stepSize;
+ }
+
+ template <int DST, int DIRECTSEND>
+ inline __device__ void waitSend(ssize_t directOffset, int nbytes) {
spins = 0;
- if (sendConnHeadPtr) {
- while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
- sendConnHeadCache = *sendConnHeadPtr;
- if (checkAbort(wid, 1)) break;
- }
- if (sendConnFifoPtr) {
- sendConnFifoPtr[sendConnHead%NCCL_STEPS] = nbytes;
- }
- sendConnHead += SLICESTEPS;
+ while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) {
+ connHeadCache = *connHeadPtr;
+ if (checkAbort()) break;
+ }
+ if (connSizesFifoPtr) {
+ connSizesFifoPtr[step%NCCL_STEPS] = nbytes;
}
+
+ if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, dsts[DST+index]);
+ else dsts[DST+index] = directPtr<DIRECTSEND>(directOffset);
+ step += SLICESTEPS;
}
- inline __device__ void waitRecv() {
+ template <int SRC, int DIRECTRECV>
+ inline __device__ void waitRecv(ssize_t directOffset) {
spins = 0;
- if (recvConnTailPtr) {
- while (recvConnTailCache < recvConnTail + SLICESTEPS) {
- recvConnTailCache = *recvConnTailPtr;
- if (checkAbort(wid, 0)) break;
- }
- recvConnTail += SLICESTEPS;
+ while (connTailCache < step + SLICESTEPS) {
+ connTailCache = *connTailPtr;
+ if (checkAbort()) break;
}
+ if (connPtrsFifoPtr) loadPtr(connPtrsFifoPtr+step%NCCL_STEPS, srcs[SRC+index]);
+ else srcs[SRC+index] = directPtr<DIRECTRECV>(directOffset);
+ step += SLICESTEPS;
}
- inline __device__ void incRecv(int i) {
- recvStep[i] += SLICESTEPS;
- }
inline __device__ void postRecv() {
- if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += SLICESTEPS;
+ *connHeadPtr = step += SLICESTEPS;
}
- inline __device__ void incSend(int i) {
- sendStep[i] += SLICESTEPS;
- }
inline __device__ void postSend() {
- if (sendConnTailPtr) *sendConnTailPtr = sendConnTail += SLICESTEPS;
- }
-
- template <int DIRECTRECV>
- inline __device__ const T* directRecvPtr(int i, ssize_t directOffset) {
- return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i);
- }
-
- template <int DIRECTSEND>
- inline __device__ T* directSendPtr(int i, ssize_t directOffset) {
- return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
- }
-
- template <int DIRECTRECV>
- inline __device__ int directRecvInc(int i, int directInc, int sliceInc) {
- return DIRECTRECV && recvDirectBuff[i] ? directInc : sliceInc;
- }
-
- template <int DIRECTSEND>
- inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
- return DIRECTSEND && sendDirectBuff[i] ? directInc : sliceInc;
+ *connTailPtr = step += SLICESTEPS;
}
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
@@ -162,135 +138,128 @@ class ncclPrimitives {
int sliceSize = stepSize*SLICESTEPS;
int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32);
- const T* srcs[RECV*NRECV+SRC];
- srcs[0] = SRC ? srcPtr : directRecvPtr<DIRECTRECV>(0, directOffset);
- if (RECV) {
- if (SRC) srcs[1] = recvPtr(0);
- for (int i=1; i<NRECV && i<nrecv; i++) srcs[SRC+i] = recvPtr(i);
- }
-
- T* dsts[SEND*NSEND+DST];
- dsts[0] = DST ? dstPtr : directSendPtr<DIRECTSEND>(0, directOffset);
- if (SEND) {
- if (DST) dsts[1] = directSendPtr<DIRECTSEND>(0, directOffset);
- for (int i=1; i<NSEND && i<nsend; i++) dsts[DST+i] = directSendPtr<DIRECTSEND>(i, directOffset);
- }
-
- bool syncThread = tid >= nthreads;
-
#pragma unroll
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
int realSize = max(0, min(dataSize, nelem-offset));
- if (!syncThread) {
- if (SEND) waitSend(realSize*sizeof(T));
- if (RECV) waitRecv();
+ if (tid < nworkers) {
+ if (SRC && (role & ROLE_SRC)) srcs[0] = srcPtr+offset;
+ if (RECV && (role & ROLE_WAIT_RECV)) waitRecv<SRC, DIRECTRECV>(directOffset+offset);
+ if (DST && (role & ROLE_DST)) dsts[0] = dstPtr+offset;
+ if (SEND && (role & ROLE_WAIT_SEND)) waitSend<DST, DIRECTSEND>(directOffset+offset, realSize*sizeof(T));
if (realSize > 0) {
subBarrier();
- if (DIRECTRECV && recvDirectBuff[0]) {
+ if (DIRECTRECV && srcs[0] == dsts[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (SEND) {
- ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize);
+ // (1-SEND) is only there to avoid compilation errors in case NSEND=0 (and SEND=0).
+ ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, (1-SEND)+NSEND>(tid, nworkers, 1, srcs, nsend, dsts+1, realSize);
}
} else {
- ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
+ ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nworkers, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
}
}
}
barrier();
- FOR_SEND(incSend);
- FOR_RECV(incRecv);
- if (syncThread) {
- if (SEND) {
- if (realSize > 0 && wid == 0) __threadfence_system();
- __syncwarp();
- postSend();
- }
- if (RECV) postRecv();
- }
- srcs[0] += SRC ? realSize : directRecvInc<DIRECTRECV>(0, realSize, sliceSize);
- for (int i=1-SRC; i<RECV*NRECV; i++) srcs[SRC+i] += sliceSize;
- dsts[0] += DST ? realSize : directSendInc<DIRECTSEND>(0, realSize, sliceSize);
- for (int i=1-DST; i<SEND*NSEND; i++) dsts[DST+i] += directSendInc<DIRECTSEND>(i, realSize, sliceSize);
+ if (SEND && (role & ROLE_POST_SEND) && realSize > 0 && index == 0) __threadfence_system();
+ __syncwarp();
+ if (SEND && (role & ROLE_POST_SEND)) postSend();
+ if (RECV && (role & ROLE_POST_RECV)) postRecv();
offset += realSize;
}
}
- __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
- recvBuff[i] = (const T*)conn->buffs[NCCL_PROTO_SIMPLE];
- recvStep[i] = conn->step;
- recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
- recvDirectBuff[i] = NULL;
- if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
- recvDirectBuff[i] = directBuff;
- if (tid == 0) *conn->ptrExchange = directBuff;
- }
- if (wid == i) recvConn = conn;
- if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up
- nrecv++;
- }
- __device__ __forceinline__ void loadRecvSync() {
- if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && wid<nrecv) {
- recvConnTailPtr = recvConn->tail;
- recvConnTailCache = *recvConnTailPtr;
- }
- if (tid >= nthreads && wid < nrecv) {
- recvConnHeadPtr = recvConn->head;
- // Return credits in case we rounded up.
- *recvConnHeadPtr = recvConnHead;
- }
- }
-
- __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
- sendBuff[i] = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
- sendStep[i] = conn->step;
- sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
- sendDirectBuff[i] = NULL;
- if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
- void* volatile* ptr = conn->ptrExchange;
- while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
- barrier();
- if (tid == 0) *ptr = NULL;
- }
- if (wid == i) sendConn = conn;
- if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up
- nsend++;
- }
- __device__ __forceinline__ void loadSendSync() {
- if (tid < nsend) {
- sendConnHeadPtr = sendConn->head;
- sendConnHeadCache = *sendConnHeadPtr;
- sendConnFifoPtr = sendConn->fifo;
- }
- if (tid >= nthreads && wid<nsend) {
- sendConnTailPtr = sendConn->tail;
+ __device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) {
+ if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) {
+ conn = &channel->devPeers[peer].recv.conn;
+ step = conn->step;
+ step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
+ if (role & ROLE_POST_RECV) {
+ connHeadPtr = conn->head;
+ // Return credits in case we rounded up.
+ *connHeadPtr = step;
+ }
+ if (role & ROLE_WAIT_RECV) {
+ buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
+ if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
+ direct = directBuff;
+ *conn->ptrExchange = directBuff;
+ }
+ connTailPtr = conn->tail;
+ connTailCache = *connTailPtr;
+ connPtrsFifoPtr = conn->ptrsFifo;
+ }
}
}
- __device__ __forceinline__ void saveRecvSync() {
- if (tid >= nthreads && wid < nrecv) {
- recvConn->step = recvConnHead;
- __threadfence_system();
+ __device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) {
+ if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) {
+ conn = &channel->devPeers[peer].send.conn;
+ step = conn->step;
+ step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
+ if (role & ROLE_POST_SEND) {
+ connTailPtr = conn->tail;
+ }
+ if (role & ROLE_WAIT_SEND) {
+ buff = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
+ if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
+ void* volatile* ptr = conn->ptrExchange;
+ while ((direct = (T*)(*ptr)) == NULL);
+ *ptr = NULL;
+ }
+ connHeadPtr = conn->head;
+ connHeadCache = *connHeadPtr;
+ connSizesFifoPtr = conn->sizesFifo;
+ connPtrsFifoPtr = conn->ptrsFifo;
+ }
}
}
- __device__ __forceinline__ void saveSendSync() {
- if (tid < nsend) {
- sendConn->step = sendConnHead;
+ __device__ __forceinline__ void saveSync() {
+ if (role & (ROLE_POST_SEND|ROLE_POST_RECV)) {
+ conn->step = step;
__threadfence_system();
}
}
public:
__device__ __forceinline__
- ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
- : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize) {
+ ncclPrimitives(const int tid, const int nworkers, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, struct ncclShmemPtrs* ptrs, int group)
+ : comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[group].srcs), dsts((T**)ptrs[group].dsts), group(group) {
+ nthreads = nworkers;
+ // For send operations, we need an extra warp to overlap the threadfence and the copy
+ int postThreads = NSEND && nworkers >= 64 ? WARP_SIZE : 0;
+ nthreads += postThreads;
+
// Make sure step is updated before we read it.
barrier();
- for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff);
- for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
- loadRecvSync();
- loadSendSync();
+ for (int i=0; i<NRECV; i++) if (recvPeers[i] != -1) nrecv++;
+ for (int i=0; i<NSEND; i++) if (sendPeers[i] != -1) nsend++;
+
+ #define SYNC_GROUP 8
+ static_assert(NSEND < SYNC_GROUP && NRECV < SYNC_GROUP, "Not enough threads to cover all peers");
+
+ int g = tid / SYNC_GROUP;
+ int ng = nthreads / SYNC_GROUP;
+ index = tid % SYNC_GROUP;
+
+ if (g == 0) {
+ if (index < nrecv) role |= ROLE_WAIT_RECV;
+ if (index == nrecv) role |= ROLE_SRC;
+ } else if (g == 1) {
+ if (index < nsend) role |= ROLE_WAIT_SEND;
+ if (index == nsend) role |= ROLE_DST;
+ } else if (g == ng - 2) {
+ if (index < nrecv) role |= ROLE_POST_RECV;
+ } else if (g == ng - 1) {
+ if (index < nsend) role |= ROLE_POST_SEND;
+ }
+
+ if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) peer = recvPeers[index];
+ if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) peer = sendPeers[index];
+
+ loadRecvConn(channel, directBuff);
+ loadSendConn(channel);
}
__device__ __forceinline__ void
@@ -351,8 +320,7 @@ class ncclPrimitives {
__device__ __forceinline__ ~ncclPrimitives() {
// Save steps for the next operation
- saveRecvSync();
- saveSendSync();
+ saveSync();
}
};
diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h
index 93c6db3..9e362f9 100644
--- a/src/collectives/device/prims_ll.h
+++ b/src/collectives/device/prims_ll.h
@@ -178,7 +178,7 @@ class ncclLLPrimitives {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
- sendConnFifoPtr = sendConn->fifo;
+ sendConnFifoPtr = sendConn->sizesFifo;
}
}
diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h
index f97b25c..999d0d5 100644
--- a/src/collectives/device/prims_ll128.h
+++ b/src/collectives/device/prims_ll128.h
@@ -211,14 +211,14 @@ class ncclLL128Primitives {
/************************ Send **************************/
if (SEND) {
for (int i=1; i<NSEND && i<nsend; i++) {
- int flag = sendFlag(i);
+ uint64_t flag = sendFlag(i);
uint64_t* ptr = sendPtr(i)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
}
}
- int flag = sendFlag(0);
+ uint64_t flag = sendFlag(0);
uint64_t* ptr = sendPtr(0)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
@@ -318,10 +318,10 @@ class ncclLL128Primitives {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
- sendConnFifoPtr = sendConn->fifo;
+ sendConnFifoPtr = sendConn->sizesFifo;
}
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
- if (sendConn->fifo) {
+ if (sendConn->sizesFifo) {
sendConnTailPtr = sendConn->tail;
sendConnTail = sendConn->step;
}
@@ -345,7 +345,7 @@ class ncclLL128Primitives {
public:
__device__ __forceinline__
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
- : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
// Make sure step is updated before we read it.
barrier();
diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu
index a2caac5..66f1bb2 100644
--- a/src/collectives/device/reduce.cu
+++ b/src/collectives/device/reduce.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -8,4 +8,4 @@
#include "common.h"
#include "collectives.h"
-IMPL_COLL_R(ncclReduce, ncclCollReduce);
+IMPL_COLL_R(Reduce);
diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h
index 115e05e..313209d 100644
--- a/src/collectives/device/reduce.h
+++ b/src/collectives/device/reduce.h
@@ -8,142 +8,145 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceRingKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * REDUCE_CHUNKSTEPS;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = ring->devUserRanks[0];
- const int prevRank = ring->devUserRanks[nranks-1];
- const int root = args->coll.root;
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t offset = gridOffset + bid*realChunkSize;
- int nelem = min(realChunkSize, size-offset);
- if (prevRank == root) {
- prims.send(thisInput+offset, nelem);
- } else if (rank == root) {
- prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else {
- prims.recvReduceSend(thisInput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * REDUCE_CHUNKSTEPS;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = ring->devUserRanks[0];
+ const int prevRank = ring->devUserRanks[nranks-1];
+ const int root = args->coll.root;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t offset = gridOffset + bid*realChunkSize;
+ int nelem = min(realChunkSize, size-offset);
+ if (prevRank == root) {
+ prims.send(thisInput+offset, nelem);
+ } else if (rank == root) {
+ prims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else {
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
}
- }
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceTreeKernel(struct CollectiveArgs* args) { }
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceCollNetKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceRingLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = comm->rank;
- const int prevRank = ring->devUserRanks[nranks-1];
- const int root = args->coll.root;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- if (size-gridOffset < loopSize) {
- chunkSize = args->coll.lastChunkSize;
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = comm->rank;
+ const int prevRank = ring->devUserRanks[nranks-1];
+ const int root = args->coll.root;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ if (size-gridOffset < loopSize) {
+ chunkSize = args->coll.lastChunkSize;
+ }
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (prevRank == root) {
+ LLprims.send(thisInput+offset, nelem);
+ } else if (rank == root) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
}
- ssize_t offset = gridOffset + bid*chunkSize;
-
- int nelem = min(chunkSize, size-offset);
- if (prevRank == root) {
- LLprims.send(thisInput+offset, nelem);
- } else if (rank == root) {
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else {
- LLprims.recvReduceSend(thisInput+offset, nelem);
- }
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceTreeLLKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceCollNetLLKernel(struct CollectiveArgs* args) { }
+};
#include "prims_ll128.h"
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceRingLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
- const int rank = comm->rank;
- const int prevRank = ring->devUserRanks[nranks-1];
- const int root = args->coll.root;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
- ssize_t offset = gridOffset + bid*chunkSize;
-
- int nelem = min(chunkSize, size-offset);
- if (prevRank == root) {
- LLprims.send(thisInput+offset, nelem);
- } else if (rank == root) {
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
- } else {
- LLprims.recvReduceSend(thisInput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T));
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+ const int rank = comm->rank;
+ const int prevRank = ring->devUserRanks[nranks-1];
+ const int root = args->coll.root;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
+ ssize_t offset = gridOffset + bid*chunkSize;
+
+ int nelem = min(chunkSize, size-offset);
+ if (prevRank == root) {
+ LLprims.send(thisInput+offset, nelem);
+ } else if (rank == root) {
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+offset, nelem);
+ } else {
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+ }
}
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceTreeLL128Kernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceCollNetLL128Kernel(struct CollectiveArgs* args) { }
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduce, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduce, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu
index 8b45299..c2c6d42 100644
--- a/src/collectives/device/reduce_scatter.cu
+++ b/src/collectives/device/reduce_scatter.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -8,4 +8,4 @@
#include "common.h"
#include "collectives.h"
-IMPL_COLL_R(ncclReduceScatter, ncclCollReduceScatter);
+IMPL_COLL_R(ReduceScatter);
diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h
index 52c858c..a0d45dc 100644
--- a/src/collectives/device/reduce_scatter.h
+++ b/src/collectives/device/reduce_scatter.h
@@ -8,186 +8,189 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceScatterRingKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads-WARP_SIZE;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
- const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
- const ssize_t size = args->coll.count;
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
- prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm);
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- ssize_t chunkOffset = gridOffset + bid*realChunkSize;
-
- /////////////// begin ReduceScatter steps ///////////////
- ssize_t offset;
- int nelem = min(realChunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[nranks-1];
- offset = chunkOffset + rankDest * size;
-
- prims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- prims.recvReduceSend(thisInput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads-WARP_SIZE;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / (sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize * REDUCESCATTER_CHUNKSTEPS;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*(ssize_t)chunkSize;
+ const ssize_t size = args->coll.count;
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
+ prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+ ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
+ ssize_t chunkOffset = gridOffset + bid*realChunkSize;
+
+ /////////////// begin ReduceScatter steps ///////////////
+ ssize_t offset;
+ int nelem = min(realChunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[nranks-1];
+ offset = chunkOffset + rankDest * size;
+
+ prims.send(thisInput+offset, nelem);
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ prims.recvReduceSend(thisInput+offset, nelem);
+ }
+
+ // step k-1: reduce this buffer and data, which will produce the final result
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
+ }
}
-
- // step k-1: reduce this buffer and data, which will produce the final result
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- prims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
- }
-}
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceScatterTreeKernel(struct CollectiveArgs* args) { }
-
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclReduceScatterCollNetKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterRingLLKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
- ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- if (size-gridOffset < loopSize) {
- chunkSize = args->coll.lastChunkSize;
- }
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
-
- /////////////// begin ReduceScatter steps ///////////////
- ssize_t offset;
- int nelem = min(chunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[nranks-1];
- offset = chunkOffset + rankDest * size;
-
- LLprims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvReduceSend(thisInput+offset, nelem);
+};
+
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepLines = comm->buffSizes[NCCL_PROTO_LL] / (sizeof(union ncclLLFifoLine)*NCCL_STEPS);
+ ssize_t chunkSize = stepLines * sizeof(uint64_t) / sizeof(T);
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLLPrimitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepLines, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ if (size-gridOffset < loopSize) {
+ chunkSize = args->coll.lastChunkSize;
+ }
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin ReduceScatter steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[nranks-1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.send(thisInput+offset, nelem);
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
+ }
}
-
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterTreeLLKernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterCollNetLLKernel(struct CollectiveArgs* args) { }
+};
#include "prims_ll128.h"
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterRingLL128Kernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->coll.nThreads;
- const int bid = args->coll.bid;
- const int nChannels = args->coll.nChannels;
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
- struct ncclRing* ring = &channel->ring;
- const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
- ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
- // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
- const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
- const int nranks = comm->nRanks;
- const ssize_t loopSize = nChannels*chunkSize;
- const ssize_t size = args->coll.count;
-
- ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
-
- // Compute pointers
- const T * __restrict__ thisInput = (const T*)args->sendbuff;
- T * __restrict__ thisOutput = (T*)args->recvbuff;
-
- for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
- chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
-
- ssize_t chunkOffset = gridOffset + bid*chunkSize;
-
- /////////////// begin ReduceScatter steps ///////////////
- ssize_t offset;
- int nelem = min(chunkSize, size-chunkOffset);
- int rankDest;
-
- // step 0: push data to next GPU
- rankDest = ring->devUserRanks[nranks-1];
- offset = chunkOffset + rankDest * size;
-
- LLprims.send(thisInput+offset, nelem);
-
- // k-2 steps: reduce and copy to next GPU
- for (int j=2; j<nranks; ++j) {
- rankDest = ring->devUserRanks[nranks-j];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvReduceSend(thisInput+offset, nelem);
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_LL128, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {
+ const int tid = threadIdx.x;
+ const int nthreads = args->nThreads;
+ const int bid = args->coll.bid;
+ const int nChannels = args->coll.nChannels;
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
+ struct ncclRing* ring = &channel->ring;
+ const int stepSize = comm->buffSizes[NCCL_PROTO_LL128] / (sizeof(uint64_t)*NCCL_STEPS);
+ ssize_t chunkSize = stepSize*NCCL_LL128_DATAELEMS*sizeof(uint64_t) / (NCCL_LL128_LINEELEMS*sizeof(T));
+ // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
+ const ssize_t minChunkSize = (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*nthreads*NCCL_LL128_DATAELEMS*sizeof(uint64_t))/(NCCL_LL128_LINEELEMS*sizeof(T))/2;
+ const int nranks = comm->nRanks;
+ const ssize_t loopSize = nChannels*chunkSize;
+ const ssize_t size = args->coll.count;
+
+ ncclLL128Primitives<T, FUNC, 1, 1> LLprims(tid, nthreads, &ring->prev, &ring->next, stepSize, channel, comm);
+
+ // Compute pointers
+ const T * __restrict__ thisInput = (const T*)args->sendbuff;
+ T * __restrict__ thisOutput = (T*)args->recvbuff;
+
+ for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
+ chunkSize = min(DIVUP(size-gridOffset, nChannels*minChunkSize)*minChunkSize, chunkSize);
+
+ ssize_t chunkOffset = gridOffset + bid*chunkSize;
+
+ /////////////// begin ReduceScatter steps ///////////////
+ ssize_t offset;
+ int nelem = min(chunkSize, size-chunkOffset);
+ int rankDest;
+
+ // step 0: push data to next GPU
+ rankDest = ring->devUserRanks[nranks-1];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.send(thisInput+offset, nelem);
+
+ // k-2 steps: reduce and copy to next GPU
+ for (int j=2; j<nranks; ++j) {
+ rankDest = ring->devUserRanks[nranks-j];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceSend(thisInput+offset, nelem);
+ }
+
+ // step k-1: reduce this buffer and data, which will produce the final
+ // result that we store in this data
+ rankDest = ring->devUserRanks[0];
+ offset = chunkOffset + rankDest * size;
+
+ LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
+ }
}
-
- // step k-1: reduce this buffer and data, which will produce the final
- // result that we store in this data
- rankDest = ring->devUserRanks[0];
- offset = chunkOffset + rankDest * size;
-
- LLprims.recvReduceCopy(thisInput+offset, thisOutput+chunkOffset, nelem);
- }
-}
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterTreeLL128Kernel(struct CollectiveArgs* args) { }
-
-template<int UNUSED, class FUNC, typename T>
-__device__ void ncclReduceScatterCollNetLL128Kernel(struct CollectiveArgs* args) { }
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_TREE, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
+
+template<int PROTO, class REDOP, typename T, int UNROLL>
+class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_COLLNET, PROTO, REDOP, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* args) {}
+};
diff --git a/src/collectives/device/sendrecv.cu b/src/collectives/device/sendrecv.cu
index 34e7adf..59e38b5 100644
--- a/src/collectives/device/sendrecv.cu
+++ b/src/collectives/device/sendrecv.cu
@@ -8,7 +8,4 @@
#include "common.h"
#include "collectives.h"
-#if NCCL_OP == 0 && NCCL_TYPE == 0
-IMPL_COLL_FUNC(ncclSendRecv, copy, FuncSum, i8, int8_t);
-IMPL_COLL_KERN(ncclSendRecv, copy, FuncSum, i8, int8_t, 0);
-#endif
+IMPL_COLL_P(SendRecv);
diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h
index 7998ba6..1cb34f3 100644
--- a/src/collectives/device/sendrecv.h
+++ b/src/collectives/device/sendrecv.h
@@ -8,74 +8,85 @@
#include "primitives.h"
#include "collectives.h"
-template<int UNROLL, class FUNC, typename T>
-__device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
- const int tid = threadIdx.x;
- const int nthreads = args->p2p.nThreads-2*WARP_SIZE;
+template<class FUNC, typename T, int UNROLL>
+class ncclFunction<ncclFuncSendRecv, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, UNROLL> {
+ public:
+ __device__ void run(struct ncclWorkElem* firstArgs) {
+ struct ncclWorkElem* args = firstArgs;
+ int tid = threadIdx.x;
+ int group = 0;
+ for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
+ int nThreadsSegment = args->p2p.nThreads;
+ if (nThreadsSegment == 0) return; // Nothing else to do
+ int groupRecv = group;
+ group += 1;
+ int groupSend = group;
+ group += nThreadsSegment > 128 ? 2 : 1;
+ if (tid < nThreadsSegment) {
+ const int nThreads = nThreadsSegment > 128 ? nThreadsSegment-WARP_SIZE : nThreadsSegment;
- // Compute pointers
- const T* sendbuff = (const T*)args->sendbuff;
- T* recvbuff = (T*)args->recvbuff;
+ // Compute pointers
+ const T* sendbuff = (const T*)args->sendbuff;
+ T* recvbuff = (T*)args->recvbuff;
+ const ssize_t sendCount = args->p2p.sendCount;
+ const ssize_t recvCount = args->p2p.recvCount;
- if (args->p2p.delta < 0 ) return; // No-op
+ const int delta = args->p2p.delta;
+ if (delta == 0) {
+ if (tid < nThreads && sendbuff != recvbuff) {
+ // local copy : ReduceOrCopyMulti takes an int as number of elements,
+ // so we split it in blocks of 1G elements.
+ int blockSize = 1<<30;
+ for (size_t offset=0; offset<sendCount; offset += blockSize) {
+ size_t remaining = sendCount - offset;
+ if (remaining < blockSize) blockSize = remaining;
+ ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nThreads, 1, &sendbuff, 1, &recvbuff, blockSize);
+ sendbuff += blockSize; recvbuff += blockSize;
+ }
+ }
+ } else {
+ struct ncclDevComm* comm = args->comm;
+ struct ncclChannel* channel = comm->channels+blockIdx.x;
- if (args->p2p.delta == 0) {
- if (tid < nthreads && sendbuff != recvbuff) {
- // local copy : ReduceOrCopyMulti takes an int as number of elements,
- // so we split it in blocks of 1G elements.
- int blockSize = 1<<30;
- for (size_t offset=0; offset<args->p2p.sendCount; offset += blockSize) {
- size_t remaining = args->p2p.sendCount - offset;
- if (remaining < blockSize) blockSize = remaining;
- ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nthreads, 1, &sendbuff, 1, &recvbuff, blockSize);
- sendbuff += blockSize; recvbuff += blockSize;
- }
- }
- return;
- }
-
- struct ncclDevComm* comm = args->comm;
- struct ncclChannel* channel = comm->channels+blockIdx.x;
-
- const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS)/SENDRECV_SLICEFACTOR;
-
- int nthreadsSplit = nthreads/2;
- // We set NRECV or NSEND to 2 to use different barriers in primitives for the send threads and
- // receive threads, but then we define all peers to -1 since sender threads don't receive and
- // receive threads don't send.
- int peerNone[2] = {-1,-1};
+ const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS);
+ const int chunkSize = stepSize/SENDRECV_SLICEFACTOR;
- if (tid < nthreadsSplit + WARP_SIZE ) {
- const ssize_t sendSize = args->p2p.sendCount;
- if (sendSize < 0) return;
+ int nThreadsSplit = nThreads/2;
+ if ((tid < nThreadsSplit) && recvCount >= 0) {
+ int peer = (comm->rank-delta+comm->nRanks)%comm->nRanks;
+ int nt = nThreadsSplit;
+ ncclPrimitives<UNROLL, 1, 1, T, 1, 0, 1, FUNC>
+ prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupRecv);
- int peer = (comm->rank+(int)args->p2p.delta)%comm->nRanks;
- ncclPrimitives<UNROLL, 1, 1, T, 2, 1, 1, FUNC>
- prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm);
+ if (recvCount == 0) {
+ prims.recv(recvbuff, 0);
+ } else for (ssize_t offset = 0; offset < recvCount; offset += chunkSize) {
+ int realChunkSize = min(chunkSize, recvCount-offset);
+ ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
+ int nelem = min(realChunkSize, recvCount-offset);
+ prims.directRecv(recvbuff+offset, offset, nelem);
+ }
+ }
+ if ((tid >= nThreadsSplit) && sendCount >= 0) {
+ int peer = (comm->rank+delta)%comm->nRanks;
+ int nt = nThreads-nThreadsSplit;
+ ncclPrimitives<UNROLL, 1, 1, T, 0, 1, 1, FUNC>
+ prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupSend);
- if (sendSize == 0) {
- prims.send(sendbuff, 0);
- } else for (ssize_t offset = 0; offset < sendSize; offset += stepSize) {
- int realChunkSize = min(stepSize, sendSize-offset);
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- int nelem = min(realChunkSize, sendSize-offset);
- prims.directSend(sendbuff+offset, offset, nelem);
- }
- } else {
- const ssize_t recvSize = args->p2p.recvCount;
- if (recvSize < 0) return;
-
- int peer = (comm->rank-(int)args->p2p.delta+comm->nRanks)%comm->nRanks;
- ncclPrimitives<UNROLL, 1, 1, T, 1, 2, 1, FUNC>
- prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm);
-
- if (recvSize == 0) {
- prims.recv(recvbuff, 0);
- } else for (ssize_t offset = 0; offset < recvSize; offset += stepSize) {
- int realChunkSize = min(stepSize, recvSize-offset);
- ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
- int nelem = min(realChunkSize, recvSize-offset);
- prims.directRecv(recvbuff+offset, offset, nelem);
+ if (sendCount == 0) {
+ prims.send(sendbuff, 0);
+ } else for (ssize_t offset = 0; offset < sendCount; offset += chunkSize) {
+ int realChunkSize = min(chunkSize, sendCount-offset);
+ ALIGN_SIZE(realChunkSize, nt*sizeof(uint64_t)/sizeof(T));
+ int nelem = min(realChunkSize, sendCount-offset);
+ prims.directSend(sendbuff+offset, offset, nelem);
+ }
+ }
+ }
+ }
+ tid -= nThreadsSegment;
+ if (tid < 0) return;
+ args++;
+ }
}
- }
-}
+};
diff --git a/src/collectives/reduce.cc b/src/collectives/reduce.cc
index 67f2fae..86388df 100644
--- a/src/collectives/reduce.cc
+++ b/src/collectives/reduce.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,7 +11,8 @@ NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollReduce, "Reduce",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncReduce, "Reduce",
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
return ncclEnqueueCheck(&info);
diff --git a/src/collectives/reduce_scatter.cc b/src/collectives/reduce_scatter.cc
index 5ad7f5f..57c67bf 100644
--- a/src/collectives/reduce_scatter.cc
+++ b/src/collectives/reduce_scatter.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -11,7 +11,8 @@ NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollReduceScatter, "ReduceScatter",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
return ncclEnqueueCheck(&info);
diff --git a/src/collectives/sendrecv.cc b/src/collectives/sendrecv.cc
index 2e32875..65222a5 100644
--- a/src/collectives/sendrecv.cc
+++ b/src/collectives/sendrecv.cc
@@ -12,7 +12,8 @@ NCCL_API(ncclResult_t, ncclSend, const void* sendbuff, size_t count, ncclDataTyp
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollSendRecv, "Send",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncSendRecv, "Send",
sendbuff, NULL, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
@@ -26,7 +27,8 @@ NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t da
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
- struct ncclInfo info = { ncclCollSendRecv, "Recv",
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
+ struct ncclInfo info = { ncclFuncSendRecv, "Recv",
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
diff --git a/src/debug.cc b/src/debug.cc
index 3b99201..25bf37a 100644
--- a/src/debug.cc
+++ b/src/debug.cc
@@ -127,7 +127,7 @@ void ncclDebugInit() {
void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *filefunc, int line, const char *fmt, ...) {
if (ncclDebugLevel == -1) ncclDebugInit();
if (ncclDebugNoWarn != 0 && level == NCCL_LOG_WARN) { level = NCCL_LOG_INFO; flags = ncclDebugNoWarn; }
- if (ncclDebugLevel < level) return;
+ if (ncclDebugLevel < level || ((flags & ncclDebugMask) == 0)) return;
// Gather the rank information. This can take > 1us so we want to make sure
// we only do it when needed.
@@ -144,11 +144,11 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file
if (level == NCCL_LOG_WARN)
len = snprintf(buffer, sizeof(buffer),
"\n%s:%d:%d [%d] %s:%d NCCL WARN ", hostname, pid, tid, cudaDev, filefunc, line);
- else if (level == NCCL_LOG_INFO && (flags & ncclDebugMask))
+ else if (level == NCCL_LOG_INFO)
len = snprintf(buffer, sizeof(buffer),
"%s:%d:%d [%d] NCCL INFO ", hostname, pid, tid, cudaDev);
#ifdef ENABLE_TRACE
- else if (level == NCCL_LOG_TRACE && (flags & ncclDebugMask)) {
+ else if (level == NCCL_LOG_TRACE) {
auto delta = std::chrono::high_resolution_clock::now() - ncclEpoch;
double timestamp = std::chrono::duration_cast<std::chrono::duration<double>>(delta).count()*1000;
len = snprintf(buffer, sizeof(buffer),
diff --git a/src/enqueue.cc b/src/enqueue.cc
index 40e8f57..a427bd9 100644
--- a/src/enqueue.cc
+++ b/src/enqueue.cc
@@ -9,58 +9,58 @@
#include "coll_net.h"
// Only generate inline kernels for LL
-#define NCCL_FUNC5(coll, op, dtype) \
- (void*)NCCL_KERN_NAME(coll##LL, op, dtype), \
- (void*)NCCL_KERN_NAME(coll##LL, op, dtype), \
- (void*)NCCL_KERN_NAME(coll##LL, op, dtype)
+#define NCCL_FUNC5(func, algo, redop, dtype) \
+ (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype), \
+ (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype), \
+ (void*)NCCL_KERN_NAME(func, algo, LL, redop, dtype)
-#define NCCL_FUNC4(coll, op, dtype) \
- (void*)NCCL_FUNC5(coll##Tree, op, dtype), \
- (void*)NCCL_FUNC5(coll##Ring, op, dtype), \
- (void*)NCCL_FUNC5(coll##CollNet, op, dtype)
+#define NCCL_FUNC4(func, redop, type) \
+ (void*)NCCL_FUNC5(func, TREE, redop, type), \
+ (void*)NCCL_FUNC5(func, RING, redop, type), \
+ (void*)NCCL_FUNC5(func, COLLNET, redop, type)
// Must be consistent with ncclDataType_t
-#define NCCL_FUNCS3A(coll, op) \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, u8), \
- (void*)NCCL_FUNC4(coll, op, i32), \
- (void*)NCCL_FUNC4(coll, op, u32), \
- (void*)NCCL_FUNC4(coll, op, i64), \
- (void*)NCCL_FUNC4(coll, op, u64), \
- (void*)NCCL_FUNC4(coll, op, f16), \
- (void*)NCCL_FUNC4(coll, op, f32), \
- (void*)NCCL_FUNC4(coll, op, f64)
-#define NCCL_FUNCS3B(coll, op) \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8), \
- (void*)NCCL_FUNC4(coll, op, i8)
+#define NCCL_FUNCS3A(func, redop) \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, uint8_t), \
+ (void*)NCCL_FUNC4(func, redop, int32_t), \
+ (void*)NCCL_FUNC4(func, redop, uint32_t), \
+ (void*)NCCL_FUNC4(func, redop, int64_t), \
+ (void*)NCCL_FUNC4(func, redop, uint64_t), \
+ (void*)NCCL_FUNC4(func, redop, half), \
+ (void*)NCCL_FUNC4(func, redop, float), \
+ (void*)NCCL_FUNC4(func, redop, double)
+#define NCCL_FUNCS3B(func, redop) \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t), \
+ (void*)NCCL_FUNC4(func, redop, int8_t)
// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
-#define NCCL_FUNCS2A(coll) \
- NCCL_FUNCS3A(coll, sum), \
- NCCL_FUNCS3A(coll, sum), \
- NCCL_FUNCS3A(coll, sum), \
- NCCL_FUNCS3A(coll, sum)
-#define NCCL_FUNCS2B(coll) \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy), \
- NCCL_FUNCS3B(coll, copy)
+#define NCCL_FUNCS2A(func) \
+ NCCL_FUNCS3A(func, Sum), \
+ NCCL_FUNCS3A(func, Sum), \
+ NCCL_FUNCS3A(func, Sum), \
+ NCCL_FUNCS3A(func, Sum)
+#define NCCL_FUNCS2B(func) \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum), \
+ NCCL_FUNCS3B(func, Sum)
// Must be consistent with the ncclFuncSet enum
static void* const ncclKerns[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
- (void*)NCCL_KERN_NAME(ncclSendRecv, copy, i8),
- NCCL_FUNCS2B(ncclBroadcast),
- NCCL_FUNCS2A(ncclReduce),
- NCCL_FUNCS2B(ncclAllGather),
- NCCL_FUNCS2A(ncclReduceScatter),
- NCCL_FUNCS2A(ncclAllReduce)
+ (void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
+ NCCL_FUNCS2B(Broadcast),
+ NCCL_FUNCS2A(Reduce),
+ NCCL_FUNCS2B(AllGather),
+ NCCL_FUNCS2A(ReduceScatter),
+ NCCL_FUNCS2A(AllReduce)
};
/*****************************************************************************/
@@ -87,41 +87,57 @@ ncclResult_t ncclLaunchCooperativeKernelMultiDevice(struct cudaLaunchParams *par
return ncclSuccess;
}
-ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams* params) {
+static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** work, struct ncclWorkElem* base) {
+ if (channel->workCount == NCCL_MAX_OPS) {
+ WARN("Too many aggregated operations on channel %d (%d max)", channel->id, NCCL_MAX_OPS);
+ return ncclInvalidUsage;
+ }
+ int opIndex = channel->workFifoTail%NCCL_MAX_OPS;
+ struct ncclWork* w = channel->workFifo+opIndex;
+ struct ncclWorkElem* e = w->elems;
+ volatile uint8_t* activePtr = (volatile uint8_t*)&e->active;
+ while (activePtr[0] != 0) sched_yield();
+ memset(w, 0, sizeof(struct ncclWork));
+ // Initialize with work elem if provided
+ if (base) memcpy(e, base, sizeof(struct ncclWorkElem));
+ e->active = 1;
+ e->index = opIndex;
+ channel->workFifoTail++;
+ channel->workCount++;
+ if (work) *work = w;
+ return ncclSuccess;
+}
+
+static ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams* params) {
// Only launch blocks where we have work to do.
for (int c=0; c<comm->p2pnChannels; c++) {
- if (comm->channels[c].collCount) params->gridDim.x = c+1;
+ if (comm->channels[c].workCount) params->gridDim.x = c+1;
}
// Set active = 2 for the last operation and add a no-op on empty channels (p2p case).
for (int c=0; c<params->gridDim.x; c++) {
struct ncclChannel* channel = comm->channels+c;
- if (channel->collCount == 0) {
- int opIndex = channel->collFifoTail;
- struct ncclColl* c = channel->collectives+opIndex;
- volatile uint8_t* activePtr = (volatile uint8_t*)&c->active;
- while (activePtr[0] != 0) sched_yield();
-
- c->args.p2p.delta = -1; // no-op
- c->funcIndex = FUNC_INDEX_P2P;
- c->args.comm = comm->devComm;
- c->active = 1;
- opIndex = (opIndex+1)%NCCL_MAX_OPS;
- c->nextIndex = opIndex;
- channel->collFifoTail = opIndex;
- channel->collCount++;
+ if (channel->workCount == 0) {
+ struct ncclWork* w;
+ NCCLCHECK(getNextOp(channel, &w, NULL));
+ struct ncclWorkElem* e = w->elems;
+ e->comm = comm->devComm;
+ e->funcIndex = FUNC_INDEX_P2P;
+ e->p2p.nThreads = 0;
}
- channel->collectives[(channel->collStart+channel->collCount-1)%NCCL_MAX_OPS].active = 2;
+ channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].elems[0].active = 2;
}
// Find the first operation, choose the kernel accordingly and pass it
// as the first argument.
- struct ncclColl* coll = comm->channels[0].collectives+comm->channels[0].collStart;
- memcpy(&comm->args, coll, sizeof(struct ncclColl));
- // As we pass that coll directly, we can free it immediately.
- coll->active = 0;
-
- params->func = ncclKerns[coll->funcIndex];
+ struct ncclChannel* c0 = comm->channels;
+ struct ncclWork* work = c0->workFifo+((c0->workFifoTail-c0->workCount)%NCCL_MAX_OPS);
+ struct ncclWorkElem* elem = work->elems;
+ memcpy(&comm->args, elem, sizeof(struct ncclWorkElem));
+ // As we inline the first coll directly, we can free it immediately.
+ if (elem->funcIndex != FUNC_INDEX_P2P) elem->active = 0;
+
+ params->func = ncclKerns[elem->funcIndex];
return ncclSuccess;
}
@@ -131,7 +147,7 @@ ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast) {
bool done = false;
while (done == false) {
if (val >= comm->intraRanks) {
- WARN("Trying to launch too many collectives");
+ WARN("Trying to launch too many work elements, max is %d", NCCL_MAX_OPS);
return ncclInvalidUsage;
}
if (val+1 == comm->intraRanks) {
@@ -151,7 +167,7 @@ ncclResult_t ncclCpuBarrierLast(struct ncclComm* comm) {
volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase);
int val = *ptr;
if (__sync_bool_compare_and_swap(ptr, val, val+1) != true) {
- WARN("Trying to launch too many collectives");
+ WARN("Trying to launch too many work elements, max is %d", NCCL_MAX_OPS);
return ncclInternalError;
}
return ncclSuccess;
@@ -222,13 +238,18 @@ ncclResult_t ncclBarrierEnqueueWait(ncclComm_t comm) {
// launch and the ncclProxyStart call could cause a deadlock.
// Also, starting the proxies after the CUDA launch seems to be better for
// performance (latency).
+ uint64_t max = 0ULL;
for (int r=0; r<params->gridDim.x; r++) {
struct ncclChannel* channel = comm->channels+r;
- channel->collStart = channel->collFifoTail;
- channel->collCount = 0;
+ max = std::max(max, channel->workFifoTail);
+ channel->workCount = 0;
+ }
+ for (int r=0; r<comm->p2pnChannels; r++) {
+ struct ncclChannel* channel = comm->channels+r;
+ channel->workFifoTail = max;
}
params->gridDim.x = params->blockDim.x = 0;
- comm->lastOpCount = comm->opCount;
+ comm->lastOpCount = max;
NCCLCHECK(ncclProxyStart(comm));
return ncclSuccess;
}
@@ -280,7 +301,8 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
//if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
- int nc = (info->algorithm == NCCL_ALGO_COLLNET) ? comm->nChannels/2 : comm->nChannels; // CollNet uses one channel for up and one channel for down
+ int nc = (info->nChannels > 0) ? info->nChannels :
+ (info->algorithm == NCCL_ALGO_COLLNET) ? comm->nChannels/2 : comm->nChannels; // CollNet uses one channel for up and one channel for down
int nt = comm->maxThreads[info->algorithm][info->protocol];
int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol];
while (info->nBytes < nc*nt*threadThreshold) {
@@ -289,6 +311,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
else break;
}
if (info->protocol == NCCL_PROTO_SIMPLE) nt += WARP_SIZE; // Extra warp for sync
+ if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_TREE) nt += WARP_SIZE;
info->nChannels = nc;
info->nThreads = nt;
return ncclSuccess;
@@ -296,14 +319,14 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
static ncclResult_t getPatternInfo(struct ncclInfo* info) {
switch (info->coll) {
- case ncclCollBroadcast:
+ case ncclFuncBroadcast:
info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeDown : ncclPatternPipelineFrom; break;
- case ncclCollReduce:
+ case ncclFuncReduce:
info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break;
- case ncclCollReduceScatter:
- case ncclCollAllGather:
+ case ncclFuncReduceScatter:
+ case ncclFuncAllGather:
info->pattern = ncclPatternRing; break;
- case ncclCollAllReduce:
+ case ncclFuncAllReduce:
info->pattern = info->algorithm == NCCL_ALGO_COLLNET ? ncclPatternCollTreeUp : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : ncclPatternRingTwice; break;
default:
WARN("Unknown pattern for collective %d algorithm %d", info->coll, info->algorithm);
@@ -333,30 +356,22 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
return ncclSuccess;
}
-static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclColl* coll, struct ncclProxyArgs* proxyArgs /* output */) {
- coll->args.sendbuff = info->sendbuff;
- coll->args.recvbuff = info->recvbuff;
- coll->args.comm = info->comm->devComm;
-
- if (info->coll == ncclCollSendRecv) {
- coll->args.p2p.sendCount = info->sendbytes;
- coll->args.p2p.recvCount = info->recvbytes;
- coll->args.p2p.delta = info->delta;
- coll->funcIndex = FUNC_INDEX_P2P;
- coll->args.p2p.nThreads = info->nThreads = info->comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]+2*WARP_SIZE;
- return ncclSuccess;
- }
+static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) {
+ work->comm = info->comm->devComm;
+
// Set nstepsPerLoop and nchunksPerLoop
NCCLCHECK(getAlgoInfo(info));
NCCLCHECK(getPatternInfo(info));
NCCLCHECK(getLoopInfo(info));
- coll->args.coll.root = info->root;
- coll->args.coll.count = info->count;
- coll->args.coll.nChannels = info->nChannels;
- coll->args.coll.nThreads = info->nThreads;
+ work->sendbuff = info->sendbuff;
+ work->recvbuff = info->recvbuff;
+ work->coll.root = info->root;
+ work->coll.count = info->count;
+ work->coll.nChannels = info->nChannels;
+ work->nThreads = info->nThreads;
- coll->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol);
+ work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol);
int stepSize = info->comm->buffSizes[info->protocol]/NCCL_STEPS;
int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1;
@@ -367,25 +382,25 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo
if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_SIMPLE) {
if (info->pattern == ncclPatternTreeUpDown) {
// Optimize chunkSize / nSteps
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth*8 && chunkSize > 131072) chunkSize /= 2;
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth*4 && chunkSize > 65536) chunkSize /= 2;
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].treeUp.depth && chunkSize > 32768) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*8 && chunkSize > 131072) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*4 && chunkSize > 65536) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth && chunkSize > 32768) chunkSize /= 2;
}
// Use lastChunkSize as chunkSize
- coll->args.coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
+ work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
} else if (info->algorithm == NCCL_ALGO_COLLNET && info->protocol == NCCL_PROTO_SIMPLE) {
// Optimize chunkSize / nSteps
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth*16 && chunkSize > 131072) chunkSize /= 2;
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth*4 && chunkSize > 65536) chunkSize /= 2;
- while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTreeUp.depth && chunkSize > 32768) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth*16 && chunkSize > 131072) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth*4 && chunkSize > 65536) chunkSize /= 2;
+ while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collTree.depth && chunkSize > 32768) chunkSize /= 2;
// Use lastChunkSize as chunkSize
- coll->args.coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
+ work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
} else if (info->protocol == NCCL_PROTO_LL) {
const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine);
const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize;
- coll->args.coll.lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop);
- ALIGN_SIZE(coll->args.coll.lastChunkSize, info->nThreads*sizeof(uint64_t));
- coll->args.coll.lastChunkSize /= ncclTypeSize(info->datatype);
+ work->coll.lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop);
+ ALIGN_SIZE(work->coll.lastChunkSize, info->nThreads*sizeof(uint64_t));
+ work->coll.lastChunkSize /= ncclTypeSize(info->datatype);
} else if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_LL128) {
int nNodes = info->comm->nNodes;
float ppn = info->comm->nRanks / (float)nNodes;
@@ -393,7 +408,7 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo
while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*64/ppn && chunkSize > 131072) chunkSize /= 2;
while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*16/ppn && chunkSize > 32768) chunkSize /= 2;
// Use lastChunkSize as chunkSize
- coll->args.coll.lastChunkSize = chunkSize*NCCL_LL128_DATAELEMS/(NCCL_LL128_LINEELEMS*ncclTypeSize(info->datatype));
+ work->coll.lastChunkSize = chunkSize*NCCL_LL128_DATAELEMS/(NCCL_LL128_LINEELEMS*ncclTypeSize(info->datatype));
}
// Compute nSteps for proxies
@@ -406,9 +421,13 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclCo
proxyArgs->sliceSteps = sliceSteps;
proxyArgs->chunkSteps = chunkSteps;
proxyArgs->protocol = info->protocol;
- proxyArgs->opCount = info->comm->opCount;
proxyArgs->dtype = info->datatype;
proxyArgs->redOp = info->op;
+ // This is used by P2P to reduce the receive buffer size. We don't use it in collectives
+ // because some protocols need to transmit more than the total size, plus they sometimes
+ // round up
+ proxyArgs->recvbytes = stepSize*proxyArgs->sliceSteps;
+
TRACE(NCCL_NET,"opCount %lx slicesteps %d spl %d cpl %d nbytes %zi -> protocol %d nchannels %d nthreads %d, nloops %d nsteps %d comm %p",
proxyArgs->opCount, proxyArgs->sliceSteps, info->nstepsPerLoop, info->nchunksPerLoop, info->nBytes, info->protocol, info->nChannels, info->nThreads,
nLoops, proxyArgs->nsteps, info->comm);
@@ -427,32 +446,26 @@ static ncclResult_t checkSetStream(struct ncclInfo* info) {
}
ncclResult_t ncclSaveKernel(struct ncclInfo* info) {
- if (info->comm->nRanks == 1 && info->coll != ncclCollSendRecv) {
+ if (info->comm->nRanks == 1) {
if (info->sendbuff != info->recvbuff)
CUDACHECK(cudaMemcpyAsync(info->recvbuff, info->sendbuff, info->nBytes, cudaMemcpyDeviceToDevice, info->stream));
return ncclSuccess;
}
- struct ncclColl coll;
+ struct ncclWorkElem work;
struct ncclProxyArgs proxyArgs;
memset(&proxyArgs, 0, sizeof(struct ncclProxyArgs));
- NCCLCHECK(computeColl(info, &coll, &proxyArgs));
+ NCCLCHECK(computeColl(info, &work, &proxyArgs));
info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads);
- int nChannels = info->coll == ncclCollSendRecv ? 1 : coll.args.coll.nChannels;
+ int nChannels = work.coll.nChannels;
int nSubChannels = (info->pattern == ncclPatternCollTreeUp || info->pattern == ncclPatternCollTreeDown) ? 2 : 1;
for (int bid=0; bid<nChannels*nSubChannels; bid++) {
- int channelId = (info->coll == ncclCollSendRecv) ? info->channelId :
- info->comm->myParams->gridDim.x % info->comm->nChannels;
+ int channelId = info->comm->myParams->gridDim.x % info->comm->nChannels;
struct ncclChannel* channel = info->comm->channels+channelId;
- if (channel->collCount == NCCL_MAX_OPS) {
- WARN("Too many aggregated operations on channel %d (%d max)", channel->id, NCCL_MAX_OPS);
- return ncclInvalidUsage;
- }
-
// Proxy
proxyArgs.channel = channel;
// Adjust pattern for CollNet based on channel index
@@ -460,67 +473,141 @@ ncclResult_t ncclSaveKernel(struct ncclInfo* info) {
info->pattern = (channelId < info->comm->nChannels/nSubChannels) ? ncclPatternCollTreeUp : ncclPatternCollTreeDown;
}
- if (info->coll == ncclCollSendRecv) {
- info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1);
- NCCLCHECK(ncclProxySaveP2p(info, channel));
- } else {
- NCCLCHECK(ncclProxySaveColl(&proxyArgs, info->pattern, info->root, info->comm->nRanks));
- }
+ if (proxyArgs.nsteps) NCCLCHECK(ncclProxySaveColl(&proxyArgs, info->pattern, info->root, info->comm->nRanks));
+
info->comm->myParams->gridDim.x++;
- int opIndex = channel->collFifoTail;
- struct ncclColl* c = channel->collectives+opIndex;
- volatile uint8_t* activePtr = (volatile uint8_t*)&c->active;
- while (activePtr[0] != 0) sched_yield();
-
- memcpy(c, &coll, sizeof(struct ncclColl));
- if (info->coll != ncclCollSendRecv) c->args.coll.bid = bid % coll.args.coll.nChannels;
-
- c->active = 1;
- opIndex = (opIndex+1)%NCCL_MAX_OPS;
- c->nextIndex = opIndex;
- channel->collFifoTail = opIndex;
- channel->collCount++;
+ work.coll.bid = bid % nChannels;
+ NCCLCHECK(getNextOp(channel, NULL, &work));
+ }
+ return ncclSuccess;
+}
+
+#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
+#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
+
+ncclResult_t ncclSaveCommKernels(ncclComm_t comm) {
+ if (comm->asyncOpCount == 0) {
+ return ncclSuccess;
+ } else if (comm->asyncOpCount == 1) {
+ // No aggregation
+ struct ncclInfo* info = comm->asyncOps;
+ info->nChannels = 0;
+ NCCLCHECK(ncclSaveKernel(info));
+ } else {
+ // Aggregation
+ size_t channelSize = NCCL_AGG_CHANNEL_SIZE * comm->nRanks; // scale channel size based on nranks as latency increases
+ // Reduce the per-channel size if we cannot fully utilize the channels
+ while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2;
+ for (int c = 0; c < comm->asyncOpCount; c++) {
+ struct ncclInfo* info = comm->asyncOps+c;
+ info->nChannels = std::min((int)DIVUP(info->nBytes, channelSize), comm->nChannels); // assign number of channels
+ NCCLCHECK(ncclSaveKernel(info));
+ }
+ }
+ // Reset counters
+ comm->asyncOpCount = 0;
+ comm->asyncTotalSize = 0;
+ return ncclSuccess;
+}
+
+static ncclResult_t ncclSaveAsyncColl(struct ncclInfo* info) {
+ ncclComm_t comm = info->comm;
+ if (comm->asyncOpCount >= NCCL_MAX_OPS) {
+ WARN("Too many async operations in progress, max is %d", NCCL_MAX_OPS);
+ return ncclInvalidUsage;
}
- info->comm->opCount++;
+ memcpy(comm->asyncOps+comm->asyncOpCount, info, sizeof(struct ncclInfo));
+ comm->asyncOpCount++;
+ comm->asyncTotalSize += info->nBytes;
return ncclSuccess;
}
-// Save p2p operations in comm->p2plist. Operations will be posted to channels
+// Save p2p operations in comm->p2pSends and p2pRecvs. Operations will be posted to channels
// during ncclGroupEnd()
-ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
+static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
struct ncclComm* comm = info->comm;
- struct ncclP2Plist* p2plist = &comm->p2plist;
int peer = info->root;
- p2plist->count++;
ssize_t nBytes = info->count*ncclTypeSize(info->datatype);
- if (info->recvbuff == NULL) {
+ if (info->opName[0] == 'S') { // Send
if (peer != comm->rank) {
int delta = (comm->nRanks - (comm->rank-peer)) % comm->nRanks;
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels;
if (comm->channels[channelId].peers[peer].send.connected == 0) {
- p2plist->connect.send[channelId*comm->nRanks+p2plist->connect.nsend[channelId]++] = peer;
+ comm->connectSend[peer] |= (1<<channelId);
+ comm->connect = 1;
}
}
}
- p2plist->peerlist[info->root].sendbytes = nBytes;
- p2plist->peerlist[info->root].sendbuff = info->sendbuff;
+ NCCLCHECK(enqueueP2pInfo(comm->p2pSends+info->root, (void*)info->sendbuff, nBytes));
+ comm->p2pSendCount++;
} else {
if (peer != comm->rank) {
int delta = (comm->nRanks + (comm->rank-peer)) % comm->nRanks;
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels;
if (comm->channels[channelId].peers[peer].recv.connected == 0) {
- p2plist->connect.recv[channelId*comm->nRanks+p2plist->connect.nrecv[channelId]++] = peer;
+ comm->connectRecv[peer] |= (1<<channelId);
+ comm->connect = 1;
}
}
}
- p2plist->peerlist[info->root].recvbytes = nBytes;
- p2plist->peerlist[info->root].recvbuff = info->recvbuff;
+ NCCLCHECK(enqueueP2pInfo(comm->p2pRecvs+info->root, info->recvbuff, nBytes));
+ comm->p2pRecvCount++;
}
return ncclSuccess;
}
+static int getSegment(struct ncclInfo* info, struct ncclWork* work) {
+ for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != info->delta; s++) {
+ if (work->elems[s].p2p.nThreads == 0) return s;
+ }
+ return -1;
+}
+
+static ncclResult_t saveP2pOp(struct ncclInfo* info /* input */, struct ncclWork* work, int s) {
+ struct ncclWorkElem* elem = work->elems+s;
+ elem->comm = info->comm->devComm;
+ elem->funcIndex = FUNC_INDEX_P2P;
+ elem->nThreads = info->nThreads = NCCL_MAX_NTHREADS;
+ elem->sendbuff = info->sendbuff;
+ elem->recvbuff = info->recvbuff;
+ elem->p2p.sendCount = info->sendbytes;
+ elem->p2p.recvCount = info->recvbytes;
+ elem->p2p.delta = info->delta;
+ const int nsegments = s+1;
+ int nThreads = 512;
+ while (nsegments*nThreads > 512) nThreads /= 2;
+ if (nThreads >= 128) nThreads += WARP_SIZE;
+ for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
+ return ncclSuccess;
+}
+
+ncclResult_t ncclSaveP2pKernel(struct ncclInfo* info) {
+ int channelId = info->channelId;
+ struct ncclChannel* channel = info->comm->channels+channelId;
+
+ // Try to reuse last p2p operation if not full yet
+ int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
+ struct ncclWork* w = channel->workFifo+opIndex;
+ int segment = -1;
+ if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].p2p.nThreads == 0) {
+ // Try to pack more segments into a single operation
+ segment = getSegment(info, w);
+ }
+ if (segment == -1) {
+ NCCLCHECK(getNextOp(channel, &w, NULL));
+ segment = 0;
+ }
+
+ NCCLCHECK(ncclProxySaveP2p(info, channel, segment));
+ NCCLCHECK(saveP2pOp(info, w, segment));
+ info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1);
+ info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads);
+
+ return ncclSuccess;
+}
+
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
// Launch asynchronously if needed
if (ncclAsyncMode()) {
@@ -542,10 +629,10 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count,
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream);
- if (info->coll == ncclCollSendRecv) { //p2p stored separately
+ if (info->coll == ncclFuncSendRecv) { //p2p stored separately
NCCLCHECKGOTO(ncclSaveP2p(info), ret, end);
} else {
- NCCLCHECKGOTO(ncclSaveKernel(info), ret, end);
+ NCCLCHECKGOTO(ncclSaveAsyncColl(info), ret, end);
}
end:
if (savedDev != -1) CUDACHECK(cudaSetDevice(savedDev));
diff --git a/src/graph/connect.cc b/src/graph/connect.cc
index dd9f9f0..a0f1265 100644
--- a/src/graph/connect.cc
+++ b/src/graph/connect.cc
@@ -23,14 +23,10 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
for (int c=0; c<nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
channel->ring.prev = channel->ring.next = -1;
- channel->treeUp.up = -1;
- for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->treeUp.down[i] = -1;
- channel->treeDn.up = -1;
- for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->treeDn.down[i] = -1;
- channel->collTreeUp.up = -1;
- for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->collTreeUp.down[i] = -1;
- channel->collTreeDn.up = -1;
- for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->collTreeDn.down[i] = -1;
+ channel->tree.up = -1;
+ for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->tree.down[i] = -1;
+ channel->collTree.up = -1;
+ for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->collTree.down[i] = -1;
int* ringIntra = ringGraph->intra+c*localRanks;
int* treeIntra = treeGraph->intra+c*localRanks;
@@ -44,33 +40,21 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
channel->ring.next = (i == localRanks-1) ? -1 : ringIntra[i+1];
}
if (treeIntra[i] == rank) {
- int recvIndex = 0, sendIndex = treeGraph->pattern == NCCL_TOPO_PATTERN_TREE ? 0 : 1;
- int prev = (i-1+localRanks)%localRanks, next = (i+1)%localRanks;
-
- // Tree loop always flows in the same direction. Other trees are symmetric, i.e.
- // up/down go in reverse directions
- int sym = treeGraph->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE_LOOP ? 0 : 1;
+ int parentIndex = 0;
+ int child0Index = treeGraph->pattern == NCCL_TOPO_PATTERN_TREE ? 0 : 1;
+ int child1Index = treeGraph->pattern == NCCL_TOPO_PATTERN_SPLIT_TREE ? 1 : 0;
- // Down tree is common
- topoRanks->treeDnRecv[c] = treeIntra[recvIndex];
- topoRanks->treeDnSend[c] = treeIntra[sendIndex];
- channel->treeDn.up = treeIntra[prev];
- channel->treeDn.down[0] = treeIntra[next];
- // Up tree depends on the pattern
- topoRanks->treeUpRecv[c] = sym ? topoRanks->treeDnSend[c] : topoRanks->treeDnRecv[c];
- topoRanks->treeUpSend[c] = sym ? topoRanks->treeDnRecv[c] : topoRanks->treeDnSend[c];
- channel->treeUp.down[0] = sym ? channel->treeDn.down[0] : channel->treeDn.up ;
- channel->treeUp.up = sym ? channel->treeDn.up : channel->treeDn.down[0];
+ topoRanks->treeToParent[c] = treeIntra[parentIndex];
+ topoRanks->treeToChild0[c] = treeIntra[child0Index];
+ topoRanks->treeToChild1[c] = treeIntra[child1Index];
+ channel->tree.up = i == 0 ? -1 : treeIntra[i-1];
+ channel->tree.down[0] = i == localRanks-1 ? -1 : treeIntra[i+1];
}
if (collNetIntra[i] == rank) {
int prev = (i-1+localRanks)%localRanks, next = (i+1)%localRanks;
- // CollTrees are always symmetric, i.e.
- // up/down go in reverse directions
- channel->collTreeDn.up = collNetIntra[prev];
- channel->collTreeDn.down[0] = collNetIntra[next];
- channel->collTreeUp.down[0] = channel->collTreeDn.down[0];
- channel->collTreeUp.up = channel->collTreeDn.up;
+ channel->collTree.up = collNetIntra[prev];
+ channel->collTree.down[0] = collNetIntra[next];
}
}
topoRanks->ringPrev[c] = channel->ring.prev;
@@ -120,72 +104,66 @@ static ncclResult_t getIndexes(int* ranks, int* indexes, int nNodes, int* firstR
return ncclSuccess;
}
-static ncclResult_t setTreeUp(struct ncclTree* tree0, struct ncclTree* tree1, int* indexes, int u0, int u1) {
- if (u0 != -1) tree0->up = indexes[u0];
- if (u1 != -1) tree1->up = indexes[u1];
+static ncclResult_t setTreeUp(struct ncclTree* tree, int* indexes, int u) {
+ if (u == -1) return ncclSuccess;
+ tree->up = indexes[u];
return ncclSuccess;
}
-static ncclResult_t addRanksDown(int* down, int* indexes, int r0, int r1) {
+static ncclResult_t setTreeDown(struct ncclTree* tree, int* indexes, int d) {
+ if (d == -1) return ncclSuccess;
int x = 0;
- if (down[x] >= 0) x++;
- if (down[x] >= 0) {
- WARN("Internal error : tree already has more than one child (%d %d %d)\n", down[0], down[1], down[2]);
+ while (x < NCCL_MAX_TREE_ARITY && tree->down[x] >= 0) x++;
+ if (x == NCCL_MAX_TREE_ARITY) {
+ WARN("Internal error : tree already has %d children (%d %d %d)\n", x, tree->down[0], tree->down[1], tree->down[2]);
return ncclInternalError;
}
- if (r0 != -1) down[x++] = indexes[r0];
- if (r1 != -1) down[x++] = indexes[r1];
- return ncclSuccess;
-}
-
-static ncclResult_t setTreeDown(struct ncclTree* tree0, struct ncclTree* tree1, int* indexes, int d0_0, int d0_1, int d1_0, int d1_1) {
- NCCLCHECK(addRanksDown(tree0->down, indexes, d0_0, d0_1));
- NCCLCHECK(addRanksDown(tree1->down, indexes, d1_0, d1_1));
- return ncclSuccess;
-}
-
-static ncclResult_t openRing(struct ncclTree* tree, int rank, int upRank) {
- if (tree->down[0] == upRank) tree->down[0] = -1;
- if (rank == upRank) tree->up = -1;
+ tree->down[x] = indexes[d];
return ncclSuccess;
}
-static ncclResult_t connectTrees(struct ncclComm* comm, int* treeUpRecv, int* treeUpSend, int* treeDnRecv, int* treeDnSend, int* firstRanks) {
+static ncclResult_t connectTrees(struct ncclComm* comm, int* treeToParent, int* treeToChild0, int* treeToChild1, int* firstRanks, int* treePatterns) {
const int nChannels = comm->nChannels, nNodes = comm->nNodes, node = comm->node;
- int* indexesSend, *indexesRecv;
- NCCLCHECK(ncclCalloc(&indexesSend, nNodes));
- NCCLCHECK(ncclCalloc(&indexesRecv, nNodes));
+ int* ranksToParent, *ranksToChild0, *ranksToChild1;
+ NCCLCHECK(ncclCalloc(&ranksToParent, nNodes));
+ NCCLCHECK(ncclCalloc(&ranksToChild0, nNodes));
+ NCCLCHECK(ncclCalloc(&ranksToChild1, nNodes));
// Compute tree depth. Not an exact value but a good approximation in most
// cases
int depth = comm->nRanks/nNodes - 1 + log2i(nNodes);
- int u0, d0_0, d0_1, u1, d1_0, d1_1;
- NCCLCHECK(ncclGetDtree(nNodes, node, &u0, &d0_0, &d0_1, &u1, &d1_0, &d1_1));
+ int t0u, t0d0, t0d1, t0ChildType, t1u, t1d0, t1d1, t1ChildType;
+ NCCLCHECK(ncclGetDtree(nNodes, node, &t0u, &t0d0, &t0d1, &t0ChildType, &t1u, &t1d0, &t1d1, &t1ChildType));
for (int c=0; c<nChannels; c++) {
struct ncclChannel* channel0 = comm->channels+c;
struct ncclChannel* channel1 = channel0+nChannels;
- NCCLCHECK(getIndexes(treeUpSend+c*comm->nRanks, indexesSend, nNodes, firstRanks));
- NCCLCHECK(getIndexes(treeUpRecv+c*comm->nRanks, indexesRecv, nNodes, firstRanks));
- NCCLCHECK(openRing(&channel0->treeUp, comm->rank, indexesSend[node]));
- NCCLCHECK(openRing(&channel1->treeUp, comm->rank, indexesSend[node]));
- int root = indexesSend[node];
- if (indexesSend[node] == comm->rank) NCCLCHECK(setTreeUp(&channel0->treeUp, &channel1->treeUp, indexesRecv, u0, u1));
- if (indexesRecv[node] == comm->rank) NCCLCHECK(setTreeDown(&channel0->treeUp, &channel1->treeUp, indexesSend, d0_0, d0_1, d1_0, d1_1));
- NCCLCHECK(getIndexes(treeDnSend+c*comm->nRanks, indexesSend, nNodes, firstRanks));
- NCCLCHECK(getIndexes(treeDnRecv+c*comm->nRanks, indexesRecv, nNodes, firstRanks));
- NCCLCHECK(openRing(&channel0->treeDn, comm->rank, u0 == -1 ? root : indexesRecv[node]));
- NCCLCHECK(openRing(&channel1->treeDn, comm->rank, u1 == -1 ? root : indexesRecv[node]));
- if (indexesSend[node] == comm->rank) NCCLCHECK(setTreeDown(&channel0->treeDn, &channel1->treeDn, indexesRecv, d0_0, d0_1, d1_0, d1_1));
- if (indexesRecv[node] == comm->rank) NCCLCHECK(setTreeUp(&channel0->treeDn, &channel1->treeDn, indexesSend, u0, u1));
- TRACE(NCCL_GRAPH, "TreeUp %d : %d -> %d/%d/%d", c, channel0->treeUp.up, channel0->treeUp.down[0], channel0->treeUp.down[1], channel0->treeUp.down[2]);
- TRACE(NCCL_GRAPH, "TreeUp %d : %d -> %d/%d/%d", c+nChannels, channel1->treeUp.up, channel1->treeUp.down[0], channel1->treeUp.down[1], channel1->treeUp.down[2]);
- TRACE(NCCL_GRAPH, "TreeDn %d : %d -> %d/%d/%d", c, channel0->treeDn.up, channel0->treeDn.down[0], channel0->treeDn.down[1], channel0->treeDn.down[2]);
- TRACE(NCCL_GRAPH, "TreeDn %d : %d -> %d/%d/%d", c+nChannels, channel1->treeDn.up, channel1->treeDn.down[0], channel1->treeDn.down[1], channel1->treeDn.down[2]);
- channel0->treeUp.depth = channel1->treeUp.depth = depth;
+ NCCLCHECK(getIndexes(treeToParent+c*comm->nRanks, ranksToParent, nNodes, firstRanks));
+ NCCLCHECK(getIndexes(treeToChild0+c*comm->nRanks, ranksToChild0, nNodes, firstRanks));
+ NCCLCHECK(getIndexes(treeToChild1+c*comm->nRanks, ranksToChild1, nNodes, firstRanks));
+ if (comm->rank == ranksToParent[node]) {
+ NCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ranksToChild0 : ranksToChild1, t0u));
+ NCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ranksToChild0 : ranksToChild1, t1u));
+ }
+ if (comm->rank == ranksToChild0[node]) {
+ NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d0));
+ NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d0));
+ }
+ if (comm->rank == ranksToChild1[node]) {
+ NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d1));
+ NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d1));
+ }
+ if (comm->rank == ranksToParent[node] ||
+ comm->rank == ranksToChild0[node] ||
+ comm->rank == ranksToChild1[node]) {
+ INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c, channel0->tree.up, comm->rank, channel0->tree.down[0], channel0->tree.down[1], channel0->tree.down[2]);
+ INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c+nChannels, channel1->tree.up, comm->rank, channel1->tree.down[0], channel1->tree.down[1], channel1->tree.down[2]);
+ }
+ channel0->tree.depth = channel1->tree.depth = depth;
}
- free(indexesSend);
- free(indexesRecv);
+ free(ranksToParent);
+ free(ranksToChild0);
+ free(ranksToChild1);
return ncclSuccess;
}
@@ -198,13 +176,13 @@ ncclResult_t ncclTopoConnectCollNet(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclChannel* channel = comm->channels+c;
// Set root of collTree to id nranks
if (rank == collNetGraph->intra[sendIndex+c*comm->localRanks]) { // is master
- channel->collTreeUp.up = channel->collTreeDn.up = nranks;
+ channel->collTree.up = nranks;
}
if (rank == collNetGraph->intra[sendEndIndex+c*comm->localRanks]) { // is bottom of intra-node chain
- channel->collTreeUp.down[0] = channel->collTreeDn.down[0] = -1;
+ channel->collTree.down[0] = -1;
}
- channel->collTreeUp.depth = channel->collTreeDn.depth = depth;
- INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", c, rank, channel->collTreeUp.up, channel->collTreeUp.down[0]);
+ channel->collTree.depth = depth;
+ INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", c, rank, channel->collTree.up, channel->collTree.down[0]);
}
int recvIndex = 0; // recv GPU index is always 0
int recvEndIndex = (recvIndex+comm->localRanks-1)%comm->localRanks;
@@ -212,13 +190,13 @@ ncclResult_t ncclTopoConnectCollNet(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclChannel* channel = comm->channels+comm->nChannels/2+c;
// Set root of collTree to id nranks
if (rank == collNetGraph->intra[recvIndex+c*comm->localRanks]) { // is master
- channel->collTreeUp.up = channel->collTreeDn.up = nranks;
+ channel->collTree.up = nranks;
}
if (rank == collNetGraph->intra[recvEndIndex+c*comm->localRanks]) { // is bottom of intra-node chain
- channel->collTreeUp.down[0] = channel->collTreeDn.down[0] = -1;
+ channel->collTree.down[0] = -1;
}
- channel->collTreeUp.depth = channel->collTreeDn.depth = depth;
- INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", comm->nChannels/2+c, rank, channel->collTreeDn.up, channel->collTreeDn.down[0]);
+ channel->collTree.depth = depth;
+ INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", comm->nChannels/2+c, rank, channel->collTree.up, channel->collTree.down[0]);
}
return ncclSuccess;
}
@@ -253,35 +231,33 @@ int ncclMaxNchannels() {
return maxNchannels;
}
-ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, struct ncclTopoRanks** allTopoRanks, int* rings) {
+ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, struct ncclTopoRanks** allTopoRanks, int* rings) {
// Gather data from all ranks
- int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeUpRecv, *treeUpSend, *treeDnRecv,*treeDnSend;
+ int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeToParent, *treeToChild0, *treeToChild1;
int nranks = comm->nRanks;
int nChannels = comm->nChannels;
NCCLCHECK(ncclCalloc(&ringRecv, nranks*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringSend, nranks*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringPrev, nranks*MAXCHANNELS));
NCCLCHECK(ncclCalloc(&ringNext, nranks*MAXCHANNELS));
- NCCLCHECK(ncclCalloc(&treeUpRecv, nranks*MAXCHANNELS));
- NCCLCHECK(ncclCalloc(&treeUpSend, nranks*MAXCHANNELS));
- NCCLCHECK(ncclCalloc(&treeDnRecv, nranks*MAXCHANNELS));
- NCCLCHECK(ncclCalloc(&treeDnSend, nranks*MAXCHANNELS));
+ NCCLCHECK(ncclCalloc(&treeToParent, nranks*MAXCHANNELS));
+ NCCLCHECK(ncclCalloc(&treeToChild0, nranks*MAXCHANNELS));
+ NCCLCHECK(ncclCalloc(&treeToChild1, nranks*MAXCHANNELS));
for (int i=0; i<nranks; i++) {
for (int c=0; c<nChannels;c++) {
ringRecv[c*nranks+i] = allTopoRanks[i]->ringRecv[c];
ringSend[c*nranks+i] = allTopoRanks[i]->ringSend[c];
ringPrev[c*nranks+i] = allTopoRanks[i]->ringPrev[c];
ringNext[c*nranks+i] = allTopoRanks[i]->ringNext[c];
- treeUpRecv[c*nranks+i] = allTopoRanks[i]->treeUpRecv[c];
- treeUpSend[c*nranks+i] = allTopoRanks[i]->treeUpSend[c];
- treeDnRecv[c*nranks+i] = allTopoRanks[i]->treeDnRecv[c];
- treeDnSend[c*nranks+i] = allTopoRanks[i]->treeDnSend[c];
+ treeToParent[c*nranks+i] = allTopoRanks[i]->treeToParent[c];
+ treeToChild0[c*nranks+i] = allTopoRanks[i]->treeToChild0[c];
+ treeToChild1[c*nranks+i] = allTopoRanks[i]->treeToChild1[c];
}
}
// Connect rings and trees. This should also duplicate the channels.
NCCLCHECK(connectRings(comm, ringRecv, ringSend, ringPrev, ringNext, firstRanks));
- NCCLCHECK(connectTrees(comm, treeUpRecv, treeUpSend, treeDnRecv, treeDnSend, firstRanks));
+ NCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, firstRanks, treePatterns));
// Duplicate ringPrev/ringNext for ncclBuildRing
memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int));
@@ -308,10 +284,9 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, struct nccl
free(ringSend);
free(ringPrev);
free(ringNext);
- free(treeUpRecv);
- free(treeUpSend);
- free(treeDnRecv);
- free(treeDnSend);
+ free(treeToParent);
+ free(treeToChild0);
+ free(treeToChild1);
return ncclSuccess;
}
diff --git a/src/graph/paths.cc b/src/graph/paths.cc
index b711874..c7525e6 100644
--- a/src/graph/paths.cc
+++ b/src/graph/paths.cc
@@ -60,7 +60,12 @@ static ncclResult_t ncclTopoSetPaths(struct ncclTopoNode* baseNode, struct ncclT
struct ncclTopoLinkList* remPath;
NCCLCHECK(getPath(system, remNode, baseNode->type, baseNode->id, &remPath));
float width = std::min(path->width, link->width);
- if (remPath->width < width) {
+
+ // allow routing through a GPU only as 1 hop
+ if (node != baseNode && node->type == GPU &&
+ (link->type != LINK_NVL || remNode->type != GPU || path->count > 1)) continue;
+
+ if ((remPath->width == 0 || remPath->count > path->count) && remPath->width < width) {
// Find reverse link
for (int l=0; l<remNode->nlinks; l++) {
if (remNode->links[l].remNode == node) {
@@ -80,24 +85,20 @@ static ncclResult_t ncclTopoSetPaths(struct ncclTopoNode* baseNode, struct ncclT
// Start with path type = link type. PATH and LINK types are supposed to match.
// Don't consider LINK_NET as we only care about the NIC->GPU path.
- int type = link->type == LINK_NET ? 0 : link->type;
+ int type = link->type == LINK_NET ? LINK_LOC : link->type;
// Differentiate between one and multiple PCI switches
- if (type == PATH_PIX && (node->type == PCI || link->remNode->type == PCI) && remPath->count > 3) type = PATH_PXB;
+ if (node->type == PCI && remNode->type == PCI) type = PATH_PXB;
// Consider a path going through the CPU as PATH_PHB
if (link->type == LINK_PCI && (node->type == CPU || link->remNode->type == CPU)) type = PATH_PHB;
- // Ignore Power CPU in an NVLink path
- if (path->type == PATH_NVL && type == PATH_SYS && link->remNode->type == CPU &&
- link->remNode->cpu.arch == NCCL_TOPO_CPU_ARCH_POWER) type = 0;
+ // Set 1 hop NVLink as NVB
+ if (node->type == GPU && path->type == PATH_NVL && type == PATH_NVL && remPath->count > 1) type = PATH_NVB;
remPath->type = std::max(path->type, type);
// Add to the list for the next iteration if not already in the list
- // Disallow GPUs as intermediate steps for now
- if (remNode->type != GPU) {
- int i;
- for (i=0; i<nextNodeList.count; i++) if (nextNodeList.list[i] == remNode) break;
- if (i == nextNodeList.count) nextNodeList.list[nextNodeList.count++] = remNode;
- }
+ int i;
+ for (i=0; i<nextNodeList.count; i++) if (nextNodeList.list[i] == remNode) break;
+ if (i == nextNodeList.count) nextNodeList.list[nextNodeList.count++] = remNode;
}
}
}
@@ -217,7 +218,7 @@ ncclResult_t ncclGetLevel(int* level, const char* disableEnv, const char* levelE
if (l == -1) {
char* str = getenv(levelEnv);
if (str) {
- for (int i=0; i<PATH_NET; i++) {
+ for (int i=0; i<=PATH_SYS; i++) {
if (strcmp(str, topoPathTypeStr[i]) == 0) {
l = i;
break;
@@ -239,9 +240,10 @@ ncclResult_t ncclGetLevel(int* level, const char* disableEnv, const char* levelE
}
int ncclTopoUserP2pLevel = -1;
-ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read) {
+ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read, int* intermediateRank) {
*p2p = 0;
- *read = 0;
+ if (read) *read = 0;
+ if (intermediateRank) *intermediateRank = -1;
// Get GPUs from topology
int g1, g2;
@@ -251,7 +253,16 @@ ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_
// GPU not found, we can't use p2p.
return ncclSuccess;
}
+
+
+ // Set intermediate GPU rank, if routing through an intermediate GPU.
struct ncclTopoLinkList* path = gpu1->paths[GPU]+g2;
+ if (path->count == 2) {
+ struct ncclTopoNode* intermediateNode = path->list[0]->remNode;
+ if (intermediateNode->type == GPU && intermediateRank) {
+ *intermediateRank = intermediateNode->gpu.rank;
+ }
+ }
// In general, use P2P whenever we can.
int p2pLevel = PATH_SYS;
@@ -280,7 +291,7 @@ compare:
if (path->type == PATH_NVL) {
struct ncclTopoNode* gpu2 = system->nodes[GPU].nodes+g2;
// Enable P2P Read for Ampere/NVLink only
- if ((gpu1->gpu.cudaCompCap == gpu2->gpu.cudaCompCap) && (gpu1->gpu.cudaCompCap == 80)) *read = 1;
+ if (read && (gpu1->gpu.cudaCompCap == gpu2->gpu.cudaCompCap) && (gpu1->gpu.cudaCompCap == 80)) *read = 1;
}
return ncclSuccess;
@@ -355,8 +366,8 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclPeer
// Update path when we don't want to / can't use GPU Direct P2P
for (int p=0; p<system->nodes[GPU].count; p++) {
- int p2p, read;
- NCCLCHECK(ncclTopoCheckP2p(system, system->nodes[GPU].nodes[p].id, system->nodes[GPU].nodes[g].id, &p2p, &read));
+ int p2p;
+ NCCLCHECK(ncclTopoCheckP2p(system, system->nodes[GPU].nodes[p].id, system->nodes[GPU].nodes[g].id, &p2p, NULL, NULL));
if (p2p == 0) {
// Divert all traffic through the CPU
int cpu;
@@ -464,8 +475,7 @@ static ncclResult_t ncclTopoGetNchannels(struct ncclTopoSystem* system, int g /*
// Local rank
path = system->nodes[GPU].nodes[peer].paths[GPU]+g;
if (path->type == PATH_NVL) {
- int sm = system->nodes[GPU].nodes[g].gpu.cudaCompCap;
- double nvlWidth = sm < 70 ? PASCAL_NVLINK_WIDTH : VOLTA_NVLINK_WIDTH;
+ float nvlWidth = ncclTopoNVLinkSpeed(system->nodes[GPU].nodes[g].gpu.cudaCompCap);
*nChannels = 2*std::max(1, (int)(path->width / nvlWidth));
} else {
*nChannels = 2;
diff --git a/src/graph/rings.cc b/src/graph/rings.cc
index 5aacbb5..53130d1 100644
--- a/src/graph/rings.cc
+++ b/src/graph/rings.cc
@@ -21,7 +21,7 @@ void dumpLine(int* values, int nranks, const char* prefix) {
ncclResult_t ncclBuildRings(int nrings, int* rings, int rank, int nranks, int* prev, int* next) {
for (int r=0; r<nrings; r++) {
- char prefix[30];
+ char prefix[40];
/*sprintf(prefix, "[%d] Channel %d Prev : ", rank, r);
dumpLine(prev+r*nranks, nranks, prefix);
sprintf(prefix, "[%d] Channel %d Next : ", rank, r);
diff --git a/src/graph/search.cc b/src/graph/search.cc
index cb52921..57c66e7 100644
--- a/src/graph/search.cc
+++ b/src/graph/search.cc
@@ -22,8 +22,18 @@ static float getMaxWidth(struct ncclTopoSystem* system, struct ncclTopoNode* gpu
}
return maxWidth;
}
+static float getTotalWidth(struct ncclTopoSystem* system, struct ncclTopoNode* gpu) {
+ float nvlinkWidth = 0.0, pciWidth = 0.0;
+ for (int l=0; l<gpu->nlinks; l++) {
+ struct ncclTopoLink* link = gpu->links+l;
+ if (link->type == LINK_NVL) nvlinkWidth += link->width;
+ if (link->type == LINK_PCI) pciWidth = link->width;
+ }
+ return std::max(pciWidth, nvlinkWidth);
+}
ncclResult_t ncclTopoSearchInit(struct ncclTopoSystem* system) {
system->maxWidth = 0.0;
+ system->totalWidth = 0.0;
int inter = system->nodes[NET].count;
if (inter == 0 && system->nodes[GPU].count == 1) {
system->maxWidth = LOC_WIDTH;
@@ -32,6 +42,7 @@ ncclResult_t ncclTopoSearchInit(struct ncclTopoSystem* system) {
for (int g=0; g<system->nodes[GPU].count; g++) {
struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g;
system->maxWidth = std::max(system->maxWidth, getMaxWidth(system, gpu, inter ? NET : GPU));
+ system->totalWidth = std::max(system->totalWidth, getTotalWidth(system, gpu));
}
return ncclSuccess;
}
@@ -290,7 +301,7 @@ ncclResult_t ncclTopoCompareGraphs(struct ncclTopoGraph* graph, struct ncclTopoG
return ncclSuccess;
}
// 3. Less hops (but not at the price of going cross NICs)
- if (graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1;
+ if (graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1;
return ncclSuccess;
}
@@ -326,11 +337,26 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo
struct ncclTopoNode* net = system->nodes[NET].nodes+n;
if (graph->pattern == NCCL_TOPO_PATTERN_TREE && net->id != startNet->id) continue; // Trees are symmetric
if (graph->crossNic != 1 && (net->net.asic != startNet->net.asic || net->net.port != startNet->net.port)) continue;
+
+ // Balanced Tree : count half of the bandwidth on first two GPUs
+ int nextBackToNet = -1;
+ float speedInterSave = graph->speedInter;
+ if (graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) {
+ // Count half of the bandwidth on each of the first two GPUs
+ if (step == 0) nextBackToNet = 1;
+ else if (net->id != graph->inter[graph->nChannels*2+1]) continue;
+ graph->speedInter /= 2;
+ }
+
NCCLCHECK(ncclTopoFollowPath(system, graph, GPU, g, NET, n, 1, &net));
+ graph->speedInter = speedInterSave;
if (net) {
graph->inter[graph->nChannels*2+1] = net->id;
- NCCLCHECK(ncclTopoSearchRecGpu(system, graph, saveGraph, gpu, step, -1, backToFirstRank, forcedOrder, time));
+ NCCLCHECK(ncclTopoSearchRecGpu(system, graph, saveGraph, gpu, step, nextBackToNet, backToFirstRank, forcedOrder, time));
+
+ if (graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->speedInter /= 2;
NCCLCHECK(ncclTopoFollowPath(system, graph, GPU, g, NET, n, -1, &net));
+ graph->speedInter = speedInterSave;
}
}
}
@@ -460,13 +486,12 @@ ncclResult_t ncclTopoSearchRecNet(struct ncclTopoSystem* system, struct ncclTopo
ncclResult_t ncclTopoSearchParams(struct ncclTopoSystem* system, int pattern, int* backToNet, int* backToFirstRank) {
if (system->nodes[NET].count) {
if (pattern == NCCL_TOPO_PATTERN_RING) *backToNet = system->nodes[GPU].count-1;
- else if (pattern == NCCL_TOPO_PATTERN_TREE) *backToNet = 0;
- else *backToNet = 1;
- if (pattern == NCCL_TOPO_PATTERN_SPLIT_TREE_LOOP) *backToFirstRank = system->nodes[GPU].count-1;
- else *backToFirstRank = -1;
+ else if (pattern == NCCL_TOPO_PATTERN_SPLIT_TREE) *backToNet = 1;
+ else *backToNet = 0;
+ *backToFirstRank = -1;
} else {
*backToNet = -1;
- if (pattern == NCCL_TOPO_PATTERN_RING || pattern == NCCL_TOPO_PATTERN_SPLIT_TREE_LOOP) *backToFirstRank = system->nodes[GPU].count-1;
+ if (pattern == NCCL_TOPO_PATTERN_RING) *backToFirstRank = system->nodes[GPU].count-1;
else *backToFirstRank = -1;
}
return ncclSuccess;
@@ -503,7 +528,7 @@ ncclResult_t ncclTopoSearchRec(struct ncclTopoSystem* system, struct ncclTopoGra
/* User defined graph from XML file */
/************************************/
-struct kvDict kvDictLinkType[] = { { "SYS", PATH_SYS }, { "PHB", PATH_PHB }, { "PIX", PATH_PIX }, { "PXB", PATH_PXB }, { "NVL", PATH_NVL }, { "LOC", PATH_LOC }, { NULL, 0 } };
+struct kvDict kvDictLinkType[] = { { "SYS", PATH_SYS }, { "PHB", PATH_PHB }, { "PIX", PATH_PIX }, { "PXB", PATH_PXB }, { "NVL", PATH_NVL }, { "NVB", PATH_NVB}, { "LOC", PATH_LOC }, { NULL, 0 } };
ncclResult_t ncclTopoGetChannelFromXml(struct ncclXmlNode *xmlChannel, int c, struct ncclTopoSystem* system, struct ncclTopoGraph* graph) {
int ngpus = system->nodes[GPU].count;
int* inter = graph->inter+2*c;
@@ -623,7 +648,7 @@ ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs
return ncclSuccess;
}
-float speedArray[] = { 42.0, 24.0, 21.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
+float speedArray[] = { 42.0, 30.0, 24.0, 21.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12 };
#define NSPEEDS (sizeof(speedArray)/sizeof(float))
ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph) {
@@ -651,11 +676,15 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph
if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE;
+ // SPLIT_TREE works better on older archs.
+ int ccMin;
+ NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL));
+ if (ccMin < 80 && graph->pattern == NCCL_TOPO_PATTERN_BALANCED_TREE) graph->pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
+
struct ncclTopoGraph tmpGraph;
memcpy(&tmpGraph, graph, sizeof(struct ncclTopoGraph));
// First try crossnic, then decrease speed and finally increase speedIntra.
- tmpGraph.pattern = graph->pattern;
int pass = 1;
int speedIndex = 0;
while (speedArray[speedIndex] > system->maxWidth && speedIndex < NSPEEDS-1) speedIndex++;
@@ -670,7 +699,7 @@ search:
NCCLCHECK(ncclTopoSearchRec(system, &tmpGraph, graph, &time));
#if 0
- printf("Pattern %d, crossNic %d, Speed %g/%g, type %d/%d, channels %d-%d sameChannels %d -> nChannels %dx%g/%g %s\n", tmpGraph.pattern, tmpGraph.crossNic, tmpGraph.speedInter, tmpGraph.speedIntra, tmpGraph.typeInter, tmpGraph.typeIntra, tmpGraph.minChannels, tmpGraph.maxChannels, tmpGraph.sameChannels, graph->nChannels, graph->speedInter, graph->speedIntra, time == 0 ? "TIMEOUT" : "");
+ printf("Pattern %d, crossNic %d, Speed %g/%g, type %d/%d, channels %d-%d sameChannels %d -> nChannels %dx%g/%g %s\n", tmpGraph.pattern, tmpGraph.crossNic, tmpGraph.speedInter, tmpGraph.speedIntra, tmpGraph.typeInter, tmpGraph.typeIntra, tmpGraph.minChannels, tmpGraph.maxChannels, tmpGraph.sameChannels, graph->nChannels, graph->speedInter, graph->speedIntra, time == 0 ? "TIMEOUT" : time == -1 ? "PERFECT" : "");
for (int c=0; c<graph->nChannels; c++) {
printf("%2d : ", c);
for (int g=0; g<ngpus; g++) {
@@ -680,7 +709,8 @@ search:
}
#endif
// Optimal solution, stop here
- if (graph->nChannels == graph->maxChannels && graph->speedInter == system->maxWidth) goto done;
+ if (time == -1) goto done;
+ if (graph->nChannels*graph->speedInter >= system->totalWidth) goto done;
if (pass == 1) {
// First pass, we don't have a solution yet ; try other options
@@ -694,7 +724,7 @@ search:
if (time != -1) globalTimeout += time;
else globalTimeout = NCCL_SEARCH_GLOBAL_TIMEOUT;
- if (globalTimeout < 0) goto done;
+ if (globalTimeout < 0 && graph->nChannels) goto done;
int maxTypeIntra = system->nodes[NET].count > 0 ? tmpGraph.typeInter : PATH_SYS;
if (tmpGraph.typeIntra < maxTypeIntra && (graph->nChannels == 0 || tmpGraph.typeIntra < graph->typeIntra)) {
@@ -709,10 +739,6 @@ search:
tmpGraph.typeInter = PATH_PIX;
// Try a simpler tree
- if (tmpGraph.pattern == NCCL_TOPO_PATTERN_SPLIT_TREE_LOOP) {
- tmpGraph.pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
- goto search;
- }
if (tmpGraph.pattern == NCCL_TOPO_PATTERN_SPLIT_TREE) {
tmpGraph.pattern = NCCL_TOPO_PATTERN_TREE;
goto search;
diff --git a/src/graph/topo.cc b/src/graph/topo.cc
index ed4bd23..3e395c5 100644
--- a/src/graph/topo.cc
+++ b/src/graph/topo.cc
@@ -20,8 +20,8 @@
#define BUSID_REDUCED_SIZE (sizeof("0000:00"))
const char* topoNodeTypeStr[] = { "GPU", "PCI", "NVS", "CPU", "NIC", "NET" };
-const char* topoLinkTypeStr[] = { "LOC", "NVL", "PCI", "", "", "SYS", "NET" };
-const char* topoPathTypeStr[] = { "LOC", "NVL", "PIX", "PXB", "PHB", "SYS", "NET" };
+const char* topoLinkTypeStr[] = { "LOC", "NVL", "", "PCI", "", "", "SYS", "NET" };
+const char* topoPathTypeStr[] = { "LOC", "NVL", "NVB", "PIX", "PXB", "PHB", "SYS" };
/******************************************************************/
/******************* Graph Creation Functions *********************/
@@ -215,7 +215,7 @@ static ncclResult_t ncclTopoPrintRec(struct ncclTopoNode* node, struct ncclTopoN
}
ncclResult_t ncclTopoPrint(struct ncclTopoSystem* s) {
- INFO(NCCL_GRAPH, "=== System : maxWidth %2.1f ===", s->maxWidth);
+ INFO(NCCL_GRAPH, "=== System : maxWidth %2.1f totalWidth %2.1f ===", s->maxWidth, s->totalWidth);
char line[1024];
for (int n=0; n<s->nodes[CPU].count; n++) NCCLCHECK(ncclTopoPrintRec(s->nodes[CPU].nodes+n, NULL, line, 0));
INFO(NCCL_GRAPH, "==========================================");
@@ -441,7 +441,7 @@ ncclResult_t ncclTopoAddNvLinks(struct ncclXmlNode* node, struct ncclTopoSystem*
}
}
if (remote) {
- int nvlSpeed = gpu->gpu.cudaCompCap == 60 ? PASCAL_NVLINK_WIDTH : VOLTA_NVLINK_WIDTH;
+ float nvlSpeed = ncclTopoNVLinkSpeed(gpu->gpu.cudaCompCap);
NCCLCHECK(ncclTopoConnectNodes(gpu, remote, LINK_NVL, count*nvlSpeed));
if (remote->type != GPU) {
NCCLCHECK(ncclTopoConnectNodes(remote, gpu, LINK_NVL, count*nvlSpeed));
@@ -521,6 +521,7 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
struct ncclXmlNode* node;
NCCLCHECK(ncclTopoFillGpu(xml, busId, &node));
if (node == NULL) continue;
+ NCCLCHECK(xmlSetAttrInt(node, "keep", 1));
NCCLCHECK(xmlSetAttrInt(node, "rank", r));
NCCLCHECK(xmlInitAttrInt(node, "gdr", comm->peerInfo[r].gdrSupport));
}
@@ -535,6 +536,7 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
NCCLCHECK(collNetGetProperties(n, &props));
struct ncclXmlNode* netNode;
NCCLCHECK(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode));
+ NCCLCHECK(xmlSetAttrInt(netNode, "keep", 1));
NCCLCHECK(xmlSetAttrInt(netNode, "dev", n));
NCCLCHECK(xmlInitAttrInt(netNode, "speed", props.speed));
NCCLCHECK(xmlInitAttrInt(netNode, "port", props.port));
@@ -552,6 +554,7 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
NCCLCHECK(ncclNetGetProperties(n, &props));
struct ncclXmlNode* netNode;
NCCLCHECK(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode));
+ NCCLCHECK(xmlSetAttrInt(netNode, "keep", 1));
NCCLCHECK(xmlSetAttrInt(netNode, "dev", n));
NCCLCHECK(xmlInitAttrInt(netNode, "speed", props.speed));
NCCLCHECK(xmlInitAttrInt(netNode, "port", props.port));
@@ -560,6 +563,9 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy
NCCLCHECK(xmlInitAttrInt(netNode, "gdr", props.ptrSupport & NCCL_PTR_CUDA ? 1 : 0));
}
+ // Remove XML branches which don't have a node with keep="1" (typically when importing a topology)
+ NCCLCHECK(ncclTopoTrimXml(xml));
+
xmlTopoFile = getenv("NCCL_TOPO_DUMP_FILE");
if (xmlTopoFile && comm->rank == ncclParamTopoDumpFileRank()) {
INFO(NCCL_ENV, "NCCL_TOPO_DUMP_FILE set by environment to %s", xmlTopoFile);
@@ -668,3 +674,21 @@ ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank) {
}
return ncclSuccess;
}
+
+ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count) {
+ *count = system->nodes[NET].count;
+ return ncclSuccess;
+}
+
+ncclResult_t ncclTopoGetCompCap(struct ncclTopoSystem* system, int* ccMin, int* ccMax) {
+ if (system->nodes[GPU].count == 0) return ncclInternalError;
+ int min, max;
+ min = max = system->nodes[GPU].nodes[0].gpu.cudaCompCap;
+ for (int g=1; g<system->nodes[GPU].count; g++) {
+ min = std::min(min, system->nodes[GPU].nodes[g].gpu.cudaCompCap);
+ max = std::max(max, system->nodes[GPU].nodes[g].gpu.cudaCompCap);
+ }
+ if (ccMin) *ccMin = min;
+ if (ccMax) *ccMax = max;
+ return ncclSuccess;
+}
diff --git a/src/graph/topo.h b/src/graph/topo.h
index 950cff8..a12bb2d 100644
--- a/src/graph/topo.h
+++ b/src/graph/topo.h
@@ -12,8 +12,10 @@
#include <sched.h>
#define LOC_WIDTH 5000.0
-#define PASCAL_NVLINK_WIDTH 18.0
-#define VOLTA_NVLINK_WIDTH 21.0
+#define SM60_NVLINK_WIDTH 18.0
+#define SM70_NVLINK_WIDTH 21.0
+#define SM80_NVLINK_WIDTH 21.0
+#define SM86_NVLINK_WIDTH 12.0
#define PCI_WIDTH 12.0 // PCI Gen3 x16
#define QPI_WIDTH 6.0
#define SKL_QPI_WIDTH 9.0
@@ -38,20 +40,21 @@ extern const char* topoNodeTypeStr[];
// We want link types and path types to match as much as possible
#define LINK_LOC 0
#define LINK_NVL 1
-#define LINK_PCI 2
-// Skipping 3 for PATH_PXB
-// Skipping 4 for PATH_PHB
-#define LINK_SYS 5
-#define LINK_NET 6
+// Skipping 2 for PATH_NVB
+#define LINK_PCI 3
+// Skipping 4 for PATH_PXB
+// Skipping 5 for PATH_PHB
+#define LINK_SYS 6
+#define LINK_NET 7
extern const char* topoLinkTypeStr[];
#define PATH_LOC 0
#define PATH_NVL 1
-#define PATH_PIX 2
-#define PATH_PXB 3
-#define PATH_PHB 4
-#define PATH_SYS 5
-#define PATH_NET 6
+#define PATH_NVB 2
+#define PATH_PIX 3
+#define PATH_PXB 4
+#define PATH_PHB 5
+#define PATH_SYS 6
extern const char* topoPathTypeStr[];
struct ncclTopoNode;
@@ -117,6 +120,7 @@ struct ncclTopoNodeSet {
struct ncclTopoSystem {
struct ncclTopoNodeSet nodes[NCCL_TOPO_NODE_TYPES];
float maxWidth;
+ float totalWidth;
};
ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id);
@@ -132,6 +136,8 @@ ncclResult_t ncclTopoGetSystemFromXml(struct ncclXml* xml, struct ncclTopoSystem
ncclResult_t ncclTopoGetGraphFromXml(struct ncclXmlNode *xmlGraphs, struct ncclTopoSystem* system, struct ncclTopoGraph* graph, int* nChannels);
ncclResult_t ncclTopoGetXmlFromGraphs(int ngraphs, struct ncclTopoGraph** graphs, struct ncclTopoSystem* system, struct ncclXml *xml);
+ncclResult_t ncclTopoGetCompCap(struct ncclTopoSystem* system, int* ccMin, int* ccMax);
+
static ncclResult_t ncclTopoIdToIndex(struct ncclTopoSystem* system, int type, int64_t id, int* index) {
*index = -1;
for (int i=0; i<system->nodes[type].count; i++) {
@@ -154,4 +160,13 @@ static ncclResult_t ncclTopoRankToIndex(struct ncclTopoSystem* system, int rank,
return ncclInternalError;
}
+// Returns NVLink speed in GB/s
+static float ncclTopoNVLinkSpeed(int cudaCompCap) {
+ return
+ cudaCompCap == 86 ? SM86_NVLINK_WIDTH :
+ cudaCompCap >= 80 ? SM80_NVLINK_WIDTH :
+ cudaCompCap >= 70 ? SM70_NVLINK_WIDTH :
+ cudaCompCap >= 60 ? SM60_NVLINK_WIDTH :
+ SM80_NVLINK_WIDTH;
+}
#endif
diff --git a/src/graph/trees.cc b/src/graph/trees.cc
index 722e61b..8e1e2ae 100644
--- a/src/graph/trees.cc
+++ b/src/graph/trees.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -28,7 +28,7 @@
* / \ / \ / \ \
* 1 3 5 7 9 11 13
*/
-ncclResult_t ncclGetBtree(int nranks, int rank, int* u, int* d0, int* d1) {
+ncclResult_t ncclGetBtree(int nranks, int rank, int* u, int* d0, int* d1, int* parentChildType) {
int up, down0, down1;
int bit;
for (bit=1; bit<nranks; bit<<=1) {
@@ -37,13 +37,16 @@ ncclResult_t ncclGetBtree(int nranks, int rank, int* u, int* d0, int* d1) {
if (rank == 0) {
*u = -1;
- *d0 = nranks > 1 ? bit >> 1 : -1;
- *d1 = -1;
+ *d0 = -1;
+ // Child rank is > 0 so it has to be our child 1, not 0.
+ *d1 = nranks > 1 ? bit >> 1 : -1;
return ncclSuccess;
}
up = (rank ^ bit) | (bit << 1);
+ // if smaller than the parent, we are his first child, otherwise we're his second
if (up >= nranks) up = (rank ^ bit);
+ *parentChildType = (rank < up) ? 0 : 1;
*u = up;
int lowbit = bit >> 1;
@@ -62,42 +65,42 @@ ncclResult_t ncclGetBtree(int nranks, int rank, int* u, int* d0, int* d1) {
}
/* Build a double binary tree. Take the previous tree for the first tree.
- * For the second tree, we use a mirror tree (if nranks is odd)
+ * For the second tree, we use a mirror tree (if nranks is even)
*
- * 8---------0---------5
- * ______/ \______ _____/ \______
- * 4 12 1 9
- * / \ / \ / \
- * 2 6 10 3 7 10
- * / \ / \ / \ / \ / \ / \
- * 1 3 5 7 9 11 2 4 6 8 11 12
+ * 0---------------8 3----------------11
+ * ______/ \ / \______
+ * 4 \ / 7
+ * / \ \ / / \
+ * 2 6 10 1 5 9
+ * / \ / \ / \ / \ / \ / \
+ * 1 3 5 7 9 11 0 2 4 6 8 10
*
- * or shift it by one rank (if nranks is even)
+ * or shift it by one rank (if nranks is odd).
*
- * 8---------0--------------9
- * ______/ \ ______/ \
- * 4 \ 5 \
- * / \ \ / \ \
- * 2 6 10 3 7 11
- * / \ / \ / \ / \ / \ / \
- * 1 3 5 7 9 11 2 4 6 8 10 1
+ * 0---------------8 1---------------9
+ * ______/ \______ ______/ \______
+ * 4 12 5 0
+ * / \ / / \ /
+ * 2 6 10 3 7 11
+ * / \ / \ / \ / \ / \ / \
+ * 1 3 5 7 9 11 2 4 6 8 10 12
*/
-ncclResult_t ncclGetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* s1, int* d1_0, int* d1_1) {
+ncclResult_t ncclGetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* parentChildType0, int* s1, int* d1_0, int* d1_1, int* parentChildType1) {
// First tree ... use a btree
- ncclGetBtree(nranks, rank, s0, d0_0, d0_1);
+ ncclGetBtree(nranks, rank, s0, d0_0, d0_1, parentChildType0);
// Second tree ... mirror or shift
- if (nranks % 2 == 0) {
+ if (nranks % 2 == 1) {
// shift
int shiftrank = (rank-1+nranks) % nranks;
int u, d0, d1;
- ncclGetBtree(nranks, shiftrank, &u, &d0, &d1);
+ ncclGetBtree(nranks, shiftrank, &u, &d0, &d1, parentChildType1);
*s1 = u == -1 ? -1 : (u+1) % nranks;
*d1_0 = d0 == -1 ? -1 : (d0+1) % nranks;
*d1_1 = d1 == -1 ? -1 : (d1+1) % nranks;
} else {
// mirror
int u, d0, d1;
- ncclGetBtree(nranks, nranks-1-rank, &u, &d0, &d1);
+ ncclGetBtree(nranks, nranks-1-rank, &u, &d0, &d1, parentChildType1);
*s1 = u == -1 ? -1 : nranks-1-u;
*d1_0 = d0 == -1 ? -1 : nranks-1-d0;
*d1_1 = d1 == -1 ? -1 : nranks-1-d1;
diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc
index 62f50ef..42a4dde 100644
--- a/src/graph/tuning.cc
+++ b/src/graph/tuning.cc
@@ -62,77 +62,86 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 4.4,
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
static const float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
{ /* NVLINK */
- { /* Tree (LL/LL128/Simple)*/ { .52, 1.2, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 4.0 } },
+ { /* Tree (LL/LL128/Simple)*/ { .52, 1.25, 28 }, /* Ring (LL/LL128/Simple)*/ { .47, 1.9, 3.4 }, /* CollNet (LL/LL128/Simple)*/ { .5, 1.2, 4.0 } },
/* PCI */
{ /* Tree (LL/LL128/Simple)*/ { 1.0, 1.9, 28 }, /* Ring (LL/LL128/Simple)*/ { 1.0, 2.5, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 1.0, 1.9, 5.5 } },
/* NET */
- { /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 50 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, /* CollNet (LL/LL128/Simple)*/ { 5.0, 5.0, 10.7 } }
+ { /* Tree (LL/LL128/Simple)*/ { 5.0, 8.5, 28 }, /* Ring (LL/LL128/Simple)*/ { 2.7, 4.0, 9.6 }, /* CollNet (LL/LL128/Simple)*/ { 5.0, 5.0, 10.7 } }
};
// LL128 max BW (per channel) for the different collectives
-// ncclCollBroadcast, ncclCollReduce, ncclCollAllGather, ncclCollReduceScatter, ncclCollAllReduce
-static const double ll128MaxBwPerCh[NCCL_NUM_FUNCTIONS] = { 18.8, 12.0, 18.3, 15.2, 16.7 };
+// ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce
+static const double ll128MaxBwPerCh[NCCL_NUM_FUNCTIONS] = { 18.8, 12.0, 18.3, 15.2, 16.9 };
+static const double llMaxBws[2][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0} };
+static const double perChMaxTreeBws[2][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 22.5, 16.0} };
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) {
- int simpleDefaultThreads = (ringGraph->speedIntra*ringGraph->nChannels <= PCI_WIDTH) ? 256 : NCCL_MAX_NTHREADS;
+ int simpleDefaultThreads = (ringGraph->speedIntra*ringGraph->nChannels <= PCI_WIDTH) ? 256 : NCCL_SIMPLE_MAX_NTHREADS;
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] =
- getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_MAX_NTHREADS, simpleDefaultThreads);
+ getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, simpleDefaultThreads);
comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = comm->maxThreads[NCCL_ALGO_COLLNET][NCCL_PROTO_SIMPLE] =
- getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_MAX_NTHREADS, NCCL_MAX_NTHREADS);
+ getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, NCCL_SIMPLE_MAX_NTHREADS);
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_COLLNET][NCCL_PROTO_LL] =
- getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_MAX_NTHREADS, NCCL_MAX_NTHREADS);
+ getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_LL_MAX_NTHREADS, NCCL_LL_MAX_NTHREADS);
comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_COLLNET][NCCL_PROTO_LL128] =
getNthreads("NCCL_LL128_NTHREADS", ncclParamLl128Nthreads(), NCCL_LL128_MAX_NTHREADS/4, NCCL_LL128_MAX_NTHREADS, NCCL_LL128_MAX_NTHREADS);
- if (comm->nRanks <= 1) return ncclSuccess;
+ int nNodes = comm->nNodes;
+ int nRanks = comm->nRanks;
+ if (nRanks <= 1) return ncclSuccess;
int compCap80 = minCompCap == 80 && maxCompCap == 80 ? 1 : 0;
- float ppn = (float)comm->nRanks / comm->nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
+ int cpuArch, cpuVendor, cpuModel;
+ NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel));
+ int index2 = nNodes <= 2 ? nNodes-1 : 2;
+ // LL: for single node, we look at GPU type; for multi-node, we look at CPU type
+ int index1 = nNodes == 1 ? compCap80 : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0;
+ double llMaxBw = llMaxBws[index1][index2];
+ double perChMaxTreeBw = perChMaxTreeBws[compCap80][index2];
+ float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
+
struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph };
int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS];
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) intraHw[a] = graphs[a]->typeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI;
- for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) hw[a] = comm->nNodes == 1 ? intraHw[a] : NCCL_HW_NET;
+ for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) hw[a] = nNodes == 1 ? intraHw[a] : NCCL_HW_NET;
for (int coll=0; coll<NCCL_NUM_FUNCTIONS; coll++) {
- int nsteps = coll == ncclCollAllReduce ? 2*(comm->nRanks-1) :
- coll == ncclCollReduceScatter || coll == ncclCollAllGather ? comm->nRanks-1 :
- comm->nRanks;
- int nInterSteps = coll == ncclCollAllReduce ? 2*(comm->nNodes-1) :
- coll == ncclCollReduceScatter || coll == ncclCollAllGather ? comm->nNodes-1 :
- comm->nNodes;
+ int nsteps = coll == ncclFuncAllReduce ? 2*(nRanks-1) :
+ coll == ncclFuncReduceScatter || coll == ncclFuncAllGather ? nRanks-1 :
+ nRanks;
+ int nInterSteps = coll == ncclFuncAllReduce ? 2*(nNodes-1) :
+ coll == ncclFuncReduceScatter || coll == ncclFuncAllGather ? nNodes-1 :
+ nNodes;
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) {
- if (coll != ncclCollAllReduce && a != NCCL_ALGO_RING) continue;
+ if (coll != ncclFuncAllReduce && a != NCCL_ALGO_RING) continue;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- float speed = comm->nNodes <= 2 || a == NCCL_ALGO_COLLNET ? graphs[a]->speedIntra : graphs[a]->speedInter;
+ float speed = nNodes <= 2 || a == NCCL_ALGO_COLLNET ? graphs[a]->speedIntra : graphs[a]->speedInter;
float busBw = graphs[a]->nChannels * speed;
// Various model refinements
if (compCap80) busBw = std::min(busBw, 235.0f);
- if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) busBw *= (comm->nNodes > 1 || coll == ncclCollAllReduce || coll == ncclCollReduce) ? 1.0/4.0 : 1.0/3.0;
+ if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); }
if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
- double maxTreeBw = comm->nNodes > 2 ?
- compCap80 && p == NCCL_PROTO_LL128 ? 105.0 : 80.0 :
- compCap80 && p == NCCL_PROTO_LL128 ? 130.0 : 110.0;
- if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.9, maxTreeBw);
- if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw *= 1.0/3.8;
- if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (comm->nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels*7.0/9.0);
+ if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw);
+ if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw);
+ if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 0.915 /*120.0/128.0*/), ll128MaxBwPerCh[coll]*graphs[a]->nChannels);
if (a == NCCL_ALGO_COLLNET) busBw *= .9;
if (a == NCCL_ALGO_COLLNET && p == NCCL_PROTO_LL) busBw *= 1.0/6.0; // Take into account that GDR read is disabled on both sides
if (a == NCCL_ALGO_COLLNET && p == NCCL_PROTO_LL128) busBw = 0; // CollNet does not support LL128
// Convert bus BW to algorithm BW
- float ratio = (a != NCCL_ALGO_RING) ? .5 : (1.0 * comm->nRanks) / nsteps;
+ float ratio = (a != NCCL_ALGO_RING) ? .5 : (1.0 * nRanks) / nsteps;
comm->bandwidths[coll][a][p] = busBw * ratio;
comm->latencies[coll][a][p] = baseLat[a][p];
float intraLat = hwLat[intraHw[a]][a][p];
float interLat = hwLat[NCCL_HW_NET][a][p];
- if (comm->nNodes > 1 && p == NCCL_PROTO_LL) intraLat *= 1.8;
+ if (nNodes > 1 && p == NCCL_PROTO_LL) intraLat *= 1.8;
if (a == NCCL_ALGO_RING) {
float lat = hwLat[hw[a]][a][p];
- if ((coll == ncclCollReduce || coll == ncclCollBroadcast)) {
+ if ((coll == ncclFuncReduce || coll == ncclFuncBroadcast)) {
if (ringGraph->sameChannels) {
comm->latencies[coll][a][p] += lat;
} else {
@@ -144,10 +153,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
}
} else if (a == NCCL_ALGO_TREE) {
comm->latencies[coll][a][p] +=
- 2 * ((comm->nRanks/comm->nNodes-1) * intraLat + log2i(comm->nNodes) * interLat);
+ 2 * ((nRanks/nNodes-1) * intraLat + log2i(nNodes) * interLat);
} else {
comm->latencies[coll][a][p] +=
- 2 * (comm->nRanks/comm->nNodes-1) * intraLat + interLat;
+ 2 * (nRanks/nNodes-1) * intraLat + interLat;
}
}
}
@@ -168,6 +177,15 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
INFO(NCCL_ENV, "NCCL_ALGO set by environment to %s", algoStr);
NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable));
}
+ // Disable CollNet if it is not supported
+ if (comm->collNetSupport == 0) {
+ algoEnable[NCCL_ALGO_COLLNET] = 0;
+ // If user has hard set NCCL_ALGO=COLLNET, ignore it
+ if (algoEnable[NCCL_ALGO_RING] == 0 && algoEnable[NCCL_ALGO_TREE] == 0) {
+ algoEnable[NCCL_ALGO_RING] = algoEnable[NCCL_ALGO_TREE] = 1;
+ if (comm->rank == 0) WARN("CollNet is not supported or fails to initialize, ignoring NCCL_ALGO=COLLNET");
+ }
+ }
for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
int pEnable = protoEnable[p];
@@ -178,7 +196,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
}
if (pEnable == 0) comm->bandwidths[c][a][p] = 0;
// Only disable algo for Allreduce since others only have one
- if (c == ncclCollAllReduce && algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0;
+ if (c == ncclFuncAllReduce && algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0;
}
if (comm->rank == 0) {
@@ -214,7 +232,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
comm->threadThresholds[a][NCCL_PROTO_LL128] = NCCL_LL128_THREAD_THRESHOLD;
comm->threadThresholds[a][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_THREAD_THRESHOLD;
}
- comm->threadThresholds[NCCL_ALGO_RING][NCCL_PROTO_LL] *= comm->nRanks;
+ comm->threadThresholds[NCCL_ALGO_RING][NCCL_PROTO_LL] *= nRanks;
// Override defaults with user env
char* str = getenv("NCCL_THREAD_THRESHOLDS");
@@ -243,11 +261,11 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
}
// Trees are not perfectly sticking to the model for medium sizes. Applying a static correction
-// factor is not ideal but works quite well. Powers of two, 64 B to 128MB.
-static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = {
- { 1.0, 1.0, 1.0, 1.0, .9, .8, .7, .7, .7, .7, .6, .5, .4, .4, .5, .6, .7, .8, .9, 1.0, 1.0, 1.0 },
- { 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .5, .6, .6, .7, .7, .8, .9, .9, 1.0 },
- { .9, .9, .9, .9, .9, .9, .9, .8, .7, .6, .6, .5, .5, .5, .5, .5, .5, .6, .6, .7, .8, .9 }
+// factor is not ideal but works quite well. Powers of two, 64 B to 256MB.
+static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][23] = {
+ { 1.0, 1.0, 1.0, 1.0, .9, .8, .7, .7, .7, .7, .6, .5, .4, .4, .5, .6, .7, .8, .9, 1.0, 1.0, 1.0, 1.0 },
+ { 1.0, 1.0, 1.0, 1.0, 1.0, .9, .8, .8, .8, .7, .6, .6, .6, .5, .6, .6, .7, .7, .8, .9, .9, .92, .92 },
+ { .9, .9, .9, .9, .9, .9, .9, .8, .7, .6, .6, .5, .5, .5, .5, .6, .7, .8, .7, .7, .8, .9, .9 }
};
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, float* time) {
@@ -257,9 +275,10 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto
*time = -1.0; return ncclSuccess;
}
int logSize = log2i(info->nBytes>>6);
- if (algorithm == NCCL_ALGO_TREE && logSize < 22) bw *= treeCorrectionFactor[protocol][logSize];
+ if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize];
+ if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels;
if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1
- && info->coll == ncclCollAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring
+ && info->coll == ncclFuncAllReduce && info->nBytes >= info->comm->nRanks/16.0*65536) lat *= 1.9; // Plateau effect of ring
*time = lat + (info->nBytes) / (1000 * bw);
return ncclSuccess;
}
diff --git a/src/graph/xml.cc b/src/graph/xml.cc
index cc91b92..b2232c2 100644
--- a/src/graph/xml.cc
+++ b/src/graph/xml.cc
@@ -559,7 +559,6 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
NCCLCHECK(xmlGetAttrIndex(gpuNode, "dev", &index));
if (index == -1) {
if (nvmlDev == NULL) {
- WARN("No NVML, trying to use CUDA instead");
const char* busId;
NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId));
if (busId == NULL || cudaDeviceGetByPCIBusId(&dev, busId) != cudaSuccess) dev = -1;
@@ -647,6 +646,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
char* path;
NCCLCHECK(getPciPath(busId, &path));
NCCLCHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass"));
+ free(path);
}
}
}
@@ -658,10 +658,14 @@ ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct nccl
struct ncclXmlNode* node;
NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node));
NCCLCHECK(ncclTopoGetXmlFromSys(node, xml));
- NCCLCHECK(wrapNvmlSymbols());
- NCCLCHECK(wrapNvmlInit());
- nvmlDevice_t nvmlDev;
- if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev) != ncclSuccess) nvmlDev = NULL;
+ nvmlDevice_t nvmlDev = NULL;
+ static int nvmlInit = 0;
+ if (nvmlInit == 0) {
+ nvmlInit = (wrapNvmlSymbols() != ncclSuccess || wrapNvmlInit() != ncclSuccess) ? 2 : 1;
+ }
+ if (nvmlInit == 1) {
+ if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev) != ncclSuccess) nvmlDev = NULL;
+ }
NCCLCHECK(ncclTopoGetXmlFromGpu(node, nvmlDev, xml, gpuNode));
return ncclSuccess;
}
@@ -704,12 +708,8 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha
for (offset=strlen(pciSysPath)-1; pciSysPath[offset] != '/'; offset--);
char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
strcpy(busId, pciSysPath+offset+1);
- NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", busId));
- if (parent == NULL) {
- NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent));
- NCCLCHECK(xmlSetAttr(parent, "busid", busId));
- NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml));
- }
+ NCCLCHECK(ncclTopoGetPciNode(xml, busId, &parent));
+ NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml));
} else {
// Virtual NIC, no PCI device, attach to first CPU
NCCLCHECK(xmlFindTag(xml, "cpu", &parent));
@@ -728,6 +728,28 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha
return ncclSuccess;
}
+ncclResult_t ncclTopoTrimXmlRec(struct ncclXmlNode* node) {
+ const char* str;
+ NCCLCHECK(xmlGetAttr(node, "keep", &str));
+ if (str && strcmp(str, "1") == 0) {
+ NCCLCHECK(xmlUnsetAttr(node, "keep"));
+ } else {
+ // Copy nSubs and subs as they could change as we trim recursively.
+ struct ncclXmlNode* subs[MAX_SUBS];
+ int nSubs = node->nSubs;
+ memcpy(subs, node->subs, node->nSubs*sizeof(struct ncclXmlNode*));
+ for (int s=0; s<nSubs; s++) {
+ NCCLCHECK(ncclTopoTrimXmlRec(subs[s]));
+ }
+ if (node->nSubs == 0) NCCLCHECK(xmlRemoveNode(node));
+ }
+ return ncclSuccess;
+}
+ncclResult_t ncclTopoTrimXml(struct ncclXml* xml) {
+ NCCLCHECK(ncclTopoTrimXmlRec(xml->nodes));
+ return ncclSuccess;
+}
+
/**************************************************/
/* Parser rules for the user-defined graph search */
/**************************************************/
diff --git a/src/graph/xml.h b/src/graph/xml.h
index 22e016f..9a617af 100644
--- a/src/graph/xml.h
+++ b/src/graph/xml.h
@@ -8,7 +8,7 @@
#define XML_H_
// A few constraints to make the implementation easy
-#define MAX_STR_LEN 256
+#define MAX_STR_LEN 255
#define MAX_ATTR_COUNT 16
#define MAX_SUBS 32
#define MAX_NODES 1024
@@ -19,10 +19,10 @@
#define NODE_TYPE_SINGLE 3
struct ncclXmlNode {
- char name[MAX_STR_LEN];
+ char name[MAX_STR_LEN+1];
struct {
- char key[MAX_STR_LEN];
- char value[MAX_STR_LEN];
+ char key[MAX_STR_LEN+1];
+ char value[MAX_STR_LEN+1];
} attrs[MAX_ATTR_COUNT+1]; // Need an extra one to consume extra params
int nAttrs;
int type;
@@ -47,6 +47,9 @@ ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXm
ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode);
ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode);
+/* Remove unneeded parts */
+ncclResult_t ncclTopoTrimXml(struct ncclXml* xml);
+
/**************/
/* XML Struct */
/* Functions */
@@ -56,7 +59,7 @@ static ncclResult_t xmlGetAttrIndex(struct ncclXmlNode* node, const char* attrNa
*index = -1;
const int nAttrs = node->nAttrs;
for (int a=0; a<nAttrs; a++) {
- if (strncmp(node->attrs[a].key, attrName, MAX_STR_LEN-1) == 0) {
+ if (strncmp(node->attrs[a].key, attrName, MAX_STR_LEN) == 0) {
*index = a;
return ncclSuccess;
}
@@ -127,8 +130,10 @@ static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, c
if (index == -1) {
index = node->nAttrs++;
strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ node->attrs[index].key[MAX_STR_LEN] = '\0';
}
strncpy(node->attrs[index].value, value, MAX_STR_LEN);
+ node->attrs[index].value[MAX_STR_LEN] = '\0';
return ncclSuccess;
}
@@ -138,8 +143,10 @@ static ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName
if (index == -1) {
index = node->nAttrs++;
strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ node->attrs[index].key[MAX_STR_LEN] = '\0';
}
snprintf(node->attrs[index].value, MAX_STR_LEN, "%d", value);
+ node->attrs[index].value[MAX_STR_LEN] = '\0';
return ncclSuccess;
}
@@ -149,8 +156,22 @@ static ncclResult_t xmlSetAttrFloat(struct ncclXmlNode* node, const char* attrNa
if (index == -1) {
index = node->nAttrs++;
strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ node->attrs[index].key[MAX_STR_LEN] = '\0';
}
snprintf(node->attrs[index].value, MAX_STR_LEN, "%g", value);
+ node->attrs[index].value[MAX_STR_LEN] = '\0';
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlUnsetAttr(struct ncclXmlNode* node, const char* attrName) {
+ int index;
+ NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
+ if (index == -1) return ncclSuccess;
+ for (int i=index+1; i<node->nAttrs; i++) {
+ strcpy(node->attrs[i-1].key, node->attrs[i].key);
+ strcpy(node->attrs[i-1].value, node->attrs[i].value);
+ }
+ node->nAttrs--;
return ncclSuccess;
}
@@ -199,6 +220,20 @@ static ncclResult_t xmlAddNode(struct ncclXml* xml, struct ncclXmlNode* parent,
s->parent = parent;
if (parent) parent->subs[parent->nSubs++] = s;
strncpy(s->name, subName, MAX_STR_LEN);
+ s->name[MAX_STR_LEN] = '\0';
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlRemoveNode(struct ncclXmlNode* node) {
+ node->type = NODE_TYPE_NONE;
+ struct ncclXmlNode* parent = node->parent;
+ if (parent == NULL) return ncclSuccess;
+ int shift = 0;
+ for (int s=0; s<parent->nSubs; s++) {
+ if (parent->subs[s] == node) shift = 1;
+ else if (shift) parent->subs[s-1] = parent->subs[s];
+ }
+ parent->nSubs--;
return ncclSuccess;
}
diff --git a/src/group.cc b/src/group.cc
index 5ce4901..78a74b6 100644
--- a/src/group.cc
+++ b/src/group.cc
@@ -34,7 +34,6 @@ struct ncclInitArgs {
};
struct ncclCollArgs {
ncclComm_t comm;
- int connect;
};
enum ncclAsyncFuncType {
@@ -109,6 +108,7 @@ ncclResult_t ncclAsyncColl(ncclComm_t comm) {
NCCL_API(ncclResult_t, ncclGroupStart);
ncclResult_t ncclGroupStart() {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
if (ncclGroupMode == 0) {
memset(ncclGroupArgs, 0, sizeof(struct ncclAsyncArgs)*MAX_ASYNC_OPS);
}
@@ -117,7 +117,7 @@ ncclResult_t ncclGroupStart() {
}
static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int channelId, ssize_t recvbytes, void* recvbuff, ssize_t sendbytes, const void* sendbuff) {
- struct ncclInfo info = { ncclCollSendRecv, "SendRecv",
+ struct ncclInfo info = { ncclFuncSendRecv, "SendRecv",
sendbuff, recvbuff, (size_t)std::max<ssize_t>(sendbytes,recvbytes), ncclInt8, ncclSum, -1, comm, comm->userStream, /* Args */
1, 1 };
info.delta = delta;
@@ -125,26 +125,32 @@ static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int chann
info.sendbytes = sendbytes;
info.recvbytes = recvbytes;
if (delta == 0 && sendbytes != recvbytes) return ncclInvalidUsage;
- NCCLCHECK(ncclSaveKernel(&info));
+ NCCLCHECK(ncclSaveP2pKernel(&info));
return ncclSuccess;
}
void* ncclAsyncThreadPreconnect(void* args_) {
struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_;
- CUDACHECKTHREAD(cudaSetDevice(args->coll.comm->cudaDev));
- for (int c=0; c<args->coll.comm->p2pnChannels; c++) {
- struct ncclComm* comm = args->coll.comm;
- struct ncclChannel* channel = comm->channels+c;
- struct ncclP2PConnect* connect = &comm->p2plist.connect;
- NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, channel, connect->nrecv[c], connect->recv+c*comm->nRanks, connect->nsend[c], connect->send+c*comm->nRanks));
- connect->nrecv[c] = 0;
- connect->nsend[c] = 0;
- }
+ struct ncclComm* comm = args->coll.comm;
+ CUDACHECKTHREAD(cudaSetDevice(comm->cudaDev));
+ NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL));
return args;
}
+static size_t getP2pChunkSize(size_t totalSize, int minChannels, int maxChannels, size_t minSize, size_t maxSize) {
+ size_t size = std::max(minSize, DIVUP(totalSize, minChannels));
+ int nChannels = minChannels;
+ while (size > maxSize && nChannels <= maxChannels/2) {
+ nChannels *= 2;
+ size = DIVUP(totalSize, nChannels);
+ }
+ ALIGN_SIZE(size, minSize);
+ return size;
+}
+
NCCL_API(ncclResult_t, ncclGroupEnd);
ncclResult_t ncclGroupEnd() {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
if (ncclGroupMode == 0) {
WARN("ncclGroupEnd: not in a group call.");
return ncclInvalidUsage;
@@ -185,29 +191,21 @@ ncclResult_t ncclGroupEnd() {
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
- if (args->funcType == ASYNC_FUNC_COLL) {
- struct ncclP2Plist* p2plist = &args->coll.comm->p2plist;
- if (p2plist->count != 0) {
- struct ncclComm* comm = args->coll.comm;
- args->coll.connect = 0;
- for (int c=0; c<comm->p2pnChannels; c++)
- args->coll.connect += comm->p2plist.connect.nsend[c] + comm->p2plist.connect.nrecv[c];
- if (args->coll.connect) {
- pthread_create(ncclGroupThreads+i, NULL, ncclAsyncThreadPreconnect, args);
- }
- }
+ if (args->funcType == ASYNC_FUNC_COLL && args->coll.comm->connect) {
+ pthread_create(ncclGroupThreads+i, NULL, ncclAsyncThreadPreconnect, args);
}
}
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
- if (args->funcType == ASYNC_FUNC_COLL && (args->coll.connect)) {
+ if (args->funcType == ASYNC_FUNC_COLL && args->coll.comm->connect) {
int err = pthread_join(ncclGroupThreads[i], NULL);
if (err != 0) {
WARN("Error waiting for pthread_join : %s\n", strerror(errno));
return ncclSystemError;
}
NCCLCHECKGOTO(args->ret, ret, end);
+ args->coll.comm->connect = 0;
}
}
@@ -217,48 +215,83 @@ ncclResult_t ncclGroupEnd() {
struct ncclComm* comm = args->coll.comm;
int rank = comm->rank;
int nRanks = comm->nRanks;
- struct ncclP2Plist* p2plist = &args->coll.comm->p2plist;
- if (p2plist->count) {
- for (int delta=0; delta<nRanks; delta++) {
+ struct ncclP2Plist* p2pSends = comm->p2pSends;
+ struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
+
+ // Compute how much to split operations
+ // Natural step size matching buffer steps.
+ ssize_t stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
+ // Try to use all channels
+ int nChannelsMax = comm->p2pnChannelsPerPeer;
+ int nChannelsMin = nChannelsMax;
+ // Try to use all channels, but one channel per operation.
+ while (nChannelsMin*comm->nRanks > comm->p2pnChannels && nChannelsMin > 1) nChannelsMin /= 2;
+ // Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth.
+ while (nChannelsMax*comm->nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2;
+
+ while (comm->p2pSendCount > 0 || comm->p2pRecvCount > 0) {
+ // schedule delta 0, +1, -1, +2, -2, ...
+ // also make sure we don't do 0 twice, nor +n/2 and -n/2 if n is even.
+ for (int d=0; d<=nRanks/4; d++) {
+ int deltas[4] = { d, (nRanks-d)%nRanks, nRanks/2-d, nRanks-(nRanks/2-d) };
+ int index = 0;
+ int delta = deltas[index];
+sched_delta:
uint32_t from = (rank+nRanks-delta)%nRanks;
uint32_t to = (rank+delta)%nRanks;
+ struct ncclP2Pinfo* recv = p2pRecvs[from].head;
+ struct ncclP2Pinfo* send = p2pSends[to].head;
+ if (recv != NULL || send != NULL) {
+ ssize_t totRecvBytes = -1, totSendBytes = -1;
+ if (recv != NULL) totRecvBytes = recv->nbytes;
+ if (send != NULL) totSendBytes = send->nbytes;
+ ssize_t recvChunkSize = getP2pChunkSize(totRecvBytes, nChannelsMin, nChannelsMax, stepSize, SENDRECV_SLICEFACTOR*stepSize);
+ ssize_t sendChunkSize = getP2pChunkSize(totSendBytes, nChannelsMin, nChannelsMax, stepSize, SENDRECV_SLICEFACTOR*stepSize);
- // Compute how much to split operations
- // Natural step size matching buffer steps.
- ssize_t stepSize = 4*comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
- // Split each operation on p2pnChannelsPerPeer max.
- ssize_t recvChunkSize = DIVUP(p2plist->peerlist[from].recvbytes, comm->p2pnChannelsPerPeer);
- ssize_t sendChunkSize = DIVUP(p2plist->peerlist[to].sendbytes, comm->p2pnChannelsPerPeer);
- recvChunkSize = std::max((ssize_t)1, DIVUP(recvChunkSize, stepSize)) * stepSize;
- sendChunkSize = std::max((ssize_t)1, DIVUP(sendChunkSize, stepSize)) * stepSize;
-
- ssize_t sendOffset = 0;
- ssize_t recvOffset = 0;
- int remaining = 1;
- int chunk = 0;
- while (remaining) {
- int channelId = (delta+comm->p2pChannels[chunk%comm->p2pnChannelsPerPeer]) % comm->p2pnChannels;
- remaining = 0;
- ssize_t recvbytes = p2plist->peerlist[from].recvbytes-recvOffset;
- ssize_t sendbytes = p2plist->peerlist[to].sendbytes-sendOffset;
- if (recvbytes > recvChunkSize) { remaining = 1; recvbytes = recvChunkSize; } else p2plist->peerlist[from].recvbytes = -1;
- if (sendbytes > sendChunkSize) { remaining = 1; sendbytes = sendChunkSize; } else p2plist->peerlist[to].sendbytes = -1;
- if (sendbytes >= 0 || recvbytes >= 0) {
- NCCLCHECKGOTO(scheduleSendRecv(comm, delta, channelId,
- recvbytes, ((char*)(p2plist->peerlist[from].recvbuff)) + recvOffset,
- sendbytes, ((const char*)(p2plist->peerlist[to].sendbuff)) + sendOffset), ret, end);
+ ssize_t sendOffset = 0;
+ ssize_t recvOffset = 0;
+ int sendRemaining = 1, recvRemaining = 1;
+ int chunk = 0;
+ do {
+ int channelId = (delta+comm->p2pChannels[chunk%comm->p2pnChannelsPerPeer]) % comm->p2pnChannels;
+ ssize_t recvbytes = totRecvBytes-recvOffset;
+ ssize_t sendbytes = totSendBytes-sendOffset;
+ if (recvbytes > recvChunkSize) { recvbytes = recvChunkSize; } else { recvRemaining = 0; }
+ if (sendbytes > sendChunkSize) { sendbytes = sendChunkSize; } else { sendRemaining = 0; }
+ if (sendbytes >= 0 || recvbytes >= 0) {
+ NCCLCHECKGOTO(scheduleSendRecv(comm, delta, channelId,
+ recvbytes, recv ? ((char*)(recv->buff)) + recvOffset : NULL,
+ sendbytes, send ? ((const char*)(send->buff)) + sendOffset : NULL), ret, group_cleanup);
+ }
+ recvOffset += recvChunkSize;
+ sendOffset += sendChunkSize;
+ chunk++;
+ } while (sendRemaining || recvRemaining);
+ if (recv) {
+ NCCLCHECKGOTO(dequeueP2pInfo(p2pRecvs+from), ret, group_cleanup);
+ comm->p2pRecvCount--;
+ }
+ if (send) {
+ NCCLCHECKGOTO(dequeueP2pInfo(p2pSends+to), ret, group_cleanup);
+ comm->p2pSendCount--;
}
- recvOffset += recvChunkSize;
- sendOffset += sendChunkSize;
- chunk++;
+ }
+ index++;
+ if (index == 1 && deltas[1] == deltas[0]) index++;
+ if (index == 2 && deltas[2] == deltas[0]) index++;
+ if (index == 3 && deltas[3] == deltas[2]) index++;
+ if (index == 3 && deltas[3] == deltas[1]) index++;
+ if (index < 4) {
+ delta = deltas[index];
+ goto sched_delta;
}
}
- p2plist->count = 0;
}
}
}
/* Collectives are done in three steps :
+ * 0. Save kernels previously enqueued. Compute channel, algo, proto, etc.
* 1. Barrier Check In. Only the last call may call cudaLaunchKernel[cooperative]
* 2. Barrier Wait. No CUDA call is permitted
* 3. Enqueue Events. CUDA event wait/enqueue.
@@ -270,6 +303,13 @@ ncclResult_t ncclGroupEnd() {
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL) {
+ ncclComm_t comm = args->coll.comm;
+ NCCLCHECKGOTO(ncclSaveCommKernels(comm), ret, group_cleanup);
+ }
+ }
+ for (int i=0; i<ncclGroupIndex; i++) {
+ struct ncclAsyncArgs* args = ncclGroupArgs+i;
+ if (args->funcType == ASYNC_FUNC_COLL) {
if (args->coll.comm->userStream == NULL)
CUDACHECKGOTO(cudaSetDevice(args->coll.comm->cudaDev), ret, end);
NCCLCHECKGOTO(ncclBarrierEnqueue(args->coll.comm), ret, end);
@@ -303,32 +343,28 @@ group_cleanup:
*args->init.newcomm = NULL;
} else {
struct ncclComm* comm = args->coll.comm;
- for (int c=0; c<comm->p2pnChannels; c++) {
- struct ncclChannel* channel = comm->channels+c;
- for (int i=0; i<channel->collCount; i++) {
- channel->collectives[(channel->collStart + i)%NCCL_MAX_OPS].active = 0;
+ // Reset aggregation counters
+ comm->asyncOpCount = 0;
+ comm->asyncTotalSize = 0;
+ // Dequeue p2p lists
+ if (comm->p2pSendCount > 0 || comm->p2pRecvCount > 0) {
+ struct ncclP2Plist* p2pSends = comm->p2pSends;
+ struct ncclP2Plist* p2pRecvs = comm->p2pRecvs;
+ for (int peer=0; peer<comm->nRanks; peer++) {
+ while (p2pSends[peer].head != NULL) dequeueP2pInfo(p2pSends+peer);
+ while (p2pRecvs[peer].head != NULL) dequeueP2pInfo(p2pRecvs+peer);
}
- channel->collFifoTail = channel->collStart;
- channel->collCount = 0;
+ comm->p2pSendCount = comm->p2pRecvCount = 0;
}
- /* Cancel all proxy ops : mark them as ncclProxyOpNone and they should be freed later on */
+ /* Free all proxy ops in state->nextOps */
struct ncclProxyState* state = &comm->proxyState;
- struct ncclProxyArgs *op, *start;
- pthread_mutex_lock(&state->mutex);
- op = start = state->ops;
- while (op) {
- if (op->opCount >= comm->lastOpCount) op->state = ncclProxyOpNone;
- struct ncclProxyArgs* peerOp = op->nextPeer;
- while (peerOp) {
- if (peerOp->opCount >= comm->lastOpCount) peerOp->state = ncclProxyOpNone;
- peerOp = peerOp->nextPeer;
- }
- op = op->next;
- if (op == start) break;
+ pthread_mutex_lock(&state->poolMutex);
+ for (struct ncclProxyArgs *op = state->nextOps; op; op = op->next) {
+ op->next = state->pool;
+ state->pool = op;
}
- comm->opCount = comm->lastOpCount;
- pthread_cond_signal(&state->cond);
- pthread_mutex_unlock(&state->mutex);
+ pthread_mutex_unlock(&state->poolMutex);
+ state->nextOps = NULL;
comm->myParams->gridDim.x = comm->myParams->blockDim.x = 0;
comm->userStreamSet = false;
diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h
index a7d6be9..dbe4320 100644
--- a/src/include/bootstrap.h
+++ b/src/include/bootstrap.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -16,6 +16,8 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
ncclResult_t bootstrapSend(void* commState, int peer, void* data, int size);
ncclResult_t bootstrapRecv(void* commState, int peer, void* data, int size);
+ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr);
+ncclResult_t bootstrapRemFree(int id, int rank, void* commState);
ncclResult_t bootstrapClose(void* commState);
ncclResult_t bootstrapAbort(void* commState);
#endif
diff --git a/src/include/coll_net.h b/src/include/coll_net.h
index 3278560..0d17b76 100644
--- a/src/include/coll_net.h
+++ b/src/include/coll_net.h
@@ -24,7 +24,7 @@ static ncclResult_t collNetRegMr(void* comm, void* data, int size, int type, voi
static ncclResult_t collNetDeregMr(void* comm, void* mhandle) { NCCLCHECK(ncclCollNet->deregMr(comm, mhandle)); return ncclSuccess; }
static ncclResult_t collNetIallreduce(void* collComm, void* sendData, void* recvData, int count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) {
NCCLCHECK(ncclCollNet->iallreduce(collComm, sendData, recvData, count, dataType, redOp, sendMhandle, recvMhandle, request)); return ncclSuccess; }
-static ncclResult_t collNetFlush(void* collComm, void* data, int size, void* mhandle) { NCCLCHECK(ncclCollNet->flush(collComm, data, size, mhandle)); return ncclSuccess; }
+static ncclResult_t collNetIflush(void* collComm, void* data, int size, void* mhandle, void** request) { NCCLCHECK(ncclCollNet->iflush(collComm, data, size, mhandle, request)); return ncclSuccess; }
static ncclResult_t collNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclCollNet->test(request, done, size)); return ncclSuccess; }
static ncclResult_t collNetCloseColl(void* collComm) { NCCLCHECK(ncclCollNet->closeColl(collComm)); return ncclSuccess; }
static ncclResult_t collNetCloseListen(void* listenComm) { NCCLCHECK(ncclCollNet->closeListen(listenComm)); return ncclSuccess; }
diff --git a/src/include/collectives.h b/src/include/collectives.h
index f854364..9b9022e 100644
--- a/src/include/collectives.h
+++ b/src/include/collectives.h
@@ -8,55 +8,58 @@
#define NCCL_COLLECTIVES_H_
#define FUNC_INDEX_P2P 0
-#define FUNC_INDEX(coll, redop, dtype, al, pr) (1+(((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
+#define FUNC_INDEX(func, redop, ncclType, al, pr) (1+(((((func)*ncclNumOps + (redop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
-#define NCCL_COLL_NAME(coll, op, dtype) \
- coll##_##op##_##dtype
+#define NCCL_FUNC_NAME(func, algo, proto, redop, type) \
+ ncclFunction_##func##_##algo##_##proto##_##redop##_##type
-#define NCCL_KERN_NAME(coll, op, dtype) \
- coll##Kernel_##op##_##dtype
+#define NCCL_KERN_NAME(func, algo, proto, redop, type) \
+ ncclKernel_##func##_##algo##_##proto##_##redop##_##type
+
+#define NCCL_IMPL_NAME(func, algo, proto) \
+ nccl##func##algo##proto
/* Declare all collective operations */
-#define DECL_COLL5(coll, op, dtype) \
- extern __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args); \
- extern __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl c); \
+#define DECL5(func, algo, proto, redop, type) \
+ extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \
+ extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem c); \
-#define DECL_COLL4(coll, op, dtype) \
- DECL_COLL5(coll, op, dtype) \
- DECL_COLL5(coll##LL, op, dtype) \
- DECL_COLL5(coll##LL128, op, dtype)
+#define DECL4(func, algo, redop, type) \
+ DECL5(func, algo, SIMPLE, redop, type) \
+ DECL5(func, algo, LL, redop, type) \
+ DECL5(func, algo, LL128, redop, type)
-#define DECL_COLL3(coll, op, dtype) \
- DECL_COLL4(coll##Ring, op, dtype) \
- DECL_COLL4(coll##Tree, op, dtype) \
- DECL_COLL4(coll##CollNet, op, dtype)
+#define DECL3(func, redop, type) \
+ DECL4(func, RING, redop, type) \
+ DECL4(func, TREE, redop, type) \
+ DECL4(func, COLLNET, redop, type)
-#define DECL_COLL2(coll, op) \
- DECL_COLL3(coll, op, i8) \
- DECL_COLL3(coll, op, u8) \
- DECL_COLL3(coll, op, i32) \
- DECL_COLL3(coll, op, u32) \
- DECL_COLL3(coll, op, i64) \
- DECL_COLL3(coll, op, u64) \
- DECL_COLL3(coll, op, f16) \
- DECL_COLL3(coll, op, f32) \
- DECL_COLL3(coll, op, f64)
+#define DECL2(func, redop) \
+ DECL3(func, redop, int8_t) \
+ DECL3(func, redop, uint8_t) \
+ DECL3(func, redop, int32_t) \
+ DECL3(func, redop, uint32_t) \
+ DECL3(func, redop, int64_t) \
+ DECL3(func, redop, uint64_t) \
+ DECL3(func, redop, half) \
+ DECL3(func, redop, float) \
+ DECL3(func, redop, double)
-#define DECL_COLL(coll) \
- DECL_COLL2(coll, sum) \
- DECL_COLL2(coll, prod) \
- DECL_COLL2(coll, min) \
- DECL_COLL2(coll, max)
+#define DECL(func) \
+ DECL2(func, Sum) \
+ DECL2(func, Prod) \
+ DECL2(func, Min) \
+ DECL2(func, Max)
-#define DECL_ALL_COLLS \
- DECL_COLL2(ncclBroadcast, copy) \
- DECL_COLL(ncclReduce) \
- DECL_COLL2(ncclAllGather, copy) \
- DECL_COLL(ncclReduceScatter) \
- DECL_COLL(ncclAllReduce) \
- DECL_COLL5(ncclSendRecv,copy,i8) \
+#define DECL_ALL \
+ DECL2(Broadcast, Sum) \
+ DECL(Reduce) \
+ DECL2(AllGather, Sum) \
+ DECL(ReduceScatter) \
+ DECL(AllReduce) \
+ DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) \
-DECL_ALL_COLLS
+DECL_ALL
// CHUNKSIZE must be a multiple of SLICESIZE
#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
diff --git a/src/include/comm.h b/src/include/comm.h
index 8a44747..56116e0 100644
--- a/src/include/comm.h
+++ b/src/include/comm.h
@@ -48,8 +48,8 @@ struct ncclRecvMem {
struct {
uint64_t tail;
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
- char pad2[CACHE_LINE_SIZE-sizeof(uint64_t)];
int sizesFifo[NCCL_STEPS];
+ void* ptrsFifo[NCCL_STEPS];
};
char pad4[MEM_ALIGN];
};
@@ -63,6 +63,10 @@ struct ncclComm {
struct ncclTopoSystem* topo;
void* bootstrap;
+ // Bitmasks for ncclTransportP2pSetup
+ int connect;
+ uint32_t* connectSend;
+ uint32_t* connectRecv;
int rank; // my rank in the communicator
int nRanks; // number of GPUs in communicator
@@ -127,7 +131,7 @@ struct ncclComm {
int* intraCudaDevs;
int* intraCGMode; // Whether we can use CUDA9 CGMD or not
int* intraCC; // Only to check all have the same ComputeCap and disable CGMode if not
- struct ncclColl args;
+ struct ncclWorkElem args;
void* argsptr;
// Global proxy thread
@@ -136,8 +140,17 @@ struct ncclComm {
// Whether this communicator uses collNet
int collNetSupport;
+
+ // Store info of async operations
+ struct ncclInfo* asyncOps;
+ int asyncOpCount;
+ size_t asyncTotalSize;
+
//list of async p2p operation queued in a group semantics
- struct ncclP2Plist p2plist;
+ struct ncclP2Plist* p2pSends;
+ struct ncclP2Plist* p2pRecvs;
+ int p2pSendCount;
+ int p2pRecvCount;
};
#endif
diff --git a/src/include/core.h b/src/include/core.h
index 0435d9b..2283134 100644
--- a/src/include/core.h
+++ b/src/include/core.h
@@ -55,5 +55,6 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) {
#include "alloc.h"
#include "utils.h"
#include "param.h"
+#include "nvtx.h"
#endif // end include guard
diff --git a/src/include/cpuset.h b/src/include/cpuset.h
index 40c1594..ec55cbc 100644
--- a/src/include/cpuset.h
+++ b/src/include/cpuset.h
@@ -19,7 +19,7 @@ static int hexToInt(char c) {
#define CPU_SET_N_U32 (sizeof(cpu_set_t)/sizeof(uint32_t))
-ncclResult_t ncclStrToCpuset(const char* str, cpu_set_t* mask) {
+static ncclResult_t ncclStrToCpuset(const char* str, cpu_set_t* mask) {
uint32_t cpumasks[CPU_SET_N_U32];
int m = CPU_SET_N_U32-1;
cpumasks[m] = 0;
@@ -42,7 +42,7 @@ ncclResult_t ncclStrToCpuset(const char* str, cpu_set_t* mask) {
return ncclSuccess;
}
-ncclResult_t ncclCpusetToStr(cpu_set_t* mask, char* str) {
+static ncclResult_t ncclCpusetToStr(cpu_set_t* mask, char* str) {
int c = 0;
uint8_t* m8 = (uint8_t*)mask;
for (int o=sizeof(cpu_set_t)-1; o>=0; o--) {
diff --git a/src/include/devcomm.h b/src/include/devcomm.h
index 2be7bba..9870117 100644
--- a/src/include/devcomm.h
+++ b/src/include/devcomm.h
@@ -12,7 +12,7 @@
#include <stdint.h>
#define NCCL_NUM_FUNCTIONS 5 // SendRecv not included for now
-typedef enum { ncclCollBroadcast, ncclCollReduce, ncclCollAllGather, ncclCollReduceScatter, ncclCollAllReduce, ncclCollSendRecv} ncclFunc_t;
+typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv} ncclFunc_t;
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS];
#define NCCL_NUM_ALGORITHMS 3 // Tree/Ring/CollNet
@@ -47,8 +47,9 @@ union ncclLLFifoLine {
#define WARP_SIZE 32
#define MAXCHANNELS 32
-#define NCCL_MAX_NTHREADS 512
-#define NCCL_LL_MAX_NTHREADS NCCL_MAX_NTHREADS
+#define NCCL_MAX_NTHREADS 640
+#define NCCL_SIMPLE_MAX_NTHREADS 512
+#define NCCL_LL_MAX_NTHREADS 512
#define NCCL_LL_LINES_PER_THREAD 8
#ifdef TEST_LL_CLEANUP
#define NCCL_LL_CLEAN_MASK 0x078 // Set to 0x100 to disable cleanup
@@ -85,9 +86,11 @@ struct ncclConnInfo {
uint64_t *head; // Local for send, remote for recv
int direct; // Direct communication
+ int shared; // Buffers are shared
void **ptrExchange; // Pointer exchange for direct communication
- int *fifo; // Size fifo for proxy
+ int *sizesFifo; // Sizes fifo from GPU to proxy
+ void* *ptrsFifo; // Buffer fifo from proxy to GPU
uint64_t step; // Keep where we are
uint64_t llLastCleaning;
@@ -129,60 +132,52 @@ struct ncclPeer {
struct ncclDevComm;
-/* CollectiveArgs + ncclColl are to be a power of two, currently 64 bytes, */
+#define NCCL_MAX_WORK_ELEMENTS 8
+#define NCCL_MAX_GROUPS (NCCL_MAX_WORK_ELEMENTS*2)
+
+/* ncclWork is to be a power of two, currently 8x64 bytes, */
/* to make sure reads to host from the CUDA kernel are aligned. */
-/* Make sure to adjust padding at the end of ncclColl. */
-struct CollectiveArgs {
+/* Make sure to adjust padding at the end of ncclWorkElem. */
+struct ncclWorkElem {
+ // Header
struct ncclDevComm* comm;
+ uint16_t nThreads;
+ uint16_t funcIndex;
+ uint16_t index;
+ uint16_t active;
- // local and remote input, output, and buffer
const void * sendbuff;
void * recvbuff;
- // Op-specific fields. Make sure the common part stays the
- // same on all structs of the union
+ // Op-specific fields.
union {
struct {
- uint16_t nThreads;
- } common;
- struct {
- uint16_t nThreads;
- uint8_t bid;
- uint8_t nChannels;
- uint32_t root;
size_t count;
size_t lastChunkSize;
+ uint32_t root;
+ uint8_t bid;
+ uint8_t nChannels;
} coll;
struct {
- uint16_t nThreads;
- uint16_t unused;
- int32_t delta;
size_t sendCount;
size_t recvCount;
+ int32_t delta;
+ uint16_t nThreads;
} p2p;
+ uint64_t align[4];
};
};
-struct ncclColl {
- union {
- struct {
- struct CollectiveArgs args;
- uint16_t funcIndex;
- uint16_t nextIndex;
- uint8_t active;
- };
- int data[0x10];
- };
+struct ncclWork {
+ struct ncclWorkElem elems[NCCL_MAX_WORK_ELEMENTS];
};
-static_assert(sizeof(struct ncclColl) == (0x10*sizeof(int)), "ncclColl must have a pow2 size");
+static_assert(sizeof(struct ncclWorkElem) == (0x10*sizeof(int)), "ncclWorkElem must have a pow2 size");
struct ncclChannel {
union {
struct {
struct ncclRing ring;
- struct ncclTree treeUp;
- struct ncclTree treeDn;
- struct ncclTree collTreeUp;
- struct ncclTree collTreeDn;
+ struct ncclTree tree;
+ struct ncclTree collTree;
int id;
@@ -191,11 +186,9 @@ struct ncclChannel {
struct ncclPeer* devPeers;
// Operation list for aggregation
- struct ncclColl* collectives;
- int collStart;
- int collCount;
- int collFifoHead; // Only used by GPU
- int collFifoTail; // Only used by CPU
+ struct ncclWork* workFifo;
+ int workCount;
+ uint64_t workFifoTail; // Only used by CPU
};
int data[0x80];
};
diff --git a/src/include/enqueue.h b/src/include/enqueue.h
index a7e6e50..2c2ab1f 100644
--- a/src/include/enqueue.h
+++ b/src/include/enqueue.h
@@ -19,5 +19,7 @@ ncclResult_t ncclBarrierEnqueue(struct ncclComm* comm);
ncclResult_t ncclBarrierEnqueueWait(struct ncclComm* comm);
ncclResult_t ncclEnqueueEvents(struct ncclComm* comm);
ncclResult_t ncclSaveKernel(struct ncclInfo* info);
+ncclResult_t ncclSaveP2pKernel(struct ncclInfo* info);
+ncclResult_t ncclSaveCommKernels(struct ncclComm* comm);
#endif // End include guard
diff --git a/src/include/graph.h b/src/include/graph.h
index 70117d5..a4dba5c 100644
--- a/src/include/graph.h
+++ b/src/include/graph.h
@@ -29,7 +29,7 @@ ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm);
// Query topology
ncclResult_t ncclTopoGetNetDev(struct ncclTopoSystem* system, int rank, struct ncclTopoGraph* graph, int channelId, int* net);
-ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read);
+ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read, int* intermediateRank);
ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int netDev, int read, int* useGdr);
// Set CPU affinity
@@ -43,15 +43,16 @@ ncclResult_t ncclTopoSetAffinity(struct ncclTopoSystem* system, int rank);
#define NCCL_TOPO_CPU_TYPE_BDW 1
#define NCCL_TOPO_CPU_TYPE_SKL 2
ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vendor, int* model);
+ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count);
#define NCCL_TOPO_MAX_NODES 256
// Init search. Needs to be done before calling ncclTopoCompute
ncclResult_t ncclTopoSearchInit(struct ncclTopoSystem* system);
-#define NCCL_TOPO_PATTERN_SPLIT_TREE_LOOP 1 // Split tree (send/recv from different ranks) always flowing in the same direction
-#define NCCL_TOPO_PATTERN_SPLIT_TREE 2 // Split tree (send/recv from different ranks) flowing in both directions
-#define NCCL_TOPO_PATTERN_TREE 3 // Simple tree (send/recv from same rank) flowing in both directions
+#define NCCL_TOPO_PATTERN_BALANCED_TREE 1 // Spread NIC traffic between two GPUs (Tree parent + one child on first GPU, second child on second GPU)
+#define NCCL_TOPO_PATTERN_SPLIT_TREE 2 // Spread NIC traffic between two GPUs (Tree parent on first GPU, tree children on the second GPU)
+#define NCCL_TOPO_PATTERN_TREE 3 // All NIC traffic going to/from the same GPU
#define NCCL_TOPO_PATTERN_RING 4 // Ring
struct ncclTopoGraph {
// Input / output
@@ -82,17 +83,16 @@ struct ncclTopoRanks {
int ringSend[MAXCHANNELS];
int ringPrev[MAXCHANNELS];
int ringNext[MAXCHANNELS];
- int treeUpRecv[MAXCHANNELS];
- int treeUpSend[MAXCHANNELS];
- int treeDnRecv[MAXCHANNELS];
- int treeDnSend[MAXCHANNELS];
+ int treeToParent[MAXCHANNELS];
+ int treeToChild0[MAXCHANNELS];
+ int treeToChild1[MAXCHANNELS];
};
ncclResult_t ncclTopoPreset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph,
struct ncclTopoRanks* topoRanks);
-ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks,
+ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns,
struct ncclTopoRanks** allTopoRanks, int* rings);
ncclResult_t ncclTopoConnectCollNet(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, int rank);
diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h
index fd19f81..8c016dc 100644
--- a/src/include/nccl_net.h
+++ b/src/include/nccl_net.h
@@ -15,6 +15,9 @@
#define NCCL_PTR_HOST 0x1
#define NCCL_PTR_CUDA 0x2
+// Maximum number of requests per comm object
+#define NCCL_NET_MAX_REQUESTS 8
+
typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel;
typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALL=~0} ncclDebugLogSubSys;
@@ -29,9 +32,9 @@ typedef struct {
int speed; // Port speed in Mbps.
int port; // Port number.
int maxComms; // Maximum number of comms we can create
-}ncclNetProperties_v3_t;
+}ncclNetProperties_v4_t;
-typedef ncclNetProperties_v3_t ncclNetProperties_t;
+typedef ncclNetProperties_v4_t ncclNetProperties_t;
typedef struct {
// Name of the network (mainly for logs)
@@ -41,7 +44,7 @@ typedef struct {
// Return the number of adapters.
ncclResult_t (*devices)(int* ndev);
// Get various device properties.
- ncclResult_t (*getProperties)(int dev, ncclNetProperties_v3_t* props);
+ ncclResult_t (*getProperties)(int dev, ncclNetProperties_v4_t* props);
// Create a receiving object and provide a handle to connect to it. The
// handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged
// between ranks to create a connection.
@@ -62,7 +65,7 @@ typedef struct {
ncclResult_t (*irecv)(void* recvComm, void* data, int size, void* mhandle, void** request);
// Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is
// visible to the GPU
- ncclResult_t (*flush)(void* recvComm, void* data, int size, void* mhandle);
+ ncclResult_t (*iflush)(void* recvComm, void* data, int size, void* mhandle, void** request);
// Test whether a request is complete. If size is not NULL, it returns the
// number of bytes sent/received.
ncclResult_t (*test)(void* request, int* done, int* size);
@@ -70,11 +73,11 @@ typedef struct {
ncclResult_t (*closeSend)(void* sendComm);
ncclResult_t (*closeRecv)(void* recvComm);
ncclResult_t (*closeListen)(void* listenComm);
-} ncclNet_v3_t;
+} ncclNet_v4_t;
-typedef ncclNet_v3_t ncclNet_t;
+typedef ncclNet_v4_t ncclNet_t;
-#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v3
+#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v4
typedef struct {
// Name of the collective network (mainly for logs)
@@ -85,7 +88,7 @@ typedef struct {
// If ndev returns 0, all other functions might be set to NULL.
ncclResult_t (*devices)(int* ndev);
// Get various device properties.
- ncclResult_t (*getProperties)(int dev, ncclNetProperties_v3_t* props);
+ ncclResult_t (*getProperties)(int dev, ncclNetProperties_v4_t* props);
// Create a receiving object and provide a handle to connect to it. The
// handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged
// between ranks to create connections.
@@ -105,17 +108,17 @@ typedef struct {
ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request);
// Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is
// visible to the GPU
- ncclResult_t (*flush)(void* collComm, void* data, int size, void* mhandle);
+ ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request);
// Test whether a request is complete. If size is not NULL, it returns the
// number of bytes sent/received.
ncclResult_t (*test)(void* request, int* done, int* size);
// Close and free collective comm objects
ncclResult_t (*closeColl)(void* collComm);
ncclResult_t (*closeListen)(void* listenComm);
-} ncclCollNet_v3_t;
+} ncclCollNet_v4_t;
-typedef ncclCollNet_v3_t ncclCollNet_t;
+typedef ncclCollNet_v4_t ncclCollNet_t;
-#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v3
+#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v4
#endif // end include guard
diff --git a/src/include/net.h b/src/include/net.h
index bc81965..244215e 100644
--- a/src/include/net.h
+++ b/src/include/net.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -24,7 +24,7 @@ static ncclResult_t ncclNetRegMr(void* comm, void* data, int size, int type, voi
static ncclResult_t ncclNetDeregMr(void* comm, void* mhandle) { NCCLCHECK(ncclNet->deregMr(comm, mhandle)); return ncclSuccess; }
static ncclResult_t ncclNetIsend(void* sendComm, void* data, int size, void* mhandle, void** request) { NCCLCHECK(ncclNet->isend(sendComm, data, size, mhandle, request)); return ncclSuccess; }
static ncclResult_t ncclNetIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) { NCCLCHECK(ncclNet->irecv(recvComm, data, size, mhandle, request)); return ncclSuccess; }
-static ncclResult_t ncclNetFlush(void* recvComm, void* data, int size, void* mhandle) { NCCLCHECK(ncclNet->flush(recvComm, data, size, mhandle)); return ncclSuccess; }
+static ncclResult_t ncclNetIflush(void* recvComm, void* data, int size, void* mhandle, void** request) { NCCLCHECK(ncclNet->iflush(recvComm, data, size, mhandle, request)); return ncclSuccess; }
static ncclResult_t ncclNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclNet->test(request, done, size)); return ncclSuccess; }
static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeSend(sendComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
diff --git a/src/include/nvmlwrap.h b/src/include/nvmlwrap.h
index 01bbb7f..21ee82e 100644
--- a/src/include/nvmlwrap.h
+++ b/src/include/nvmlwrap.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -45,14 +45,6 @@ static ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index)
NVMLCHECK(nvmlDeviceGetIndex(device, index));
return ncclSuccess;
}
-static ncclResult_t wrapNvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t *device) {
- NVMLCHECK(nvmlDeviceGetHandleByIndex(index,device));
- return ncclSuccess;
-}
-static ncclResult_t wrapNvmlDeviceGetHandleByPciInfo(nvmlDevice_t device, nvmlPciInfo_t* pci) {
- NVMLCHECK(nvmlDeviceGetPciInfo(device, pci));
- return ncclSuccess;
-}
static ncclResult_t wrapNvmlDeviceGetNvLinkState(nvmlDevice_t device, unsigned int link, nvmlEnableState_t *isActive) {
NVMLCHECK(nvmlDeviceGetNvLinkState(device, link, isActive));
return ncclSuccess;
@@ -66,10 +58,6 @@ static ncclResult_t wrapNvmlDeviceGetNvLinkCapability(nvmlDevice_t device, unsig
NVMLCHECK(nvmlDeviceGetNvLinkCapability(device, link, capability, capResult));
return ncclSuccess;
}
-static ncclResult_t wrapNvmlDeviceGetMinorNumber(nvmlDevice_t device, unsigned int* minorNumber) {
- NVMLCHECK(nvmlDeviceGetMinorNumber(device, minorNumber));
- return ncclSuccess;
-}
static ncclResult_t wrapNvmlDeviceGetCudaComputeCapability(nvmlDevice_t device, int* major, int* minor) {
NVMLCHECK(nvmlDeviceGetCudaComputeCapability(device, major, minor));
return ncclSuccess;
@@ -150,12 +138,10 @@ ncclResult_t wrapNvmlShutdown(void);
ncclResult_t wrapNvmlDeviceGetHandleByPciBusId(const char* pciBusId, nvmlDevice_t* device);
ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index);
ncclResult_t wrapNvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t *device);
-ncclResult_t wrapNvmlDeviceGetPciInfo(nvmlDevice_t device, nvmlPciInfo_t* pci);
ncclResult_t wrapNvmlDeviceGetNvLinkState(nvmlDevice_t device, unsigned int link, nvmlEnableState_t *isActive);
ncclResult_t wrapNvmlDeviceGetNvLinkRemotePciInfo(nvmlDevice_t device, unsigned int link, nvmlPciInfo_t *pci);
ncclResult_t wrapNvmlDeviceGetNvLinkCapability(nvmlDevice_t device, unsigned int link,
nvmlNvLinkCapability_t capability, unsigned int *capResult);
-ncclResult_t wrapNvmlDeviceGetMinorNumber(nvmlDevice_t device, unsigned int* minorNumber);
ncclResult_t wrapNvmlDeviceGetCudaComputeCapability(nvmlDevice_t device, int* major, int* minor);
#endif // NVML_DIRECT
diff --git a/src/include/nvtx.h b/src/include/nvtx.h
new file mode 100644
index 0000000..7796126
--- /dev/null
+++ b/src/include/nvtx.h
@@ -0,0 +1,14 @@
+/*************************************************************************
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef NCCL_NVTX_H_
+#define NCCL_NVTX_H_
+
+#include "nvtx3.hpp"
+
+struct nccl_domain{static constexpr char const* name{"NCCL"};};
+
+#endif
diff --git a/src/include/nvtx3.hpp b/src/include/nvtx3.hpp
new file mode 100644
index 0000000..1e99373
--- /dev/null
+++ b/src/include/nvtx3.hpp
@@ -0,0 +1,2268 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Temporary helper #defines, #undef'ed at end of header */
+#define NVTX3_CPP_VERSION_MAJOR 1
+#define NVTX3_CPP_VERSION_MINOR 0
+
+/* This section handles the decision of whether to provide unversioned symbols.
+ * If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is #defined, unversioned symbols are
+ * not provided, and explicit-version symbols such as nvtx3::v1::thread_range
+ * and NVTX3_V1_FUNC_RANGE must be used. By default, the first #include of this
+ * header will define the unversioned symbols such as nvtx3::thread_range and
+ * NVTX3_FUNC_RANGE. Subsequently including a different major version of this
+ * header without #defining NVTX3_CPP_REQUIRE_EXPLICIT_VERSION triggers an error
+ * since the symbols would conflict. Subsequently including of a different
+ * minor version within the same major version is allowed. Functionality of
+ * minor versions is cumulative, regardless of include order.
+ *
+ * Since NVTX3_CPP_REQUIRE_EXPLICIT_VERSION allows all combinations of versions
+ * to coexist without problems within a translation unit, the recommended best
+ * practice for instrumenting header-based libraries with NVTX C++ Wrappers is
+ * is to #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before including nvtx3.hpp,
+ * #undef it afterward, and only use explicit-version symbols. This is not
+ * necessary in common cases, such as instrumenting a standalone application, or
+ * static/shared libraries in .cpp files or headers private to those projects.
+ */
+/* clang-format off */
+#if !defined(NVTX3_CPP_REQUIRE_EXPLICIT_VERSION)
+ /* Define macro used by all definitions in this header to indicate the
+ * unversioned symbols should be defined in addition to the versioned ones.
+ */
+ #define NVTX3_INLINE_THIS_VERSION
+
+ #if !defined(NVTX3_CPP_INLINED_VERSION_MAJOR)
+ /* First occurrence of this header in the translation unit. Define macros
+ * indicating which version shall be used for unversioned symbols.
+ */
+
+ /**
+ * @brief Semantic major version number for NVTX C++ wrappers of unversioned symbols
+ *
+ * Breaking changes may occur between major versions, and different major versions
+ * cannot provide unversioned symbols in the same translation unit (.cpp file).
+ *
+ * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined.
+ *
+ * Not to be confused with the version number of the NVTX core library.
+ */
+ #define NVTX3_CPP_INLINED_VERSION_MAJOR 1 // NVTX3_CPP_VERSION_MAJOR
+
+ /**
+ * @brief Semantic minor version number for NVTX C++ wrappers of unversioned symbols
+ *
+ * No breaking changes occur between minor versions -- minor version changes within
+ * a major version are purely additive.
+ *
+ * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined.
+ *
+ * Not to be confused with the version number of the NVTX core library.
+ */
+ #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR
+ #elif NVTX3_CPP_INLINED_VERSION_MAJOR != NVTX3_CPP_VERSION_MAJOR
+ /* Unsupported case -- cannot define unversioned symbols for different major versions
+ * in the same translation unit.
+ */
+ #error \
+ "Two different major versions of the NVTX C++ Wrappers are being included in a single .cpp file, with unversioned symbols enabled in both. Only one major version can enable unversioned symbols in a .cpp file. To disable unversioned symbols, #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before #including nvtx3.hpp, and use the explicit-version symbols instead -- this is the preferred way to use nvtx3.hpp from a header file."
+ #elif (NVTX3_CPP_INLINED_VERSION_MAJOR == NVTX3_CPP_VERSION_MAJOR) && \
+ (NVTX3_CPP_INLINED_VERSION_MINOR < NVTX3_CPP_VERSION_MINOR)
+ /* An older minor version of the same major version already defined unversioned
+ * symbols. The new features provided in this header will be inlined
+ * redefine the minor version macro to this header's version.
+ */
+ #undef NVTX3_CPP_INLINED_VERSION_MINOR
+ #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR
+ // else, already have this version or newer, nothing to do
+ #endif
+#endif
+/* clang-format on */
+
+#include <nvtx3/nvToolsExt.h>
+
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+/**
+ * @file nvtx3.hpp
+ *
+ * @brief Provides C++ constructs making the NVTX library safer and easier to
+ * use with zero overhead.
+ */
+
+/**
+ * \mainpage
+ * \tableofcontents
+ *
+ * \section QUICK_START Quick Start
+ *
+ * To add NVTX ranges to your code, use the `nvtx3::thread_range` RAII object. A
+ * range begins when the object is created, and ends when the object is
+ * destroyed.
+ *
+ * \code{.cpp}
+ * #include "nvtx3.hpp"
+ * void some_function(){
+ * // Begins a NVTX range with the messsage "some_function"
+ * // The range ends when some_function() returns and `r` is destroyed
+ * nvtx3::thread_range r{"some_function"};
+ *
+ * for(int i = 0; i < 6; ++i){
+ * nvtx3::thread_range loop{"loop range"};
+ * std::this_thread::sleep_for(std::chrono::seconds{1});
+ * }
+ * } // Range ends when `r` is destroyed
+ * \endcode
+ *
+ * The example code above generates the following timeline view in Nsight
+ * Systems:
+ *
+ * \image html
+ * https://raw.githubusercontent.com/jrhemstad/nvtx_wrappers/master/docs/example_range.png
+ *
+ * Alternatively, use the \ref MACROS like `NVTX3_FUNC_RANGE()` to add
+ * ranges to your code that automatically use the name of the enclosing function
+ * as the range's message.
+ *
+ * \code{.cpp}
+ * #include "nvtx3.hpp"
+ * void some_function(){
+ * // Creates a range with a message "some_function" that ends when the
+ * enclosing
+ * // function returns
+ * NVTX3_FUNC_RANGE();
+ * ...
+ * }
+ * \endcode
+ *
+ *
+ * \section Overview
+ *
+ * The NVTX library provides a set of functions for users to annotate their code
+ * to aid in performance profiling and optimization. These annotations provide
+ * information to tools like Nsight Systems to improve visualization of
+ * application timelines.
+ *
+ * \ref RANGES are one of the most commonly used NVTX constructs for annotating
+ * a span of time. For example, imagine a user wanted to see every time a
+ * function, `my_function`, is called and how long it takes to execute. This can
+ * be accomplished with an NVTX range created on the entry to the function and
+ * terminated on return from `my_function` using the push/pop C APIs:
+ *
+ * ```
+ * void my_function(...){
+ * nvtxRangePushA("my_function"); // Begins NVTX range
+ * // do work
+ * nvtxRangePop(); // Ends NVTX range
+ * }
+ * ```
+ *
+ * One of the challenges with using the NVTX C API is that it requires manually
+ * terminating the end of the range with `nvtxRangePop`. This can be challenging
+ * if `my_function()` has multiple returns or can throw exceptions as it
+ * requires calling `nvtxRangePop()` before all possible return points.
+ *
+ * NVTX++ solves this inconvenience through the "RAII" technique by providing a
+ * `nvtx3::thread_range` class that begins a range at construction and ends the
+ * range on destruction. The above example then becomes:
+ *
+ * ```
+ * void my_function(...){
+ * nvtx3::thread_range r{"my_function"}; // Begins NVTX range
+ * // do work
+ * } // Range ends on exit from `my_function` when `r` is destroyed
+ * ```
+ *
+ * The range object `r` is deterministically destroyed whenever `my_function`
+ * returns---ending the NVTX range without manual intervention. For more
+ * information, see \ref RANGES and `nvtx3::domain_thread_range`.
+ *
+ * Another inconvenience of the NVTX C APIs are the several constructs where the
+ * user is expected to initialize an object at the beginning of an application
+ * and reuse that object throughout the lifetime of the application. For example
+ * Domains, Categories, and Registered messages.
+ *
+ * Example:
+ * ```
+ * nvtxDomainHandle_t D = nvtxDomainCreateA("my domain");
+ * // Reuse `D` throughout the rest of the application
+ * ```
+ *
+ * This can be problematic if the user application or library does not have an
+ * explicit initialization function called before all other functions to
+ * ensure that these long-lived objects are initialized before being used.
+ *
+ * NVTX++ makes use of the "construct on first use" technique to alleviate this
+ * inconvenience. In short, a function local static object is constructed upon
+ * the first invocation of a function and returns a reference to that object on
+ * all future invocations. See the documentation for
+ * `nvtx3::registered_string`, `nvtx3::domain`, `nvtx3::named_category`, and
+ * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use for more
+ * information.
+ *
+ * Using construct on first use, the above example becomes:
+ * ```
+ * struct my_domain{ static constexpr char const* name{"my domain"}; };
+ *
+ * // The first invocation of `domain::get` for the type `my_domain` will
+ * // construct a `nvtx3::domain` object and return a reference to it. Future
+ * // invocations simply return a reference.
+ * nvtx3::domain const& D = nvtx3::domain::get<my_domain>();
+ * ```
+ * For more information about NVTX and how it can be used, see
+ * https://docs.nvidia.com/cuda/profiler-users-guide/index.html#nvtx and
+ * https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx/
+ * for more information.
+ *
+ * \section RANGES Ranges
+ *
+ * Ranges are used to describe a span of time during the execution of an
+ * application. Common examples are using ranges to annotate the time it takes
+ * to execute a function or an iteration of a loop.
+ *
+ * NVTX++ uses RAII to automate the generation of ranges that are tied to the
+ * lifetime of objects. Similar to `std::lock_guard` in the C++ Standard
+ * Template Library.
+ *
+ * \subsection THREAD_RANGE Thread Range
+ *
+ * `nvtx3::domain_thread_range` is a class that begins a range upon construction
+ * and ends the range at destruction. This is one of the most commonly used
+ * constructs in NVTX++ and is useful for annotating spans of time on a
+ * particular thread. These ranges can be nested to arbitrary depths.
+ *
+ * `nvtx3::thread_range` is an alias for a `nvtx3::domain_thread_range` in the
+ * global NVTX domain. For more information about Domains, see \ref DOMAINS.
+ *
+ * Various attributes of a range can be configured constructing a
+ * `nvtx3::domain_thread_range` with a `nvtx3::event_attributes` object. For
+ * more information, see \ref ATTRIBUTES.
+ *
+ * Example:
+ *
+ * \code{.cpp}
+ * void some_function(){
+ * // Creates a range for the duration of `some_function`
+ * nvtx3::thread_range r{};
+ *
+ * while(true){
+ * // Creates a range for every loop iteration
+ * // `loop_range` is nested inside `r`
+ * nvtx3::thread_range loop_range{};
+ * }
+ * }
+ * \endcode
+ *
+ * \subsection PROCESS_RANGE Process Range
+ *
+ * `nvtx3::domain_process_range` is identical to `nvtx3::domain_thread_range`
+ * with the exception that a `domain_process_range` can be created and destroyed
+ * on different threads. This is useful to annotate spans of time that can
+ * bridge multiple threads.
+ *
+ * `nvtx3::domain_thread_range`s should be preferred unless one needs the
+ * ability to begin and end a range on different threads.
+ *
+ * \section MARKS Marks
+ *
+ * `nvtx3::mark` allows annotating an instantaneous event in an application's
+ * timeline. For example, indicating when a mutex is locked or unlocked.
+ *
+ * \code{.cpp}
+ * std::mutex global_lock;
+ * void lock_mutex(){
+ * global_lock.lock();
+ * // Marks an event immediately after the mutex is locked
+ * nvtx3::mark<my_domain>("lock_mutex");
+ * }
+ * \endcode
+ *
+ * \section DOMAINS Domains
+ *
+ * Similar to C++ namespaces, Domains allow for scoping NVTX events. By default,
+ * all NVTX events belong to the "global" domain. Libraries and applications
+ * should scope their events to use a custom domain to differentiate where the
+ * events originate from.
+ *
+ * It is common for a library or application to have only a single domain and
+ * for the name of that domain to be known at compile time. Therefore, Domains
+ * in NVTX++ are represented by _tag types_.
+ *
+ * For example, to define a custom domain, simply define a new concrete type
+ * (a `class` or `struct`) with a `static` member called `name` that contains
+ * the desired name of the domain.
+ *
+ * ```
+ * struct my_domain{ static constexpr char const* name{"my domain"}; };
+ * ```
+ *
+ * For any NVTX++ construct that can be scoped to a domain, the type `my_domain`
+ * can be passed as an explicit template argument to scope it to the custom
+ * domain.
+ *
+ * The tag type `nvtx3::domain::global` represents the global NVTX domain.
+ *
+ * \code{.cpp}
+ * // By default, `domain_thread_range` belongs to the global domain
+ * nvtx3::domain_thread_range<> r0{};
+ *
+ * // Alias for a `domain_thread_range` in the global domain
+ * nvtx3::thread_range r1{};
+ *
+ * // `r` belongs to the custom domain
+ * nvtx3::domain_thread_range<my_domain> r{};
+ * \endcode
+ *
+ * When using a custom domain, it is reccomended to define type aliases for NVTX
+ * constructs in the custom domain.
+ * ```
+ * using my_thread_range = nvtx3::domain_thread_range<my_domain>;
+ * using my_registered_string = nvtx3::registered_string<my_domain>;
+ * using my_named_category = nvtx3::named_category<my_domain>;
+ * ```
+ *
+ * See `nvtx3::domain` for more information.
+ *
+ * \section ATTRIBUTES Event Attributes
+ *
+ * NVTX events can be customized with various attributes to provide additional
+ * information (such as a custom message) or to control visualization of the
+ * event (such as the color used). These attributes can be specified per-event
+ * via arguments to a `nvtx3::event_attributes` object.
+ *
+ * NVTX events can be customized via four "attributes":
+ * - \ref COLOR : color used to visualize the event in tools.
+ * - \ref MESSAGES : Custom message string.
+ * - \ref PAYLOAD : User-defined numerical value.
+ * - \ref CATEGORY : Intra-domain grouping.
+ *
+ * It is possible to construct a `nvtx3::event_attributes` from any number of
+ * attribute objects (nvtx3::color, nvtx3::message, nvtx3::payload,
+ * nvtx3::category) in any order. If an attribute is not specified, a tool
+ * specific default value is used. See `nvtx3::event_attributes` for more
+ * information.
+ *
+ * \code{.cpp}
+ * // Custom color, message
+ * event_attributes attr{nvtx3::rgb{127, 255, 0},
+ * "message"};
+ *
+ * // Custom color, message, payload, category
+ * event_attributes attr{nvtx3::rgb{127, 255, 0},
+ * nvtx3::payload{42},
+ * "message",
+ * nvtx3::category{1}};
+ *
+ * // Arguments can be in any order
+ * event_attributes attr{nvtx3::payload{42},
+ * nvtx3::category{1},
+ * "message",
+ * nvtx3::rgb{127, 255, 0}};
+ *
+ * // "First wins" with multiple arguments of the same type
+ * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload is
+ * 42 \endcode
+ *
+ * \subsection MESSAGES message
+ *
+ * A `nvtx3::message` allows associating a custom message string with an NVTX
+ * event.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Create an `event_attributes` with the custom message "my message"
+ * nvtx3::event_attributes attr{nvtx3::Mesage{"my message"}};
+ *
+ * // strings and string literals implicitly assumed to be a `nvtx3::message`
+ * nvtx3::event_attributes attr{"my message"};
+ * \endcode
+ *
+ * \subsubsection REGISTERED_MESSAGE Registered Messages
+ *
+ * Associating a `nvtx3::message` with an event requires copying the contents of
+ * the message every time the message is used, i.e., copying the entire message
+ * string. This may cause non-trivial overhead in performance sensitive code.
+ *
+ * To eliminate this overhead, NVTX allows registering a message string,
+ * yielding a "handle" that is inexpensive to copy that may be used in place of
+ * a message string. When visualizing the events, tools such as Nsight Systems
+ * will take care of mapping the message handle to its string.
+ *
+ * A message should be registered once and the handle reused throughout the rest
+ * of the application. This can be done by either explicitly creating static
+ * `nvtx3::registered_string` objects, or using the
+ * `nvtx3::registered_string::get` construct on first use helper (recommended).
+ *
+ * Similar to \ref DOMAINS, `nvtx3::registered_string::get` requires defining a
+ * custom tag type with a static `message` member whose value will be the
+ * contents of the registered string.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Explicitly constructed, static `registered_string`
+ * static registered_string<my_domain> static_message{"my message"};
+ *
+ * // Or use construct on first use:
+ * // Define a tag type with a `message` member string to register
+ * struct my_message{ static constexpr char const* message{ "my message" }; };
+ *
+ * // Uses construct on first use to register the contents of
+ * // `my_message::message`
+ * nvtx3::registered_string<my_domain> const& msg =
+ * nvtx3::registered_string<my_domain>::get<my_message>(); \endcode
+ *
+ * \subsection COLOR color
+ *
+ * Associating a `nvtx3::color` with an event allows controlling how the event
+ * is visualized in a tool such as Nsight Systems. This is a convenient way to
+ * visually differentiate among different events.
+ *
+ * \code{.cpp}
+ * // Define a color via rgb color values
+ * nvtx3::color c{nvtx3::rgb{127, 255, 0}};
+ * nvtx3::event_attributes attr{c};
+ *
+ * // rgb color values can be passed directly to an `event_attributes`
+ * nvtx3::event_attributes attr1{nvtx3::rgb{127,255,0}};
+ * \endcode
+ *
+ * \subsection CATEGORY category
+ *
+ * A `nvtx3::category` is simply an integer id that allows for fine-grain
+ * grouping of NVTX events. For example, one might use separate categories for
+ * IO, memory allocation, compute, etc.
+ *
+ * \code{.cpp}
+ * nvtx3::event_attributes{nvtx3::category{1}};
+ * \endcode
+ *
+ * \subsubsection NAMED_CATEGORIES Named Categories
+ *
+ * Associates a `name` string with a category `id` to help differentiate among
+ * categories.
+ *
+ * For any given category id `Id`, a `named_category{Id, "name"}` should only
+ * be constructed once and reused throughout an application. This can be done by
+ * either explicitly creating static `nvtx3::named_category` objects, or using
+ * the `nvtx3::named_category::get` construct on first use helper (recommended).
+ *
+ * Similar to \ref DOMAINS, `nvtx3::named_category::get` requires defining a
+ * custom tag type with static `name` and `id` members.
+ *
+ * \code{.cpp}
+ * // Explicitly constructed, static `named_category`
+ * static nvtx3::named_category static_category{42, "my category"};
+ *
+ * // OR use construct on first use:
+ * // Define a tag type with `name` and `id` members
+ * struct my_category{
+ * static constexpr char const* name{"my category"}; // category name
+ * static constexpr category::id_type id{42}; // category id
+ * };
+ *
+ * // Use construct on first use to name the category id `42`
+ * // with name "my category"
+ * nvtx3::named_category const& my_category =
+ * named_category<my_domain>::get<my_category>();
+ *
+ * // Range `r` associated with category id `42`
+ * nvtx3::event_attributes attr{my_category};
+ * \endcode
+ *
+ * \subsection PAYLOAD payload
+ *
+ * Allows associating a user-defined numerical value with an event.
+ *
+ * ```
+ * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload
+ * from
+ * // the `int32_t` value 42
+ * ```
+ *
+ *
+ * \section EXAMPLE Example
+ *
+ * Putting it all together:
+ * \code{.cpp}
+ * // Define a custom domain tag type
+ * struct my_domain{ static constexpr char const* name{"my domain"}; };
+ *
+ * // Define a named category tag type
+ * struct my_category{
+ * static constexpr char const* name{"my category"};
+ * static constexpr uint32_t id{42};
+ * };
+ *
+ * // Define a registered string tag type
+ * struct my_message{ static constexpr char const* message{"my message"}; };
+ *
+ * // For convenience, use aliases for domain scoped objects
+ * using my_thread_range = nvtx3::domain_thread_range<my_domain>;
+ * using my_registered_string = nvtx3::registered_string<my_domain>;
+ * using my_named_category = nvtx3::named_category<my_domain>;
+ *
+ * // Default values for all attributes
+ * nvtx3::event_attributes attr{};
+ * my_thread_range r0{attr};
+ *
+ * // Custom (unregistered) message, and unnamed category
+ * nvtx3::event_attributes attr1{"message", nvtx3::category{2}};
+ * my_thread_range r1{attr1};
+ *
+ * // Alternatively, pass arguments of `event_attributes` ctor directly to
+ * // `my_thread_range`
+ * my_thread_range r2{"message", nvtx3::category{2}};
+ *
+ * // construct on first use a registered string
+ * auto msg = my_registered_string::get<my_message>();
+ *
+ * // construct on first use a named category
+ * auto category = my_named_category::get<my_category>();
+ *
+ * // Use registered string and named category
+ * my_thread_range r3{msg, category, nvtx3::rgb{127, 255, 0},
+ * nvtx3::payload{42}};
+ *
+ * // Any number of arguments in any order
+ * my_thread_range r{nvtx3::rgb{127, 255,0}, msg};
+ *
+ * \endcode
+ * \section MACROS Convenience Macros
+ *
+ * Oftentimes users want to quickly and easily add NVTX ranges to their library
+ * or application to aid in profiling and optimization.
+ *
+ * A convenient way to do this is to use the \ref NVTX3_FUNC_RANGE and
+ * \ref NVTX3_FUNC_RANGE_IN macros. These macros take care of constructing an
+ * `nvtx3::domain_thread_range` with the name of the enclosing function as the
+ * range's message.
+ *
+ * \code{.cpp}
+ * void some_function(){
+ * // Automatically generates an NVTX range for the duration of the function
+ * // using "some_function" as the event's message.
+ * NVTX3_FUNC_RANGE();
+ * }
+ * \endcode
+ *
+ */
+
+/* Temporary helper #defines, removed with #undef at end of header */
+
+/* Within this header, nvtx3::NVTX3_VERSION_NAMESPACE resolves to nvtx3::vX,
+ * where "X" is the major version number. */
+#define NVTX3_CONCAT(A, B) A##B
+#define NVTX3_NAMESPACE_FOR(VERSION) NVTX3_CONCAT(v, VERSION)
+#define NVTX3_VERSION_NAMESPACE NVTX3_NAMESPACE_FOR(NVTX3_CPP_VERSION_MAJOR)
+
+/* Avoid duplicating #if defined(NVTX3_INLINE_THIS_VERSION) for namespaces
+ * in each minor version by making a macro to use unconditionally, which
+ * resolves to "inline" or nothing as appropriate. */
+#if defined(NVTX3_INLINE_THIS_VERSION)
+#define NVTX3_INLINE_IF_REQUESTED inline
+#else
+#define NVTX3_INLINE_IF_REQUESTED
+#endif
+
+/* Enables the use of constexpr when support for C++14 relaxed constexpr
+ * is present.
+ *
+ * Initializing a legacy-C (i.e., no constructor) union member requires
+ * initializing in the constructor body. Non-empty constexpr constructors
+ * require C++14 relaxed constexpr. In strict C++11 compilation, fall back
+ * to using non-constexpr constructors for classes with union members.
+ */
+#if __cpp_constexpr >= 201304L
+#define NVTX3_RELAXED_CONSTEXPR constexpr
+#else
+#define NVTX3_RELAXED_CONSTEXPR
+#endif
+
+/* Implementation sections, enclosed in guard macros for each minor version */
+
+#ifndef NVTX3_CPP_DEFINITIONS_V1_0
+#define NVTX3_CPP_DEFINITIONS_V1_0
+
+namespace nvtx3 {
+
+NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE
+{
+
+namespace detail {
+
+/**
+ * @brief Verifies if a type `T` contains a member `T::name` of type `const
+ * char*` or `const wchar_t*`.
+ *
+ * @tparam T The type to verify
+ * @return True if `T` contains a member `T::name` of type `const char*` or
+ * `const wchar_t*`.
+ */
+template <typename T>
+constexpr auto has_name_member() noexcept -> decltype(T::name, bool())
+{
+ return (std::is_same<char const*, typename std::decay<decltype(T::name)>::type>::value ||
+ std::is_same<wchar_t const*, typename std::decay<decltype(T::name)>::type>::value);
+}
+} // namespace detail
+
+/**
+ * @brief `domain`s allow for grouping NVTX events into a single scope to
+ * differentiate them from events in other `domain`s.
+ *
+ * By default, all NVTX constructs are placed in the "global" NVTX domain.
+ *
+ * A custom `domain` may be used in order to differentiate a library's or
+ * application's NVTX events from other events.
+ *
+ * `domain`s are expected to be long-lived and unique to a library or
+ * application. As such, it is assumed a domain's name is known at compile
+ * time. Therefore, all NVTX constructs that can be associated with a domain
+ * require the domain to be specified via a *type* `DomainName` passed as an
+ * explicit template parameter.
+ *
+ * The type `domain::global` may be used to indicate that the global NVTX
+ * domain should be used.
+ *
+ * None of the C++ NVTX constructs require the user to manually construct a
+ * `domain` object. Instead, if a custom domain is desired, the user is
+ * expected to define a type `DomainName` that contains a member
+ * `DomainName::name` which resolves to either a `char const*` or `wchar_t
+ * const*`. The value of `DomainName::name` is used to name and uniquely
+ * identify the custom domain.
+ *
+ * Upon the first use of an NVTX construct associated with the type
+ * `DomainName`, the "construct on first use" pattern is used to construct a
+ * function local static `domain` object. All future NVTX constructs
+ * associated with `DomainType` will use a reference to the previously
+ * constructed `domain` object. See `domain::get`.
+ *
+ * Example:
+ * ```
+ * // The type `my_domain` defines a `name` member used to name and identify
+ * the
+ * // `domain` object identified by `my_domain`.
+ * struct my_domain{ static constexpr char const* name{"my_domain"}; };
+ *
+ * // The NVTX range `r` will be grouped with all other NVTX constructs
+ * // associated with `my_domain`.
+ * nvtx3::domain_thread_range<my_domain> r{};
+ *
+ * // An alias can be created for a `domain_thread_range` in the custom domain
+ * using my_thread_range = nvtx3::domain_thread_range<my_domain>;
+ * my_thread_range my_range{};
+ *
+ * // `domain::global` indicates that the global NVTX domain is used
+ * nvtx3::domain_thread_range<domain::global> r2{};
+ *
+ * // For convenience, `nvtx3::thread_range` is an alias for a range in the
+ * // global domain
+ * nvtx3::thread_range r3{};
+ * ```
+ */
+class domain {
+ public:
+ domain(domain const&) = delete;
+ domain& operator=(domain const&) = delete;
+ domain(domain&&) = delete;
+ domain& operator=(domain&&) = delete;
+
+ /**
+ * @brief Returns reference to an instance of a function local static
+ * `domain` object.
+ *
+ * Uses the "construct on first use" idiom to safely ensure the `domain`
+ * object is initialized exactly once upon first invocation of
+ * `domain::get<DomainName>()`. All following invocations will return a
+ * reference to the previously constructed `domain` object. See
+ * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use
+ *
+ * None of the constructs in this header require the user to directly invoke
+ * `domain::get`. It is automatically invoked when constructing objects like
+ * a `domain_thread_range` or `category`. Advanced users may wish to use
+ * `domain::get` for the convenience of the "construct on first use" idiom
+ * when using domains with their own use of the NVTX C API.
+ *
+ * This function is threadsafe as of C++11. If two or more threads call
+ * `domain::get<DomainName>` concurrently, exactly one of them is guaranteed
+ * to construct the `domain` object and the other(s) will receive a
+ * reference to the object after it is fully constructed.
+ *
+ * The domain's name is specified via the type `DomainName` pass as an
+ * explicit template parameter. `DomainName` is required to contain a
+ * member `DomainName::name` that resolves to either a `char const*` or
+ * `wchar_t const*`. The value of `DomainName::name` is used to name and
+ * uniquely identify the `domain`.
+ *
+ * Example:
+ * ```
+ * // The type `my_domain` defines a `name` member used to name and identify
+ * // the `domain` object identified by `my_domain`.
+ * struct my_domain{ static constexpr char const* name{"my domain"}; };
+ *
+ * auto D = domain::get<my_domain>(); // First invocation constructs a
+ * // `domain` with the name "my domain"
+ *
+ * auto D1 = domain::get<my_domain>(); // Simply returns reference to
+ * // previously constructed `domain`.
+ * ```
+ *
+ * @tparam DomainName Type that contains a `DomainName::name` member used to
+ * name the `domain` object.
+ * @return Reference to the `domain` corresponding to the type `DomainName`.
+ */
+ template <typename DomainName>
+ static domain const& get()
+ {
+ static_assert(detail::has_name_member<DomainName>(),
+ "Type used to identify a domain must contain a name member of"
+ "type const char* or const wchar_t*");
+ static domain const d{DomainName::name};
+ return d;
+ }
+
+ /**
+ * @brief Conversion operator to `nvtxDomainHandle_t`.
+ *
+ * Allows transparently passing a domain object into an API expecting a
+ * native `nvtxDomainHandle_t` object.
+ */
+ operator nvtxDomainHandle_t() const noexcept { return _domain; }
+
+ /**
+ * @brief Tag type for the "global" NVTX domain.
+ *
+ * This type may be passed as a template argument to any function/class
+ * expecting a type to identify a domain to indicate that the global domain
+ * should be used.
+ *
+ * All NVTX events in the global domain across all libraries and
+ * applications will be grouped together.
+ *
+ */
+ struct global {
+ };
+
+ private:
+ /**
+ * @brief Construct a new domain with the specified `name`.
+ *
+ * This constructor is private as it is intended that `domain` objects only
+ * be created through the `domain::get` function.
+ *
+ * @param name A unique name identifying the domain
+ */
+ explicit domain(char const* name) noexcept : _domain{nvtxDomainCreateA(name)} {}
+
+ /**
+ * @brief Construct a new domain with the specified `name`.
+ *
+ * This constructor is private as it is intended that `domain` objects only
+ * be created through the `domain::get` function.
+ *
+ * @param name A unique name identifying the domain
+ */
+ explicit domain(wchar_t const* name) noexcept : _domain{nvtxDomainCreateW(name)} {}
+
+ /**
+ * @brief Construct a new domain with the specified `name`.
+ *
+ * This constructor is private as it is intended that `domain` objects only
+ * be created through the `domain::get` function.
+ *
+ * @param name A unique name identifying the domain
+ */
+ explicit domain(std::string const& name) noexcept : domain{name.c_str()} {}
+
+ /**
+ * @brief Construct a new domain with the specified `name`.
+ *
+ * This constructor is private as it is intended that `domain` objects only
+ * be created through the `domain::get` function.
+ *
+ * @param name A unique name identifying the domain
+ */
+ explicit domain(std::wstring const& name) noexcept : domain{name.c_str()} {}
+
+ /**
+ * @brief Default constructor creates a `domain` representing the
+ * "global" NVTX domain.
+ *
+ * All events not associated with a custom `domain` are grouped in the
+ * "global" NVTX domain.
+ *
+ */
+ domain() = default;
+
+ /**
+ * @brief Intentionally avoid calling nvtxDomainDestroy on the `domain` object.
+ *
+ * No currently-available tools attempt to free domain resources when the
+ * nvtxDomainDestroy function is called, due to the thread-safety and
+ * efficiency challenges of freeing thread-local storage for other threads.
+ * Since libraries may be disallowed from introducing static destructors,
+ * and destroying the domain is likely to have no effect, the destructor
+ * for `domain` intentionally chooses to not destroy the domain.
+ *
+ * In a situation where domain destruction is necessary, either manually
+ * call nvtxDomainDestroy on the domain's handle, or make a class that
+ * derives from `domain` and calls nvtxDomainDestroy in its destructor.
+ */
+ ~domain() = default;
+
+ private:
+ nvtxDomainHandle_t const _domain{}; ///< The `domain`s NVTX handle
+};
+
+/**
+ * @brief Returns reference to the `domain` object that represents the global
+ * NVTX domain.
+ *
+ * This specialization for `domain::global` returns a default constructed,
+ * `domain` object for use when the "global" domain is desired.
+ *
+ * All NVTX events in the global domain across all libraries and applications
+ * will be grouped together.
+ *
+ * @return Reference to the `domain` corresponding to the global NVTX domain.
+ *
+ */
+template <>
+inline domain const& domain::get<domain::global>()
+{
+ static domain const d{};
+ return d;
+}
+
+/**
+ * @brief Indicates the values of the red, green, blue color channels for
+ * a rgb color code.
+ *
+ */
+struct rgb {
+ /// Type used for component values
+ using component_type = uint8_t;
+
+ /**
+ * @brief Construct a rgb with red, green, and blue channels
+ * specified by `red_`, `green_`, and `blue_`, respectively.
+ *
+ * Valid values are in the range `[0,255]`.
+ *
+ * @param red_ Value of the red channel
+ * @param green_ Value of the green channel
+ * @param blue_ Value of the blue channel
+ */
+ constexpr rgb(component_type red_, component_type green_, component_type blue_) noexcept
+ : red{red_}, green{green_}, blue{blue_}
+ {
+ }
+
+ component_type const red{}; ///< Red channel value
+ component_type const green{}; ///< Green channel value
+ component_type const blue{}; ///< Blue channel value
+};
+
+/**
+ * @brief Indicates the value of the alpha, red, green, and blue color
+ * channels for an argb color code.
+ *
+ */
+struct argb final : rgb {
+ /**
+ * @brief Construct an argb with alpha, red, green, and blue channels
+ * specified by `alpha_`, `red_`, `green_`, and `blue_`, respectively.
+ *
+ * Valid values are in the range `[0,255]`.
+ *
+ * @param alpha_ Value of the alpha channel (opacity)
+ * @param red_ Value of the red channel
+ * @param green_ Value of the green channel
+ * @param blue_ Value of the blue channel
+ *
+ */
+ constexpr argb(component_type alpha_,
+ component_type red_,
+ component_type green_,
+ component_type blue_) noexcept
+ : rgb{red_, green_, blue_}, alpha{alpha_}
+ {
+ }
+
+ component_type const alpha{}; ///< Alpha channel value
+};
+
+/**
+ * @brief Represents a custom color that can be associated with an NVTX event
+ * via it's `event_attributes`.
+ *
+ * Specifying colors for NVTX events is a convenient way to visually
+ * differentiate among different events in a visualization tool such as Nsight
+ * Systems.
+ *
+ */
+class color {
+ public:
+ /// Type used for the color's value
+ using value_type = uint32_t;
+
+ /**
+ * @brief Constructs a `color` using the value provided by `hex_code`.
+ *
+ * `hex_code` is expected to be a 4 byte argb hex code.
+ *
+ * The most significant byte indicates the value of the alpha channel
+ * (opacity) (0-255)
+ *
+ * The next byte indicates the value of the red channel (0-255)
+ *
+ * The next byte indicates the value of the green channel (0-255)
+ *
+ * The least significant byte indicates the value of the blue channel
+ * (0-255)
+ *
+ * @param hex_code The hex code used to construct the `color`
+ */
+ constexpr explicit color(value_type hex_code) noexcept : _value{hex_code} {}
+
+ /**
+ * @brief Construct a `color` using the alpha, red, green, blue components
+ * in `argb`.
+ *
+ * @param argb The alpha, red, green, blue components of the desired `color`
+ */
+ constexpr color(argb argb) noexcept
+ : color{from_bytes_msb_to_lsb(argb.alpha, argb.red, argb.green, argb.blue)}
+ {
+ }
+
+ /**
+ * @brief Construct a `color` using the red, green, blue components in
+ * `rgb`.
+ *
+ * Uses maximum value for the alpha channel (opacity) of the `color`.
+ *
+ * @param rgb The red, green, blue components of the desired `color`
+ */
+ constexpr color(rgb rgb) noexcept
+ : color{from_bytes_msb_to_lsb(0xFF, rgb.red, rgb.green, rgb.blue)}
+ {
+ }
+
+ /**
+ * @brief Returns the `color`s argb hex code
+ *
+ */
+ constexpr value_type get_value() const noexcept { return _value; }
+
+ /**
+ * @brief Return the NVTX color type of the color.
+ *
+ */
+ constexpr nvtxColorType_t get_type() const noexcept { return _type; }
+
+ color() = delete;
+ ~color() = default;
+ color(color const&) = default;
+ color& operator=(color const&) = default;
+ color(color&&) = default;
+ color& operator=(color&&) = default;
+
+ private:
+ /**
+ * @brief Constructs an unsigned, 4B integer from the component bytes in
+ * most to least significant byte order.
+ *
+ */
+ constexpr static value_type from_bytes_msb_to_lsb(uint8_t byte3,
+ uint8_t byte2,
+ uint8_t byte1,
+ uint8_t byte0) noexcept
+ {
+ return uint32_t{byte3} << 24 | uint32_t{byte2} << 16 | uint32_t{byte1} << 8 | uint32_t{byte0};
+ }
+
+ value_type const _value{}; ///< color's argb color code
+ nvtxColorType_t const _type{NVTX_COLOR_ARGB}; ///< NVTX color type code
+};
+
+/**
+ * @brief Object for intra-domain grouping of NVTX events.
+ *
+ * A `category` is simply an integer id that allows for fine-grain grouping of
+ * NVTX events. For example, one might use separate categories for IO, memory
+ * allocation, compute, etc.
+ *
+ * Example:
+ * \code{.cpp}
+ * nvtx3::category cat1{1};
+ *
+ * // Range `r1` belongs to the category identified by the value `1`.
+ * nvtx3::thread_range r1{cat1};
+ *
+ * // Range `r2` belongs to the same category as `r1`
+ * nvtx3::thread_range r2{nvtx3::category{1}};
+ * \endcode
+ *
+ * To associate a name string with a category id, see `named_category`.
+ *
+ */
+class category {
+ public:
+ /// Type used for `category`s integer id.
+ using id_type = uint32_t;
+
+ /**
+ * @brief Construct a `category` with the specified `id`.
+ *
+ * The `category` will be unnamed and identified only by its `id` value.
+ *
+ * All `category` objects sharing the same `id` are equivalent.
+ *
+ * @param[in] id The `category`'s identifying value
+ */
+ constexpr explicit category(id_type id) noexcept : id_{id} {}
+
+ /**
+ * @brief Returns the id of the category.
+ *
+ */
+ constexpr id_type get_id() const noexcept { return id_; }
+
+ category() = delete;
+ ~category() = default;
+ category(category const&) = default;
+ category& operator=(category const&) = default;
+ category(category&&) = default;
+ category& operator=(category&&) = default;
+
+ private:
+ id_type const id_{}; ///< category's unique identifier
+};
+
+/**
+ * @brief A `category` with an associated name string.
+ *
+ * Associates a `name` string with a category `id` to help differentiate among
+ * categories.
+ *
+ * For any given category id `Id`, a `named_category(Id, "name")` should only
+ * be constructed once and reused throughout an application. This can be done
+ * by either explicitly creating static `named_category` objects, or using the
+ * `named_category::get` construct on first use helper (recommended).
+ *
+ * Creating two or more `named_category` objects with the same value for `id`
+ * in the same domain results in undefined behavior.
+ *
+ * Similarly, behavior is undefined when a `named_category` and `category`
+ * share the same value of `id`.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Explicitly constructed, static `named_category`
+ * static nvtx3::named_category static_category{42, "my category"};
+ *
+ * // Range `r` associated with category id `42`
+ * nvtx3::thread_range r{static_category};
+ *
+ * // OR use construct on first use:
+ *
+ * // Define a type with `name` and `id` members
+ * struct my_category{
+ * static constexpr char const* name{"my category"}; // category name
+ * static constexpr category::id_type id{42}; // category id
+ * };
+ *
+ * // Use construct on first use to name the category id `42`
+ * // with name "my category"
+ * auto my_category = named_category<my_domain>::get<my_category>();
+ *
+ * // Range `r` associated with category id `42`
+ * nvtx3::thread_range r{my_category};
+ * \endcode
+ *
+ * `named_category`'s association of a name to a category id is local to the
+ * domain specified by the type `D`. An id may have a different name in
+ * another domain.
+ *
+ * @tparam D Type containing `name` member used to identify the `domain` to
+ * which the `named_category` belongs. Else, `domain::global` to indicate
+ * that the global NVTX domain should be used.
+ */
+template <typename D = domain::global>
+class named_category final : public category {
+ public:
+ /**
+ * @brief Returns a global instance of a `named_category` as a
+ * function-local static.
+ *
+ * Creates a `named_category` with name and id specified by the contents of
+ * a type `C`. `C::name` determines the name and `C::id` determines the
+ * category id.
+ *
+ * This function is useful for constructing a named `category` exactly once
+ * and reusing the same instance throughout an application.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Define a type with `name` and `id` members
+ * struct my_category{
+ * static constexpr char const* name{"my category"}; // category name
+ * static constexpr uint32_t id{42}; // category id
+ * };
+ *
+ * // Use construct on first use to name the category id `42`
+ * // with name "my category"
+ * auto cat = named_category<my_domain>::get<my_category>();
+ *
+ * // Range `r` associated with category id `42`
+ * nvtx3::thread_range r{cat};
+ * \endcode
+ *
+ * Uses the "construct on first use" idiom to safely ensure the `category`
+ * object is initialized exactly once. See
+ * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use
+ *
+ * @tparam C Type containing a member `C::name` that resolves to either a
+ * `char const*` or `wchar_t const*` and `C::id`.
+ */
+ template <typename C>
+ static named_category<D> const& get() noexcept
+ {
+ static_assert(detail::has_name_member<C>(),
+ "Type used to name a category must contain a name member.");
+ static named_category<D> const category{C::id, C::name};
+ return category;
+ }
+ /**
+ * @brief Construct a `category` with the specified `id` and `name`.
+ *
+ * The name `name` will be registered with `id`.
+ *
+ * Every unique value of `id` should only be named once.
+ *
+ * @param[in] id The category id to name
+ * @param[in] name The name to associated with `id`
+ */
+ named_category(id_type id, char const* name) noexcept : category{id}
+ {
+#ifndef NVTX_DISABLE
+ nvtxDomainNameCategoryA(domain::get<D>(), get_id(), name);
+#else
+ (void)id;
+ (void)name;
+#endif
+ };
+
+ /**
+ * @brief Construct a `category` with the specified `id` and `name`.
+ *
+ * The name `name` will be registered with `id`.
+ *
+ * Every unique value of `id` should only be named once.
+ *
+ * @param[in] id The category id to name
+ * @param[in] name The name to associated with `id`
+ */
+ named_category(id_type id, wchar_t const* name) noexcept : category{id}
+ {
+#ifndef NVTX_DISABLE
+ nvtxDomainNameCategoryW(domain::get<D>(), get_id(), name);
+#else
+ (void)id;
+ (void)name;
+#endif
+ };
+};
+
+/**
+ * @brief A message registered with NVTX.
+ *
+ * Normally, associating a `message` with an NVTX event requires copying the
+ * contents of the message string. This may cause non-trivial overhead in
+ * highly performance sensitive regions of code.
+ *
+ * message registration is an optimization to lower the overhead of
+ * associating a message with an NVTX event. Registering a message yields a
+ * handle that is inexpensive to copy that may be used in place of a message
+ * string.
+ *
+ * A particular message should only be registered once and the handle
+ * reused throughout the rest of the application. This can be done by either
+ * explicitly creating static `registered_string` objects, or using the
+ * `registered_string::get` construct on first use helper (recommended).
+ *
+ * Example:
+ * \code{.cpp}
+ * // Explicitly constructed, static `registered_string`
+ * static registered_string<my_domain> static_message{"message"};
+ *
+ * // "message" is associated with the range `r`
+ * nvtx3::thread_range r{static_message};
+ *
+ * // Or use construct on first use:
+ *
+ * // Define a type with a `message` member that defines the contents of the
+ * // registered string
+ * struct my_message{ static constexpr char const* message{ "my message" }; };
+ *
+ * // Uses construct on first use to register the contents of
+ * // `my_message::message`
+ * auto msg = registered_string<my_domain>::get<my_message>();
+ *
+ * // "my message" is associated with the range `r`
+ * nvtx3::thread_range r{msg};
+ * \endcode
+ *
+ * `registered_string`s are local to a particular domain specified via
+ * the type `D`.
+ *
+ * @tparam D Type containing `name` member used to identify the `domain` to
+ * which the `registered_string` belongs. Else, `domain::global` to indicate
+ * that the global NVTX domain should be used.
+ */
+template <typename D = domain::global>
+class registered_string {
+ public:
+ /**
+ * @brief Returns a global instance of a `registered_string` as a function
+ * local static.
+ *
+ * Provides a convenient way to register a message with NVTX without having
+ * to explicitly register the message.
+ *
+ * Upon first invocation, constructs a `registered_string` whose contents
+ * are specified by `message::message`.
+ *
+ * All future invocations will return a reference to the object constructed
+ * in the first invocation.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Define a type with a `message` member that defines the contents of the
+ * // registered string
+ * struct my_message{ static constexpr char const* message{ "my message" };
+ * };
+ *
+ * // Uses construct on first use to register the contents of
+ * // `my_message::message`
+ * auto msg = registered_string<my_domain>::get<my_message>();
+ *
+ * // "my message" is associated with the range `r`
+ * nvtx3::thread_range r{msg};
+ * \endcode
+ *
+ * @tparam M Type required to contain a member `M::message` that
+ * resolves to either a `char const*` or `wchar_t const*` used as the
+ * registered string's contents.
+ * @return Reference to a `registered_string` associated with the type `M`.
+ */
+ template <typename M>
+ static registered_string<D> const& get() noexcept
+ {
+ static registered_string<D> const registered_string{M::message};
+ return registered_string;
+ }
+
+ /**
+ * @brief Constructs a `registered_string` from the specified `msg` string.
+ *
+ * Registers `msg` with NVTX and associates a handle with the registered
+ * message.
+ *
+ * A particular message should should only be registered once and the handle
+ * reused throughout the rest of the application.
+ *
+ * @param msg The contents of the message
+ */
+ explicit registered_string(char const* msg) noexcept
+ : handle_{nvtxDomainRegisterStringA(domain::get<D>(), msg)}
+ {
+ }
+
+ /**
+ * @brief Constructs a `registered_string` from the specified `msg` string.
+ *
+ * Registers `msg` with NVTX and associates a handle with the registered
+ * message.
+ *
+ * A particular message should should only be registered once and the handle
+ * reused throughout the rest of the application.
+ *
+ * @param msg The contents of the message
+ */
+ explicit registered_string(std::string const& msg) noexcept : registered_string{msg.c_str()} {}
+
+ /**
+ * @brief Constructs a `registered_string` from the specified `msg` string.
+ *
+ * Registers `msg` with NVTX and associates a handle with the registered
+ * message.
+ *
+ * A particular message should should only be registered once and the handle
+ * reused throughout the rest of the application.
+ *
+ * @param msg The contents of the message
+ */
+ explicit registered_string(wchar_t const* msg) noexcept
+ : handle_{nvtxDomainRegisterStringW(domain::get<D>(), msg)}
+ {
+ }
+
+ /**
+ * @brief Constructs a `registered_string` from the specified `msg` string.
+ *
+ * Registers `msg` with NVTX and associates a handle with the registered
+ * message.
+ *
+ * A particular message should only be registered once and the handle
+ * reused throughout the rest of the application.
+ *
+ * @param msg The contents of the message
+ */
+ explicit registered_string(std::wstring const& msg) noexcept : registered_string{msg.c_str()} {}
+
+ /**
+ * @brief Returns the registered string's handle
+ *
+ */
+ nvtxStringHandle_t get_handle() const noexcept { return handle_; }
+
+ registered_string() = delete;
+ ~registered_string() = default;
+ registered_string(registered_string const&) = default;
+ registered_string& operator=(registered_string const&) = default;
+ registered_string(registered_string&&) = default;
+ registered_string& operator=(registered_string&&) = default;
+
+ private:
+ nvtxStringHandle_t const handle_{}; ///< The handle returned from
+ ///< registering the message with NVTX
+};
+
+/**
+ * @brief Allows associating a message string with an NVTX event via
+ * its `EventAttribute`s.
+ *
+ * Associating a `message` with an NVTX event through its `event_attributes`
+ * allows for naming events to easily differentiate them from other events.
+ *
+ * Every time an NVTX event is created with an associated `message`, the
+ * contents of the message string must be copied. This may cause non-trivial
+ * overhead in highly performance sensitive sections of code. Use of a
+ * `nvtx3::registered_string` is recommended in these situations.
+ *
+ * Example:
+ * \code{.cpp}
+ * // Creates an `event_attributes` with message "message 0"
+ * nvtx3::event_attributes attr0{nvtx3::message{"message 0"}};
+ *
+ * // `range0` contains message "message 0"
+ * nvtx3::thread_range range0{attr0};
+ *
+ * // `std::string` and string literals are implicitly assumed to be
+ * // the contents of an `nvtx3::message`
+ * // Creates an `event_attributes` with message "message 1"
+ * nvtx3::event_attributes attr1{"message 1"};
+ *
+ * // `range1` contains message "message 1"
+ * nvtx3::thread_range range1{attr1};
+ *
+ * // `range2` contains message "message 2"
+ * nvtx3::thread_range range2{nvtx3::Mesage{"message 2"}};
+ *
+ * // `std::string` and string literals are implicitly assumed to be
+ * // the contents of an `nvtx3::message`
+ * // `range3` contains message "message 3"
+ * nvtx3::thread_range range3{"message 3"};
+ * \endcode
+ */
+class message {
+ public:
+ using value_type = nvtxMessageValue_t;
+
+ /**
+ * @brief Construct a `message` whose contents are specified by `msg`.
+ *
+ * @param msg The contents of the message
+ */
+ NVTX3_RELAXED_CONSTEXPR message(char const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_ASCII}
+ {
+ value_.ascii = msg;
+ }
+
+ /**
+ * @brief Construct a `message` whose contents are specified by `msg`.
+ *
+ * @param msg The contents of the message
+ */
+ message(std::string const& msg) noexcept : message{msg.c_str()} {}
+
+ /**
+ * @brief Disallow construction for `std::string` r-value
+ *
+ * `message` is a non-owning type and therefore cannot take ownership of an
+ * r-value. Therefore, constructing from an r-value is disallowed to prevent
+ * a dangling pointer.
+ *
+ */
+ message(std::string&&) = delete;
+
+ /**
+ * @brief Construct a `message` whose contents are specified by `msg`.
+ *
+ * @param msg The contents of the message
+ */
+ NVTX3_RELAXED_CONSTEXPR message(wchar_t const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_UNICODE}
+ {
+ value_.unicode = msg;
+ }
+
+ /**
+ * @brief Construct a `message` whose contents are specified by `msg`.
+ *
+ * @param msg The contents of the message
+ */
+ message(std::wstring const& msg) noexcept : message{msg.c_str()} {}
+
+ /**
+ * @brief Disallow construction for `std::wstring` r-value
+ *
+ * `message` is a non-owning type and therefore cannot take ownership of an
+ * r-value. Therefore, constructing from an r-value is disallowed to prevent
+ * a dangling pointer.
+ *
+ */
+ message(std::wstring&&) = delete;
+
+ /**
+ * @brief Construct a `message` from a `registered_string`.
+ *
+ * @tparam D Type containing `name` member used to identify the `domain`
+ * to which the `registered_string` belongs. Else, `domain::global` to
+ * indicate that the global NVTX domain should be used.
+ * @param msg The message that has already been registered with NVTX.
+ */
+ template <typename D>
+ NVTX3_RELAXED_CONSTEXPR message(registered_string<D> const& msg) noexcept
+ : type_{NVTX_MESSAGE_TYPE_REGISTERED}
+ {
+ value_.registered = msg.get_handle();
+ }
+
+ /**
+ * @brief Return the union holding the value of the message.
+ *
+ */
+ NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; }
+
+ /**
+ * @brief Return the type information about the value the union holds.
+ *
+ */
+ NVTX3_RELAXED_CONSTEXPR nvtxMessageType_t get_type() const noexcept { return type_; }
+
+ private:
+ nvtxMessageType_t const type_{}; ///< message type
+ nvtxMessageValue_t value_{}; ///< message contents
+};
+
+/**
+ * @brief A numerical value that can be associated with an NVTX event via
+ * its `event_attributes`.
+ *
+ * Example:
+ * ```
+ * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload
+ * from
+ * // the `int32_t` value 42
+ *
+ * // `range0` will have an int32_t payload of 42
+ * nvtx3::thread_range range0{attr};
+ *
+ * // range1 has double payload of 3.14
+ * nvtx3::thread_range range1{ nvtx3::payload{3.14} };
+ * ```
+ */
+class payload {
+ public:
+ using value_type = typename nvtxEventAttributes_v2::payload_t;
+
+ /**
+ * @brief Construct a `payload` from a signed, 8 byte integer.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(int64_t value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_INT64}, value_{}
+ {
+ value_.llValue = value;
+ }
+
+ /**
+ * @brief Construct a `payload` from a signed, 4 byte integer.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(int32_t value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_INT32}, value_{}
+ {
+ value_.iValue = value;
+ }
+
+ /**
+ * @brief Construct a `payload` from an unsigned, 8 byte integer.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(uint64_t value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT64}, value_{}
+ {
+ value_.ullValue = value;
+ }
+
+ /**
+ * @brief Construct a `payload` from an unsigned, 4 byte integer.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(uint32_t value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT32}, value_{}
+ {
+ value_.uiValue = value;
+ }
+
+ /**
+ * @brief Construct a `payload` from a single-precision floating point
+ * value.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(float value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_FLOAT}, value_{}
+ {
+ value_.fValue = value;
+ }
+
+ /**
+ * @brief Construct a `payload` from a double-precision floating point
+ * value.
+ *
+ * @param value Value to use as contents of the payload
+ */
+ NVTX3_RELAXED_CONSTEXPR explicit payload(double value) noexcept
+ : type_{NVTX_PAYLOAD_TYPE_DOUBLE}, value_{}
+ {
+ value_.dValue = value;
+ }
+
+ /**
+ * @brief Return the union holding the value of the payload
+ *
+ */
+ NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; }
+
+ /**
+ * @brief Return the information about the type the union holds.
+ *
+ */
+ NVTX3_RELAXED_CONSTEXPR nvtxPayloadType_t get_type() const noexcept { return type_; }
+
+ private:
+ nvtxPayloadType_t const type_; ///< Type of the payload value
+ value_type value_; ///< Union holding the payload value
+};
+
+/**
+ * @brief Describes the attributes of a NVTX event.
+ *
+ * NVTX events can be customized via four "attributes":
+ *
+ * - color: color used to visualize the event in tools such as Nsight
+ * Systems. See `color`.
+ * - message: Custom message string. See `message`.
+ * - payload: User-defined numerical value. See `payload`.
+ * - category: Intra-domain grouping. See `category`.
+ *
+ * These component attributes are specified via an `event_attributes` object.
+ * See `nvtx3::color`, `nvtx3::message`, `nvtx3::payload`, and
+ * `nvtx3::category` for how these individual attributes are constructed.
+ *
+ * While it is possible to specify all four attributes, it is common to want
+ * to only specify a subset of attributes and use default values for the
+ * others. For convenience, `event_attributes` can be constructed from any
+ * number of attribute components in any order.
+ *
+ * Example:
+ * \code{.cpp}
+ * event_attributes attr{}; // No arguments, use defaults for all attributes
+ *
+ * event_attributes attr{"message"}; // Custom message, rest defaulted
+ *
+ * // Custom color & message
+ * event_attributes attr{"message", nvtx3::rgb{127, 255, 0}};
+ *
+ * /// Custom color & message, can use any order of arguments
+ * event_attributes attr{nvtx3::rgb{127, 255, 0}, "message"};
+ *
+ *
+ * // Custom color, message, payload, category
+ * event_attributes attr{nvtx3::rgb{127, 255, 0},
+ * "message",
+ * nvtx3::payload{42},
+ * nvtx3::category{1}};
+ *
+ * // Custom color, message, payload, category, can use any order of arguments
+ * event_attributes attr{nvtx3::payload{42},
+ * nvtx3::category{1},
+ * "message",
+ * nvtx3::rgb{127, 255, 0}};
+ *
+ * // Multiple arguments of the same type are allowed, but only the first is
+ * // used. All others are ignored
+ * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload
+ * is 42
+ *
+ * // Range `r` will be customized according the attributes in `attr`
+ * nvtx3::thread_range r{attr};
+ *
+ * // For convenience, the arguments that can be passed to the
+ * `event_attributes`
+ * // constructor may be passed to the `domain_thread_range` contructor where
+ * // they will be forwarded to the `EventAttribute`s constructor
+ * nvtx3::thread_range r{nvtx3::payload{42}, nvtx3::category{1}, "message"};
+ * \endcode
+ *
+ */
+class event_attributes {
+ public:
+ using value_type = nvtxEventAttributes_t;
+
+ /**
+ * @brief Default constructor creates an `event_attributes` with no
+ * category, color, payload, nor message.
+ */
+ constexpr event_attributes() noexcept
+ : attributes_{
+ NVTX_VERSION, // version
+ sizeof(nvtxEventAttributes_t), // size
+ 0, // category
+ NVTX_COLOR_UNKNOWN, // color type
+ 0, // color value
+ NVTX_PAYLOAD_UNKNOWN, // payload type
+ 0, // reserved 4B
+ 0, // payload value (union)
+ NVTX_MESSAGE_UNKNOWN, // message type
+ 0 // message value (union)
+ }
+ {
+ }
+
+ /**
+ * @brief Variadic constructor where the first argument is a `category`.
+ *
+ * Sets the value of the `EventAttribute`s category based on `c` and
+ * forwards the remaining variadic parameter pack to the next constructor.
+ *
+ */
+ template <typename... Args>
+ NVTX3_RELAXED_CONSTEXPR explicit event_attributes(category const& c, Args const&... args) noexcept
+ : event_attributes(args...)
+ {
+ attributes_.category = c.get_id();
+ }
+
+ /**
+ * @brief Variadic constructor where the first argument is a `color`.
+ *
+ * Sets the value of the `EventAttribute`s color based on `c` and forwards
+ * the remaining variadic parameter pack to the next constructor.
+ *
+ */
+ template <typename... Args>
+ NVTX3_RELAXED_CONSTEXPR explicit event_attributes(color const& c, Args const&... args) noexcept
+ : event_attributes(args...)
+ {
+ attributes_.color = c.get_value();
+ attributes_.colorType = c.get_type();
+ }
+
+ /**
+ * @brief Variadic constructor where the first argument is a `payload`.
+ *
+ * Sets the value of the `EventAttribute`s payload based on `p` and forwards
+ * the remaining variadic parameter pack to the next constructor.
+ *
+ */
+ template <typename... Args>
+ NVTX3_RELAXED_CONSTEXPR explicit event_attributes(payload const& p, Args const&... args) noexcept
+ : event_attributes(args...)
+ {
+ attributes_.payload = p.get_value();
+ attributes_.payloadType = p.get_type();
+ }
+
+ /**
+ * @brief Variadic constructor where the first argument is a `message`.
+ *
+ * Sets the value of the `EventAttribute`s message based on `m` and forwards
+ * the remaining variadic parameter pack to the next constructor.
+ *
+ */
+ template <typename... Args>
+ NVTX3_RELAXED_CONSTEXPR explicit event_attributes(message const& m, Args const&... args) noexcept
+ : event_attributes(args...)
+ {
+ attributes_.message = m.get_value();
+ attributes_.messageType = m.get_type();
+ }
+
+ ~event_attributes() = default;
+ event_attributes(event_attributes const&) = default;
+ event_attributes& operator=(event_attributes const&) = default;
+ event_attributes(event_attributes&&) = default;
+ event_attributes& operator=(event_attributes&&) = default;
+
+ /**
+ * @brief Get raw pointer to underlying NVTX attributes object.
+ *
+ */
+ constexpr value_type const* get() const noexcept { return &attributes_; }
+
+ private:
+ value_type attributes_{}; ///< The NVTX attributes structure
+};
+
+/**
+ * @brief A RAII object for creating a NVTX range local to a thread within a
+ * domain.
+ *
+ * When constructed, begins a nested NVTX range on the calling thread in the
+ * specified domain. Upon destruction, ends the NVTX range.
+ *
+ * Behavior is undefined if a `domain_thread_range` object is
+ * created/destroyed on different threads.
+ *
+ * `domain_thread_range` is neither moveable nor copyable.
+ *
+ * `domain_thread_range`s may be nested within other ranges.
+ *
+ * The domain of the range is specified by the template type parameter `D`.
+ * By default, the `domain::global` is used, which scopes the range to the
+ * global NVTX domain. The convenience alias `thread_range` is provided for
+ * ranges scoped to the global domain.
+ *
+ * A custom domain can be defined by creating a type, `D`, with a static
+ * member `D::name` whose value is used to name the domain associated with
+ * `D`. `D::name` must resolve to either `char const*` or `wchar_t const*`
+ *
+ * Example:
+ * ```
+ * // Define a type `my_domain` with a member `name` used to name the domain
+ * // associated with the type `my_domain`.
+ * struct my_domain{
+ * static constexpr const char * name{"my domain"};
+ * };
+ * ```
+ *
+ * Usage:
+ * ```
+ * nvtx3::domain_thread_range<> r0{"range 0"}; // Range in global domain
+ *
+ * nvtx3::thread_range r1{"range 1"}; // Alias for range in global domain
+ *
+ * nvtx3::domain_thread_range<my_domain> r2{"range 2"}; // Range in custom
+ * domain
+ *
+ * // specify an alias to a range that uses a custom domain
+ * using my_thread_range = nvtx3::domain_thread_range<my_domain>;
+ *
+ * my_thread_range r3{"range 3"}; // Alias for range in custom domain
+ * ```
+ */
+template <class D = domain::global>
+class domain_thread_range {
+ public:
+ /**
+ * @brief Construct a `domain_thread_range` with the specified
+ * `event_attributes`
+ *
+ * Example:
+ * ```
+ * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}};
+ * nvtx3::domain_thread_range<> range{attr}; // Creates a range with message
+ * contents
+ * // "msg" and green color
+ * ```
+ *
+ * @param[in] attr `event_attributes` that describes the desired attributes
+ * of the range.
+ */
+ explicit domain_thread_range(event_attributes const& attr) noexcept
+ {
+#ifndef NVTX_DISABLE
+ nvtxDomainRangePushEx(domain::get<D>(), attr.get());
+#else
+ (void)attr;
+#endif
+ }
+
+ /**
+ * @brief Constructs a `domain_thread_range` from the constructor arguments
+ * of an `event_attributes`.
+ *
+ * Forwards the arguments `first, args...` to construct an
+ * `event_attributes` object. The `event_attributes` object is then
+ * associated with the `domain_thread_range`.
+ *
+ * For more detail, see `event_attributes` documentation.
+ *
+ * Example:
+ * ```
+ * // Creates a range with message "message" and green color
+ * nvtx3::domain_thread_range<> r{"message", nvtx3::rgb{127,255,0}};
+ * ```
+ *
+ * @note To prevent making needless copies of `event_attributes` objects,
+ * this constructor is disabled when the first argument is an
+ * `event_attributes` object, instead preferring the explicit
+ * `domain_thread_range(event_attributes const&)` constructor.
+ *
+ * @param[in] first First argument to forward to the `event_attributes`
+ * constructor.
+ * @param[in] args Variadic parameter pack of additional arguments to
+ * forward.
+ *
+ */
+ template <typename First,
+ typename... Args,
+ typename = typename std::enable_if<
+ !std::is_same<event_attributes, typename std::decay<First>>::value>>
+ explicit domain_thread_range(First const& first, Args const&... args) noexcept
+ : domain_thread_range{event_attributes{first, args...}}
+ {
+ }
+
+ /**
+ * @brief Default constructor creates a `domain_thread_range` with no
+ * message, color, payload, nor category.
+ *
+ */
+ domain_thread_range() : domain_thread_range{event_attributes{}} {}
+
+ /**
+ * @brief Delete `operator new` to disallow heap allocated objects.
+ *
+ * `domain_thread_range` must follow RAII semantics to guarantee proper push/pop semantics.
+ *
+ */
+ void* operator new(std::size_t) = delete;
+
+ domain_thread_range(domain_thread_range const&) = delete;
+ domain_thread_range& operator=(domain_thread_range const&) = delete;
+ domain_thread_range(domain_thread_range&&) = delete;
+ domain_thread_range& operator=(domain_thread_range&&) = delete;
+
+ /**
+ * @brief Destroy the domain_thread_range, ending the NVTX range event.
+ */
+ ~domain_thread_range() noexcept
+ {
+#ifndef NVTX_DISABLE
+ nvtxDomainRangePop(domain::get<D>());
+#endif
+ }
+};
+
+/**
+ * @brief Alias for a `domain_thread_range` in the global NVTX domain.
+ *
+ */
+using thread_range = domain_thread_range<>;
+
+/**
+ * @brief Handle used for correlating explicit range start and end events.
+ *
+ */
+struct range_handle {
+ /// Type used for the handle's value
+ using value_type = nvtxRangeId_t;
+
+ /**
+ * @brief Construct a `range_handle` from the given id.
+ *
+ */
+ constexpr explicit range_handle(value_type id) noexcept : _range_id{id} {}
+
+ /**
+ * @brief Returns the `range_handle`'s value
+ *
+ * @return value_type The handle's value
+ */
+ constexpr value_type get_value() const noexcept { return _range_id; }
+
+ private:
+ value_type _range_id{}; ///< The underlying NVTX range id
+};
+
+/**
+ * @brief Manually begin an NVTX range.
+ *
+ * Explicitly begins an NVTX range and returns a unique handle. To end the
+ * range, pass the handle to `end_range()`.
+ *
+ * `start_range/end_range` are the most explicit and lowest level APIs provided
+ * for creating ranges. Use of `nvtx3::domain_process_range` should be
+ * preferred unless one is unable to tie the range to the lifetime of an object.
+ *
+ * Example:
+ * ```
+ * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}};
+ * nvtx3::range_handle h = nvxt3::start_range(attr); // Manually begins a range
+ * ...
+ * nvtx3::end_range(h); // Ends the range
+ * ```
+ *
+ * @tparam D Type containing `name` member used to identify the `domain`
+ * to which the range belongs. Else, `domain::global` to indicate that the
+ * global NVTX domain should be used.
+ * @param[in] attr `event_attributes` that describes the desired attributes
+ * of the range.
+ * @return Unique handle to be passed to `end_range` to end the range.
+ */
+template <typename D = domain::global>
+range_handle start_range(event_attributes const& attr) noexcept
+{
+#ifndef NVTX_DISABLE
+ return range_handle{nvtxDomainRangeStartEx(domain::get<D>(), attr.get())};
+#else
+ (void)attr;
+ return range_handle{};
+#endif
+}
+
+/**
+ * @brief Manually begin an NVTX range.
+ *
+ * Explicitly begins an NVTX range and returns a unique handle. To end the
+ * range, pass the handle to `end_range()`.
+ *
+ * Forwards the arguments `first, args...` to construct an `event_attributes`
+ * object. The `event_attributes` object is then associated with the range.
+ *
+ * For more detail, see `event_attributes` documentation.
+ *
+ * Example:
+ * ```
+ * nvtx3::range_handle h = nvxt3::start_range("msg", nvtx3::rgb{127,255,0}); //
+ * Begin range
+ * ...
+ * nvtx3::end_range(h); // Ends the range
+ * ```
+ *
+ * `start_range/end_range` are the most explicit and lowest level APIs provided
+ * for creating ranges. Use of `nvtx3::domain_process_range` should be
+ * preferred unless one is unable to tie the range to the lifetime of an object.
+ *
+ * @param first[in] First argument to pass to an `event_attributes`
+ * @param args[in] Variadiac parameter pack of the rest of the arguments for an
+ * `event_attributes`.
+ * @return Unique handle to be passed to `end_range` to end the range.
+ */
+template <typename First,
+ typename... Args,
+ typename = typename std::enable_if<
+ !std::is_same<event_attributes, typename std::decay<First>>::value>>
+range_handle start_range(First const& first, Args const&... args) noexcept
+{
+#ifndef NVTX_DISABLE
+ return start_range(event_attributes{first, args...});
+#else
+ (void)first;
+ return range_handle{};
+#endif
+}
+
+/**
+ * @brief Manually end the range associated with the handle `r`.
+ *
+ * Explicitly ends the NVTX range indicated by the handle `r` returned from a
+ * prior call to `start_range`. The range may end on a different thread from
+ * where it began.
+ *
+ * This function does not have a Domain tag type template parameter as the
+ * handle `r` already indicates the domain to which the range belongs.
+ *
+ * @param r Handle to a range started by a prior call to `start_range`.
+ */
+inline void end_range(range_handle r)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangeEnd(r.get_value());
+#else
+ (void)r;
+#endif
+}
+
+/**
+ * @brief A RAII object for creating a NVTX range within a domain that can
+ * be created and destroyed on different threads.
+ *
+ * When constructed, begins a NVTX range in the specified domain. Upon
+ * destruction, ends the NVTX range.
+ *
+ * Similar to `nvtx3::domain_thread_range`, the only difference being that
+ * `domain_process_range` can start and end on different threads.
+ *
+ * Use of `nvtx3::domain_thread_range` should be preferred unless one needs
+ * the ability to start and end a range on different threads.
+ *
+ * `domain_process_range` is moveable, but not copyable.
+ *
+ * @tparam D Type containing `name` member used to identify the `domain`
+ * to which the `domain_process_range` belongs. Else, `domain::global` to
+ * indicate that the global NVTX domain should be used.
+ */
+template <typename D = domain::global>
+class domain_process_range {
+ public:
+ /**
+ * @brief Construct a new domain process range object
+ *
+ * @param attr
+ */
+ explicit domain_process_range(event_attributes const& attr) noexcept
+ : handle_{new range_handle{start_range<D>(attr)}}
+ {
+ }
+
+ /**
+ * @brief Construct a new domain process range object
+ *
+ * @param first
+ * @param args
+ */
+ template <typename First,
+ typename... Args,
+ typename = typename std::enable_if<
+ !std::is_same<event_attributes, typename std::decay<First>>::value>>
+ explicit domain_process_range(First const& first, Args const&... args) noexcept
+ : domain_process_range{event_attributes{first, args...}}
+ {
+ }
+
+ /**
+ * @brief Construct a new domain process range object
+ *
+ */
+ constexpr domain_process_range() noexcept : domain_process_range{event_attributes{}} {}
+
+ /**
+ * @brief Destroy the `domain_process_range` ending the range.
+ *
+ */
+ ~domain_process_range()
+ {
+ if (handle_) { end_range(*handle_); }
+ }
+
+ /**
+ * @brief Move constructor allows taking ownership of the NVTX range from
+ * another `domain_process_range`.
+ *
+ * @param other
+ */
+ domain_process_range(domain_process_range&& other) = default;
+
+ /**
+ * @brief Move assignment operator allows taking ownership of an NVTX range
+ * from another `domain_process_range`.
+ *
+ * @param other
+ * @return domain_process_range&
+ */
+ domain_process_range& operator=(domain_process_range&& other) = default;
+
+ /// Copy construction is not allowed to prevent multiple objects from owning
+ /// the same range handle
+ domain_process_range(domain_process_range const&) = delete;
+
+ /// Copy assignment is not allowed to prevent multiple objects from owning the
+ /// same range handle
+ domain_process_range& operator=(domain_process_range const&) = delete;
+
+ private:
+ std::unique_ptr<range_handle> handle_; ///< Range handle used to correlate
+ ///< the start/end of the range
+};
+
+/**
+ * @brief Alias for a `domain_process_range` in the global NVTX domain.
+ *
+ */
+using process_range = domain_process_range<>;
+
+/**
+ * @brief Annotates an instantaneous point in time with the attributes specified
+ * by `attr`.
+ *
+ * Unlike a "range", a mark is an instantaneous event in an application, e.g.,
+ * locking/unlocking a mutex.
+ *
+ * \code{.cpp}
+ * std::mutex global_lock;
+ * void lock_mutex(){
+ * global_lock.lock();
+ * nvtx3::mark("lock_mutex");
+ * }
+ * \endcode
+ *
+ * @tparam D Type containing `name` member used to identify the `domain`
+ * to which the `domain_process_range` belongs. Else, `domain::global` to
+ * indicate that the global NVTX domain should be used.
+ * @param[in] attr `event_attributes` that describes the desired attributes
+ * of the mark.
+ */
+template <typename D = domain::global>
+inline void mark(event_attributes const& attr) noexcept
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainMarkEx(domain::get<D>(), attr.get());
+#else
+ (void)(attr);
+#endif
+}
+
+} // namespace NVTX3_VERSION_NAMESPACE
+
+} // namespace nvtx3
+
+/**
+ * @brief Convenience macro for generating a range in the specified `domain`
+ * from the lifetime of a function
+ *
+ * This macro is useful for generating an NVTX range in `domain` from
+ * the entry point of a function to its exit. It is intended to be the first
+ * line of the function.
+ *
+ * Constructs a static `registered_string` using the name of the immediately
+ * enclosing function returned by `__func__` and constructs a
+ * `nvtx3::thread_range` using the registered function name as the range's
+ * message.
+ *
+ * Example:
+ * ```
+ * struct my_domain{static constexpr char const* name{"my_domain"};};
+ *
+ * void foo(...){
+ * NVTX3_FUNC_RANGE_IN(my_domain); // Range begins on entry to foo()
+ * // do stuff
+ * ...
+ * } // Range ends on return from foo()
+ * ```
+ *
+ * @param[in] D Type containing `name` member used to identify the
+ * `domain` to which the `registered_string` belongs. Else,
+ * `domain::global` to indicate that the global NVTX domain should be used.
+ */
+#ifndef NVTX_DISABLE
+#define NVTX3_V1_FUNC_RANGE_IN(D) \
+ static ::nvtx3::v1::registered_string<D> const nvtx3_func_name__{__func__}; \
+ static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \
+ ::nvtx3::v1::domain_thread_range<D> const nvtx3_range__{nvtx3_func_attr__};
+#else
+#define NVTX3_V1_FUNC_RANGE_IN(D)
+#endif
+
+/**
+ * @brief Convenience macro for generating a range in the global domain from the
+ * lifetime of a function.
+ *
+ * This macro is useful for generating an NVTX range in the global domain from
+ * the entry point of a function to its exit. It is intended to be the first
+ * line of the function.
+ *
+ * Constructs a static `registered_string` using the name of the immediately
+ * enclosing function returned by `__func__` and constructs a
+ * `nvtx3::thread_range` using the registered function name as the range's
+ * message.
+ *
+ * Example:
+ * ```
+ * void foo(...){
+ * NVTX3_FUNC_RANGE(); // Range begins on entry to foo()
+ * // do stuff
+ * ...
+ * } // Range ends on return from foo()
+ * ```
+ */
+#define NVTX3_V1_FUNC_RANGE() NVTX3_V1_FUNC_RANGE_IN(::nvtx3::v1::domain::global)
+
+/* When inlining this version, versioned macros must have unversioned aliases.
+ * For each NVTX3_Vx_ #define, make an NVTX3_ alias of it here.*/
+#if defined(NVTX3_INLINE_THIS_VERSION)
+/* clang format off */
+#define NVTX3_FUNC_RANGE_IN NVTX3_V1_FUNC_RANGE_IN
+#define NVTX3_FUNC_RANGE NVTX3_V1_FUNC_RANGE
+/* clang format on */
+#endif
+
+#endif // NVTX3_CPP_DEFINITIONS_V1_0
+
+/* Add functionality for new minor versions here, by copying the above section enclosed
+ * in #ifndef NVTX3_CPP_DEFINITIONS_Vx_y, and incrementing the minor version. This code
+ * is an example of how additions for version 1.2 would look, indented for clarity. Note
+ * that the versioned symbols and macros are always provided, and the unversioned symbols
+ * are only provided if NVTX3_INLINE_THIS_VERSION was defined at the top of this header.
+ *
+ * \code{.cpp}
+ * #ifndef NVTX3_CPP_DEFINITIONS_V1_2
+ * #define NVTX3_CPP_DEFINITIONS_V1_2
+ * namespace nvtx3 {
+ * NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE {
+ * class new_class {};
+ * inline void new_function() {}
+ * }
+ * }
+ *
+ * // Macros must have the major version in their names:
+ * #define NVTX3_V1_NEW_MACRO_A() ...
+ * #define NVTX3_V1_NEW_MACRO_B() ...
+ *
+ * // If inlining, make aliases for the macros with the version number omitted
+ * #if defined(NVTX3_INLINE_THIS_VERSION)
+ * #define NVTX3_NEW_MACRO_A NVTX3_V1_NEW_MACRO_A
+ * #define NVTX3_NEW_MACRO_B NVTX3_V1_NEW_MACRO_B
+ * #endif
+ * #endif // NVTX3_CPP_DEFINITIONS_V1_2
+ * \endcode
+ */
+
+/* Undefine all temporarily-defined unversioned macros, which would conflict with
+ * subsequent includes of different versions of this header. */
+#undef NVTX3_CPP_VERSION_MAJOR
+#undef NVTX3_CPP_VERSION_MINOR
+#undef NVTX3_CONCAT
+#undef NVTX3_NAMESPACE_FOR
+#undef NVTX3_VERSION_NAMESPACE
+#undef NVTX3_INLINE_IF_REQUESTED
+#undef NVTX3_RELAXED_CONSTEXPR
+
+#if defined(NVTX3_INLINE_THIS_VERSION)
+#undef NVTX3_INLINE_THIS_VERSION
+#endif
diff --git a/src/include/nvtx3/nvToolsExt.h b/src/include/nvtx3/nvToolsExt.h
new file mode 100644
index 0000000..ce4b0be
--- /dev/null
+++ b/src/include/nvtx3/nvToolsExt.h
@@ -0,0 +1,1470 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+/** \file nvToolsExt.h
+ */
+
+/* ========================================================================= */
+/** \mainpage
+ * \tableofcontents
+ * \section INTRODUCTION Introduction
+ *
+ * The NVIDIA Tools Extension library is a set of functions that a
+ * developer can use to provide additional information to tools.
+ * The additional information is used by the tool to improve
+ * analysis and visualization of data.
+ *
+ * The library introduces close to zero overhead if no tool is
+ * attached to the application. The overhead when a tool is
+ * attached is specific to the tool.
+ *
+ * \section INITIALIZATION_SECTION Initialization
+ *
+ * Typically the tool's library that plugs into NVTX is indirectly
+ * loaded via enviromental properties that are platform specific.
+ * For some platform or special cases, the user may be required
+ * to instead explicity initialize instead though. This can also
+ * be helpful to control when the API loads a tool's library instead
+ * of what would typically be the first function call to emit info.
+ * For these rare case, see \ref INITIALIZATION for additional information.
+ *
+ * \section MARKERS_AND_RANGES Markers and Ranges
+ *
+ * Markers and ranges are used to describe events at a specific time (markers)
+ * or over a time span (ranges) during the execution of the application
+ * respectively.
+ *
+ * \subsection MARKERS Markers
+ *
+ * Markers denote specific moments in time.
+ *
+ *
+ * See \ref DOMAINS and \ref EVENT_ATTRIBUTES for additional information on
+ * how to specify the domain.
+ *
+ * \subsection THREAD_RANGES Thread Ranges
+ *
+ * Thread ranges denote nested time ranges. Nesting is maintained per thread
+ * per domain and does not require any additional correlation mechanism. The
+ * duration of a thread range is defined by the corresponding pair of
+ * nvtxRangePush* to nvtxRangePop API calls.
+ *
+ * See \ref DOMAINS and \ref EVENT_ATTRIBUTES for additional information on
+ * how to specify the domain.
+ *
+ * \subsection PROCESS_RANGES Process Ranges
+ *
+ * Process ranges denote a time span that can expose arbitrary concurrency, as
+ * opposed to thread ranges that only support nesting. In addition the range
+ * start event can happen on a different thread than the end marker. For the
+ * correlation of a start/end pair an unique correlation ID is used that is
+ * returned from the start API call and needs to be passed into the end API
+ * call.
+ *
+ * \subsection EVENT_ATTRIBUTES Event Attributes
+ *
+ * \ref MARKERS_AND_RANGES can be annotated with various attributes to provide
+ * additional information for an event or to guide the tool's visualization of
+ * the data. Each of the attributes is optional and if left unused the
+ * attributes fall back to a default value. The attributes include:
+ * - color
+ * - category
+ *
+ * To specify any attribute other than the text message, the \ref
+ * EVENT_ATTRIBUTE_STRUCTURE "Event Attribute Structure" must be used.
+ *
+ * \section DOMAINS Domains
+ *
+ * Domains enable developers to scope annotations. By default all events and
+ * annotations are in the default domain. Additional domains can be registered.
+ * This allows developers to scope markers, ranges, and resources names to
+ * avoid conflicts.
+ *
+ * The function ::nvtxDomainCreateA or ::nvtxDomainCreateW is used to create
+ * a named domain.
+ *
+ * Each domain maintains its own
+ * - categories
+ * - thread range stacks
+ * - registered strings
+ *
+ * The function ::nvtxDomainDestroy marks the end of the domain. Destroying
+ * a domain unregisters and destroys all objects associated with it such as
+ * registered strings, resource objects, named categories, and started ranges.
+ *
+ * \section RESOURCE_NAMING Resource Naming
+ *
+ * This section covers calls that allow to annotate objects with user-provided
+ * names in order to allow for a better analysis of complex trace data. All of
+ * the functions take the handle or the ID of the object to name and the name.
+ * The functions can be called multiple times during the execution of an
+ * application, however, in that case it is implementation dependent which
+ * name will be reported by the tool.
+ *
+ * \subsection CATEGORY_NAMING Category Naming
+ *
+ * Some function in this library support associating an integer category
+ * to enable filtering and sorting. The category naming functions allow
+ * the application to associate a user friendly name with the integer
+ * category. Support for domains have been added in NVTX_VERSION_2 to
+ * avoid collisions when domains are developed independantly.
+ *
+ * \subsection RESOURCE_OBJECTS Resource Objects
+ *
+ * Resource objects are a generic mechanism for attaching data to an application
+ * resource. The identifier field makes the association to a pointer or handle,
+ * while the type field helps provide deeper understanding of the identifier as
+ * well as enabling differentiation in cases where handles generated by different
+ * APIs may collide. The resource object may also have an associated message to
+ * associate with the application resource, enabling further annotation of this
+ * object and how it is used.
+ *
+ * The resource object was introduced in NVTX_VERSION_2 to supersede existing naming
+ * functions and allow the application resource identified by those functions to be
+ * associated to a domain. The other naming functions are still supported for backward
+ * compatibility but will be associated only to the default domain.
+ *
+ * \subsection RESOURCE_NAMING_OS Resource Naming
+ *
+ * Some operating system resources creation APIs do not support providing a user friendly
+ * name, such as some OS thread creation APIs. This API support resource naming though
+ * both through resource objects and functions following the pattern
+ * nvtxName[RESOURCE_TYPE][A|W](identifier, name). Resource objects introduced in NVTX_VERSION 2
+ * supersede the other functions with a a more general method of assigning names to OS resources,
+ * along with associating them to domains too. The older nvtxName* functions are only associated
+ * with the default domain.
+ * \section EXTENSIONS Optional Extensions
+ * Optional extensions will either appear within the existing sections the extend or appear
+ * in the "Related Pages" when they introduce new concepts.
+ */
+
+ /**
+ * Tools Extension API version
+ */
+#if defined(NVTX_VERSION) && NVTX_VERSION < 3
+#error "Trying to #include NVTX version 3 in a source file where an older NVTX version has already been included. If you are not directly using NVTX (the NVIDIA Tools Extension library), you are getting this error because libraries you are using have included different versions of NVTX. Suggested solutions are: (1) reorder #includes so the newest NVTX version is included first, (2) avoid using the conflicting libraries in the same .c/.cpp file, or (3) update the library using the older NVTX version to use the newer version instead."
+#endif
+
+/* Header guard */
+#if !defined(NVTX_VERSION)
+#define NVTX_VERSION 3
+
+#if defined(_MSC_VER)
+#define NVTX_API __stdcall
+#define NVTX_INLINE_STATIC __inline static
+#else /*defined(__GNUC__)*/
+#define NVTX_API
+#define NVTX_INLINE_STATIC inline static
+#endif /* Platform */
+
+#if defined(NVTX_NO_IMPL)
+/* When omitting implementation, avoid declaring functions inline */
+/* without definitions, since this causes compiler warnings. */
+#define NVTX_DECLSPEC
+#elif defined(NVTX_EXPORT_API)
+/* Allow overriding definition of NVTX_DECLSPEC when exporting API. */
+/* Default is empty, meaning non-inline with external linkage. */
+#if !defined(NVTX_DECLSPEC)
+#define NVTX_DECLSPEC
+#endif
+#else
+/* Normal NVTX usage defines the NVTX API inline with static */
+/* (internal) linkage. */
+#define NVTX_DECLSPEC NVTX_INLINE_STATIC
+#endif
+
+#include "nvtxDetail/nvtxLinkOnce.h"
+
+#define NVTX_VERSIONED_IDENTIFIER_L3(NAME, VERSION) NAME##_v##VERSION
+#define NVTX_VERSIONED_IDENTIFIER_L2(NAME, VERSION) NVTX_VERSIONED_IDENTIFIER_L3(NAME, VERSION)
+#define NVTX_VERSIONED_IDENTIFIER(NAME) NVTX_VERSIONED_IDENTIFIER_L2(NAME, NVTX_VERSION)
+
+/**
+ * The nvToolsExt library depends on stdint.h. If the build tool chain in use
+ * does not include stdint.h then define NVTX_STDINT_TYPES_ALREADY_DEFINED
+ * and define the following types:
+ * <ul>
+ * <li>uint8_t
+ * <li>int8_t
+ * <li>uint16_t
+ * <li>int16_t
+ * <li>uint32_t
+ * <li>int32_t
+ * <li>uint64_t
+ * <li>int64_t
+ * <li>uintptr_t
+ * <li>intptr_t
+ * </ul>
+ * #define NVTX_STDINT_TYPES_ALREADY_DEFINED if you are using your own header file.
+ */
+#ifndef NVTX_STDINT_TYPES_ALREADY_DEFINED
+#include <stdint.h>
+#endif
+
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/**
+* Result Codes
+*/
+
+#define NVTX_SUCCESS 0
+#define NVTX_FAIL 1
+#define NVTX_ERR_INIT_LOAD_PROPERTY 2
+#define NVTX_ERR_INIT_ACCESS_LIBRARY 3
+#define NVTX_ERR_INIT_LOAD_LIBRARY 4
+#define NVTX_ERR_INIT_MISSING_LIBRARY_ENTRY_POINT 5
+#define NVTX_ERR_INIT_FAILED_LIBRARY_ENTRY_POINT 6
+#define NVTX_ERR_NO_INJECTION_LIBRARY_AVAILABLE 7
+
+/**
+ * Size of the nvtxEventAttributes_t structure.
+ */
+#define NVTX_EVENT_ATTRIB_STRUCT_SIZE ( (uint16_t)( sizeof(nvtxEventAttributes_t) ) )
+
+#define NVTX_NO_PUSH_POP_TRACKING ((int)-2)
+
+typedef uint64_t nvtxRangeId_t;
+
+/* Forward declaration of opaque domain registration structure */
+struct nvtxDomainRegistration_st;
+typedef struct nvtxDomainRegistration_st nvtxDomainRegistration;
+
+/* \brief Domain Handle Structure.
+* \anchor DOMAIN_HANDLE_STRUCTURE
+*
+* This structure is opaque to the user and is used as a handle to reference
+* a domain. This type is returned from tools when using the NVTX API to
+* create a domain.
+*
+*/
+typedef nvtxDomainRegistration* nvtxDomainHandle_t;
+
+/* Forward declaration of opaque string registration structure */
+struct nvtxStringRegistration_st;
+typedef struct nvtxStringRegistration_st nvtxStringRegistration;
+
+/* \brief Registered String Handle Structure.
+* \anchor REGISTERED_STRING_HANDLE_STRUCTURE
+*
+* This structure is opaque to the user and is used as a handle to reference
+* a registered string. This type is returned from tools when using the NVTX
+* API to create a registered string.
+*
+*/
+typedef nvtxStringRegistration* nvtxStringHandle_t;
+
+/* ========================================================================= */
+/** \defgroup GENERAL General
+ * @{
+ */
+
+/** ---------------------------------------------------------------------------
+ * Color Types
+ * ------------------------------------------------------------------------- */
+typedef enum nvtxColorType_t
+{
+ NVTX_COLOR_UNKNOWN = 0, /**< Color attribute is unused. */
+ NVTX_COLOR_ARGB = 1 /**< An ARGB color is provided. */
+} nvtxColorType_t;
+
+/** ---------------------------------------------------------------------------
+ * Message Types
+ * ------------------------------------------------------------------------- */
+typedef enum nvtxMessageType_t
+{
+ NVTX_MESSAGE_UNKNOWN = 0, /**< Message payload is unused. */
+ NVTX_MESSAGE_TYPE_ASCII = 1, /**< A character sequence is used as payload. */
+ NVTX_MESSAGE_TYPE_UNICODE = 2, /**< A wide character sequence is used as payload. */
+ /* NVTX_VERSION_2 */
+ NVTX_MESSAGE_TYPE_REGISTERED = 3, /**< A unique string handle that was registered
+ with \ref nvtxDomainRegisterStringA() or
+ \ref nvtxDomainRegisterStringW(). */
+} nvtxMessageType_t;
+
+typedef union nvtxMessageValue_t
+{
+ const char* ascii;
+ const wchar_t* unicode;
+ /* NVTX_VERSION_2 */
+ nvtxStringHandle_t registered;
+} nvtxMessageValue_t;
+
+
+/** @} */ /*END defgroup*/
+/* ------------------------------------------------------------------------- */
+/** \brief Force initialization (optional)
+*
+* Force NVTX library to initialize. The first call to any NVTX API function
+* will automatically initialize the entire API. This can make the first call
+* much slower than subsequent calls. In applications where the first call to
+* NVTX may be in a performance-critical section, calling nvtxInitialize before
+* any performance-critical sections will ensure NVTX initialization occurs at
+* an acceptable time. Since nvtxInitialize takes no parameters and has no
+* expected behavior besides initialization, it is convenient to add a call to
+* nvtxInitialize in NVTX-instrumented applications that need to force earlier
+* initialization without changing any other code. For example, if an app's
+* first NVTX call is nvtxDomainCreate, and it is difficult to move that call
+* earlier because the domain handle must be stored in an object only created
+* at that point, adding a call to nvtxInitialize at the top of main() will
+* ensure the later call to nvtxDomainCreate is as fast as possible.
+*
+* \version \NVTX_VERSION_3
+*
+* \param reserved - must be zero or NULL.
+*
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxInitialize(const void* reserved);
+/** @} */
+
+
+/** @} */ /*END defgroup*/
+
+/* ========================================================================= */
+/** \defgroup EVENT_ATTRIBUTES Event Attributes
+* @{
+*/
+
+/** ---------------------------------------------------------------------------
+* Payload Types
+* ------------------------------------------------------------------------- */
+typedef enum nvtxPayloadType_t
+{
+ NVTX_PAYLOAD_UNKNOWN = 0, /**< Color payload is unused. */
+ NVTX_PAYLOAD_TYPE_UNSIGNED_INT64 = 1, /**< A 64 bit unsigned integer value is used as payload. */
+ NVTX_PAYLOAD_TYPE_INT64 = 2, /**< A 64 bit signed integer value is used as payload. */
+ NVTX_PAYLOAD_TYPE_DOUBLE = 3, /**< A 64 bit floating point value is used as payload. */
+ /* NVTX_VERSION_2 */
+ NVTX_PAYLOAD_TYPE_UNSIGNED_INT32 = 4, /**< A 32 bit floating point value is used as payload. */
+ NVTX_PAYLOAD_TYPE_INT32 = 5, /**< A 32 bit floating point value is used as payload. */
+ NVTX_PAYLOAD_TYPE_FLOAT = 6 /**< A 32 bit floating point value is used as payload. */
+} nvtxPayloadType_t;
+
+/** \brief Event Attribute Structure.
+ * \anchor EVENT_ATTRIBUTE_STRUCTURE
+ *
+ * This structure is used to describe the attributes of an event. The layout of
+ * the structure is defined by a specific version of the tools extension
+ * library and can change between different versions of the Tools Extension
+ * library.
+ *
+ * \par Initializing the Attributes
+ *
+ * The caller should always perform the following three tasks when using
+ * attributes:
+ * <ul>
+ * <li>Zero the structure
+ * <li>Set the version field
+ * <li>Set the size field
+ * </ul>
+ *
+ * Zeroing the structure sets all the event attributes types and values
+ * to the default value.
+ *
+ * The version and size field are used by the Tools Extension
+ * implementation to handle multiple versions of the attributes structure.
+ *
+ * It is recommended that the caller use one of the following to methods
+ * to initialize the event attributes structure:
+ *
+ * \par Method 1: Initializing nvtxEventAttributes for future compatibility
+ * \code
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * eventAttrib.version = NVTX_VERSION;
+ * eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+ * \endcode
+ *
+ * \par Method 2: Initializing nvtxEventAttributes for a specific version
+ * \code
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * eventAttrib.version = 1;
+ * eventAttrib.size = (uint16_t)(sizeof(nvtxEventAttributes_v1));
+ * \endcode
+ *
+ * If the caller uses Method 1 it is critical that the entire binary
+ * layout of the structure be configured to 0 so that all fields
+ * are initialized to the default value.
+ *
+ * The caller should either use both NVTX_VERSION and
+ * NVTX_EVENT_ATTRIB_STRUCT_SIZE (Method 1) or use explicit values
+ * and a versioned type (Method 2). Using a mix of the two methods
+ * will likely cause either source level incompatibility or binary
+ * incompatibility in the future.
+ *
+ * \par Settings Attribute Types and Values
+ *
+ *
+ * \par Example:
+ * \code
+ * // Initialize
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * eventAttrib.version = NVTX_VERSION;
+ * eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+ *
+ * // Configure the Attributes
+ * eventAttrib.colorType = NVTX_COLOR_ARGB;
+ * eventAttrib.color = 0xFF880000;
+ * eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+ * eventAttrib.message.ascii = "Example";
+ * \endcode
+ *
+ * In the example the caller does not have to set the value of
+ * \ref ::nvtxEventAttributes_v2::category or
+ * \ref ::nvtxEventAttributes_v2::payload as these fields were set to
+ * the default value by {0}.
+ * \sa
+ * ::nvtxDomainMarkEx
+ * ::nvtxDomainRangeStartEx
+ * ::nvtxDomainRangePushEx
+ */
+typedef struct nvtxEventAttributes_v2
+{
+ /**
+ * \brief Version flag of the structure.
+ *
+ * Needs to be set to NVTX_VERSION to indicate the version of NVTX APIs
+ * supported in this header file. This can optionally be overridden to
+ * another version of the tools extension library.
+ */
+ uint16_t version;
+
+ /**
+ * \brief Size of the structure.
+ *
+ * Needs to be set to the size in bytes of the event attribute
+ * structure used to specify the event.
+ */
+ uint16_t size;
+
+ /**
+ * \brief ID of the category the event is assigned to.
+ *
+ * A category is a user-controlled ID that can be used to group
+ * events. The tool may use category IDs to improve filtering or
+ * enable grouping of events in the same category. The functions
+ * \ref ::nvtxNameCategoryA or \ref ::nvtxNameCategoryW can be used
+ * to name a category.
+ *
+ * Default Value is 0
+ */
+ uint32_t category;
+
+ /** \brief Color type specified in this attribute structure.
+ *
+ * Defines the color format of the attribute structure's \ref COLOR_FIELD
+ * "color" field.
+ *
+ * Default Value is NVTX_COLOR_UNKNOWN
+ */
+ int32_t colorType; /* nvtxColorType_t */
+
+ /** \brief Color assigned to this event. \anchor COLOR_FIELD
+ *
+ * The color that the tool should use to visualize the event.
+ */
+ uint32_t color;
+
+ /**
+ * \brief Payload type specified in this attribute structure.
+ *
+ * Defines the payload format of the attribute structure's \ref PAYLOAD_FIELD
+ * "payload" field.
+ *
+ * Default Value is NVTX_PAYLOAD_UNKNOWN
+ */
+ int32_t payloadType; /* nvtxPayloadType_t */
+
+ int32_t reserved0;
+
+ /**
+ * \brief Payload assigned to this event. \anchor PAYLOAD_FIELD
+ *
+ * A numerical value that can be used to annotate an event. The tool could
+ * use the payload data to reconstruct graphs and diagrams.
+ */
+ union payload_t
+ {
+ uint64_t ullValue;
+ int64_t llValue;
+ double dValue;
+ /* NVTX_VERSION_2 */
+ uint32_t uiValue;
+ int32_t iValue;
+ float fValue;
+ } payload;
+
+ /** \brief Message type specified in this attribute structure.
+ *
+ * Defines the message format of the attribute structure's \ref MESSAGE_FIELD
+ * "message" field.
+ *
+ * Default Value is NVTX_MESSAGE_UNKNOWN
+ */
+ int32_t messageType; /* nvtxMessageType_t */
+
+ /** \brief Message assigned to this attribute structure. \anchor MESSAGE_FIELD
+ *
+ * The text message that is attached to an event.
+ */
+ nvtxMessageValue_t message;
+
+} nvtxEventAttributes_v2;
+
+typedef struct nvtxEventAttributes_v2 nvtxEventAttributes_t;
+
+/** @} */ /*END defgroup*/
+/* ========================================================================= */
+/** \defgroup MARKERS_AND_RANGES Markers and Ranges
+ *
+ * See \ref MARKERS_AND_RANGES for more details
+ *
+ * @{
+ */
+
+/** \name Marker */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Marks an instantaneous event in the application.
+*
+* A marker can contain a text message or specify additional information
+* using the event attributes structure. These attributes include a text
+* message, color, category, and a payload. Each of the attributes is optional
+* and can only be sent out using the \ref nvtxDomainMarkEx function.
+*
+* nvtxDomainMarkEx(NULL, event) is equivalent to calling
+* nvtxMarkEx(event).
+*
+* \param domain - The domain of scoping the category.
+* \param eventAttrib - The event attribute structure defining the marker's
+* attribute types and attribute values.
+*
+* \sa
+* ::nvtxMarkEx
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxDomainMarkEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Marks an instantaneous event in the application.
+ *
+ * A marker can contain a text message or specify additional information
+ * using the event attributes structure. These attributes include a text
+ * message, color, category, and a payload. Each of the attributes is optional
+ * and can only be sent out using the \ref nvtxMarkEx function.
+ * If \ref nvtxMarkA or \ref nvtxMarkW are used to specify the marker
+ * or if an attribute is unspecified then a default value will be used.
+ *
+ * \param eventAttrib - The event attribute structure defining the marker's
+ * attribute types and attribute values.
+ *
+ * \par Example:
+ * \code
+ * // zero the structure
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * // set the version and the size information
+ * eventAttrib.version = NVTX_VERSION;
+ * eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+ * // configure the attributes. 0 is the default for all attributes.
+ * eventAttrib.colorType = NVTX_COLOR_ARGB;
+ * eventAttrib.color = 0xFF880000;
+ * eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+ * eventAttrib.message.ascii = "Example nvtxMarkEx";
+ * nvtxMarkEx(&eventAttrib);
+ * \endcode
+ *
+ * \sa
+ * ::nvtxDomainMarkEx
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxMarkEx(const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Marks an instantaneous event in the application.
+ *
+ * A marker created using \ref nvtxMarkA or \ref nvtxMarkW contains only a
+ * text message.
+ *
+ * \param message - The message associated to this marker event.
+ *
+ * \par Example:
+ * \code
+ * nvtxMarkA("Example nvtxMarkA");
+ * nvtxMarkW(L"Example nvtxMarkW");
+ * \endcode
+ *
+ * \sa
+ * ::nvtxDomainMarkEx
+ * ::nvtxMarkEx
+ *
+ * \version \NVTX_VERSION_0
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxMarkA(const char* message);
+NVTX_DECLSPEC void NVTX_API nvtxMarkW(const wchar_t* message);
+/** @} */
+
+
+/** \name Process Ranges */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a process range in a domain.
+*
+* \param domain - The domain of scoping the category.
+* \param eventAttrib - The event attribute structure defining the range's
+* attribute types and attribute values.
+*
+* \return The unique ID used to correlate a pair of Start and End events.
+*
+* \remarks Ranges defined by Start/End can overlap.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("my domain");
+* nvtxEventAttributes_t eventAttrib = {0};
+* eventAttrib.version = NVTX_VERSION;
+* eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* eventAttrib.message.ascii = "my range";
+* nvtxRangeId_t rangeId = nvtxDomainRangeStartEx(&eventAttrib);
+* // ...
+* nvtxDomainRangeEnd(rangeId);
+* \endcode
+*
+* \sa
+* ::nvtxDomainRangeEnd
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxDomainRangeStartEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a process range.
+ *
+ * \param eventAttrib - The event attribute structure defining the range's
+ * attribute types and attribute values.
+ *
+ * \return The unique ID used to correlate a pair of Start and End events.
+ *
+ * \remarks Ranges defined by Start/End can overlap.
+ *
+ * \par Example:
+ * \code
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * eventAttrib.version = NVTX_VERSION;
+ * eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+ * eventAttrib.category = 3;
+ * eventAttrib.colorType = NVTX_COLOR_ARGB;
+ * eventAttrib.color = 0xFF0088FF;
+ * eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+ * eventAttrib.message.ascii = "Example Range";
+ * nvtxRangeId_t rangeId = nvtxRangeStartEx(&eventAttrib);
+ * // ...
+ * nvtxRangeEnd(rangeId);
+ * \endcode
+ *
+ * \sa
+ * ::nvtxRangeEnd
+ * ::nvtxDomainRangeStartEx
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartEx(const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a process range.
+ *
+ * \param message - The event message associated to this range event.
+ *
+ * \return The unique ID used to correlate a pair of Start and End events.
+ *
+ * \remarks Ranges defined by Start/End can overlap.
+ *
+ * \par Example:
+ * \code
+ * nvtxRangeId_t r1 = nvtxRangeStartA("Range 1");
+ * nvtxRangeId_t r2 = nvtxRangeStartW(L"Range 2");
+ * nvtxRangeEnd(r1);
+ * nvtxRangeEnd(r2);
+ * \endcode
+ *
+ * \sa
+ * ::nvtxRangeEnd
+ * ::nvtxRangeStartEx
+ * ::nvtxDomainRangeStartEx
+ *
+ * \version \NVTX_VERSION_0
+ * @{ */
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartA(const char* message);
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartW(const wchar_t* message);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Ends a process range.
+*
+* \param domain - The domain
+* \param id - The correlation ID returned from a nvtxRangeStart call.
+*
+* \remarks This function is offered completeness but is an alias for ::nvtxRangeEnd.
+* It does not need a domain param since that is associated iwth the range ID at ::nvtxDomainRangeStartEx
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("my domain");
+* nvtxEventAttributes_t eventAttrib = {0};
+* eventAttrib.version = NVTX_VERSION;
+* eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* eventAttrib.message.ascii = "my range";
+* nvtxRangeId_t rangeId = nvtxDomainRangeStartEx(&eventAttrib);
+* // ...
+* nvtxDomainRangeEnd(rangeId);
+* \endcode
+*
+* \sa
+* ::nvtxDomainRangeStartEx
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxDomainRangeEnd(nvtxDomainHandle_t domain, nvtxRangeId_t id);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Ends a process range.
+ *
+ * \param id - The correlation ID returned from an nvtxRangeStart call.
+ *
+ * \sa
+ * ::nvtxDomainRangeStartEx
+ * ::nvtxRangeStartEx
+ * ::nvtxRangeStartA
+ * ::nvtxRangeStartW
+ *
+ * \version \NVTX_VERSION_0
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxRangeEnd(nvtxRangeId_t id);
+/** @} */
+
+/** \name Thread Ranges */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a nested thread range.
+*
+* \param domain - The domain of scoping.
+* \param eventAttrib - The event attribute structure defining the range's
+* attribute types and attribute values.
+*
+* \return The 0 based level of range being started. This value is scoped to the domain.
+* If an error occurs, a negative value is returned.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("example domain");
+* nvtxEventAttributes_t eventAttrib = {0};
+* eventAttrib.version = NVTX_VERSION;
+* eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib.colorType = NVTX_COLOR_ARGB;
+* eventAttrib.color = 0xFFFF0000;
+* eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* eventAttrib.message.ascii = "Level 0";
+* nvtxDomainRangePushEx(domain, &eventAttrib);
+*
+* // Re-use eventAttrib
+* eventAttrib.messageType = NVTX_MESSAGE_TYPE_UNICODE;
+* eventAttrib.message.unicode = L"Level 1";
+* nvtxDomainRangePushEx(domain, &eventAttrib);
+*
+* nvtxDomainRangePop(domain); //level 1
+* nvtxDomainRangePop(domain); //level 0
+* \endcode
+*
+* \sa
+* ::nvtxDomainRangePop
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC int NVTX_API nvtxDomainRangePushEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a nested thread range.
+ *
+ * \param eventAttrib - The event attribute structure defining the range's
+ * attribute types and attribute values.
+ *
+ * \return The 0 based level of range being started. This level is per domain.
+ * If an error occurs a negative value is returned.
+ *
+ * \par Example:
+ * \code
+ * nvtxEventAttributes_t eventAttrib = {0};
+ * eventAttrib.version = NVTX_VERSION;
+ * eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+ * eventAttrib.colorType = NVTX_COLOR_ARGB;
+ * eventAttrib.color = 0xFFFF0000;
+ * eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
+ * eventAttrib.message.ascii = "Level 0";
+ * nvtxRangePushEx(&eventAttrib);
+ *
+ * // Re-use eventAttrib
+ * eventAttrib.messageType = NVTX_MESSAGE_TYPE_UNICODE;
+ * eventAttrib.message.unicode = L"Level 1";
+ * nvtxRangePushEx(&eventAttrib);
+ *
+ * nvtxRangePop();
+ * nvtxRangePop();
+ * \endcode
+ *
+ * \sa
+ * ::nvtxDomainRangePushEx
+ * ::nvtxRangePop
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC int NVTX_API nvtxRangePushEx(const nvtxEventAttributes_t* eventAttrib);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Starts a nested thread range.
+ *
+ * \param message - The event message associated to this range event.
+ *
+ * \return The 0 based level of range being started. If an error occurs a
+ * negative value is returned.
+ *
+ * \par Example:
+ * \code
+ * nvtxRangePushA("Level 0");
+ * nvtxRangePushW(L"Level 1");
+ * nvtxRangePop();
+ * nvtxRangePop();
+ * \endcode
+ *
+ * \sa
+ * ::nvtxDomainRangePushEx
+ * ::nvtxRangePop
+ *
+ * \version \NVTX_VERSION_0
+ * @{ */
+NVTX_DECLSPEC int NVTX_API nvtxRangePushA(const char* message);
+NVTX_DECLSPEC int NVTX_API nvtxRangePushW(const wchar_t* message);
+/** @} */
+
+
+/* ------------------------------------------------------------------------- */
+/** \brief Ends a nested thread range.
+*
+* \return The level of the range being ended. If an error occurs a negative
+* value is returned on the current thread.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreate("example library");
+* nvtxDomainRangePushA(domain, "Level 0");
+* nvtxDomainRangePushW(domain, L"Level 1");
+* nvtxDomainRangePop(domain);
+* nvtxDomainRangePop(domain);
+* \endcode
+*
+* \sa
+* ::nvtxRangePushEx
+* ::nvtxRangePushA
+* ::nvtxRangePushW
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC int NVTX_API nvtxDomainRangePop(nvtxDomainHandle_t domain);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Ends a nested thread range.
+ *
+ * \return The level of the range being ended. If an error occurs a negative
+ * value is returned on the current thread.
+ *
+ * \par Example:
+ * \code
+ * nvtxRangePushA("Level 0");
+ * nvtxRangePushW(L"Level 1");
+ * nvtxRangePop();
+ * nvtxRangePop();
+ * \endcode
+ *
+ * \sa
+ * ::nvtxRangePushEx
+ * ::nvtxRangePushA
+ * ::nvtxRangePushW
+ *
+ * \version \NVTX_VERSION_0
+ * @{ */
+NVTX_DECLSPEC int NVTX_API nvtxRangePop(void);
+/** @} */
+
+
+/** @} */ /*END defgroup*/
+/* ========================================================================= */
+/** \defgroup RESOURCE_NAMING Resource Naming
+ *
+ * See \ref RESOURCE_NAMING for more details
+ *
+ * @{
+ */
+
+
+/* ------------------------------------------------------------------------- */
+/** \name Functions for Generic Resource Naming*/
+/* ------------------------------------------------------------------------- */
+
+/* ------------------------------------------------------------------------- */
+/** \cond SHOW_HIDDEN
+* \brief Resource typing helpers.
+*
+* Classes are used to make it easy to create a series of resource types
+* per API without collisions
+*/
+#define NVTX_RESOURCE_MAKE_TYPE(CLASS, INDEX) ((((uint32_t)(NVTX_RESOURCE_CLASS_ ## CLASS))<<16)|((uint32_t)(INDEX)))
+#define NVTX_RESOURCE_CLASS_GENERIC 1
+/** \endcond */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Generic resource type for when a resource class is not available.
+*
+* \sa
+* ::nvtxDomainResourceCreate
+*
+* \version \NVTX_VERSION_2
+*/
+typedef enum nvtxResourceGenericType_t
+{
+ NVTX_RESOURCE_TYPE_UNKNOWN = 0,
+ NVTX_RESOURCE_TYPE_GENERIC_POINTER = NVTX_RESOURCE_MAKE_TYPE(GENERIC, 1), /**< Generic pointer assumed to have no collisions with other pointers. */
+ NVTX_RESOURCE_TYPE_GENERIC_HANDLE = NVTX_RESOURCE_MAKE_TYPE(GENERIC, 2), /**< Generic handle assumed to have no collisions with other handles. */
+ NVTX_RESOURCE_TYPE_GENERIC_THREAD_NATIVE = NVTX_RESOURCE_MAKE_TYPE(GENERIC, 3), /**< OS native thread identifier. */
+ NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX = NVTX_RESOURCE_MAKE_TYPE(GENERIC, 4) /**< POSIX pthread identifier. */
+} nvtxResourceGenericType_t;
+
+
+
+/** \brief Resource Attribute Structure.
+* \anchor RESOURCE_ATTRIBUTE_STRUCTURE
+*
+* This structure is used to describe the attributes of a resource. The layout of
+* the structure is defined by a specific version of the tools extension
+* library and can change between different versions of the Tools Extension
+* library.
+*
+* \par Initializing the Attributes
+*
+* The caller should always perform the following three tasks when using
+* attributes:
+* <ul>
+* <li>Zero the structure
+* <li>Set the version field
+* <li>Set the size field
+* </ul>
+*
+* Zeroing the structure sets all the resource attributes types and values
+* to the default value.
+*
+* The version and size field are used by the Tools Extension
+* implementation to handle multiple versions of the attributes structure.
+*
+* It is recommended that the caller use one of the following to methods
+* to initialize the event attributes structure:
+*
+* \par Method 1: Initializing nvtxEventAttributes for future compatibility
+* \code
+* nvtxResourceAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_RESOURCE_ATTRIB_STRUCT_SIZE;
+* \endcode
+*
+* \par Method 2: Initializing nvtxEventAttributes for a specific version
+* \code
+* nvtxResourceAttributes_v0 attribs = {0};
+* attribs.version = 2;
+* attribs.size = (uint16_t)(sizeof(nvtxResourceAttributes_v0));
+* \endcode
+*
+* If the caller uses Method 1 it is critical that the entire binary
+* layout of the structure be configured to 0 so that all fields
+* are initialized to the default value.
+*
+* The caller should either use both NVTX_VERSION and
+* NVTX_RESOURCE_ATTRIB_STRUCT_SIZE (Method 1) or use explicit values
+* and a versioned type (Method 2). Using a mix of the two methods
+* will likely cause either source level incompatibility or binary
+* incompatibility in the future.
+*
+* \par Settings Attribute Types and Values
+*
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("example domain");
+*
+* // Initialize
+* nvtxResourceAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_RESOURCE_ATTRIB_STRUCT_SIZE;
+*
+* // Configure the Attributes
+* attribs.identifierType = NVTX_RESOURCE_TYPE_GENERIC_POINTER;
+* attribs.identifier.pValue = (const void*)pMutex;
+* attribs.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* attribs.message.ascii = "Single thread access to database.";
+*
+* nvtxResourceHandle_t handle = nvtxDomainResourceCreate(domain, attribs);
+* \endcode
+*
+* \sa
+* ::nvtxDomainResourceCreate
+*/
+typedef struct nvtxResourceAttributes_v0
+{
+ /**
+ * \brief Version flag of the structure.
+ *
+ * Needs to be set to NVTX_VERSION to indicate the version of NVTX APIs
+ * supported in this header file. This can optionally be overridden to
+ * another version of the tools extension library.
+ */
+ uint16_t version;
+
+ /**
+ * \brief Size of the structure.
+ *
+ * Needs to be set to the size in bytes of this attribute
+ * structure.
+ */
+ uint16_t size;
+
+ /**
+ * \brief Identifier type specifies how to interpret the identifier field
+ *
+ * Defines the identifier format of the attribute structure's \ref RESOURCE_IDENTIFIER_FIELD
+ * "identifier" field.
+ *
+ * Default Value is NVTX_RESOURCE_TYPE_UNKNOWN
+ */
+ int32_t identifierType; /* values from enums following the pattern nvtxResource[name]Type_t */
+
+ /**
+ * \brief Identifier for the resource.
+ * \anchor RESOURCE_IDENTIFIER_FIELD
+ *
+ * An identifier may be a pointer or a handle to an OS or middleware API object.
+ * The resource type will assist in avoiding collisions where handles values may collide.
+ */
+ union identifier_t
+ {
+ const void* pValue;
+ uint64_t ullValue;
+ } identifier;
+
+ /** \brief Message type specified in this attribute structure.
+ *
+ * Defines the message format of the attribute structure's \ref RESOURCE_MESSAGE_FIELD
+ * "message" field.
+ *
+ * Default Value is NVTX_MESSAGE_UNKNOWN
+ */
+ int32_t messageType; /* nvtxMessageType_t */
+
+ /** \brief Message assigned to this attribute structure. \anchor RESOURCE_MESSAGE_FIELD
+ *
+ * The text message that is attached to a resource.
+ */
+ nvtxMessageValue_t message;
+
+} nvtxResourceAttributes_v0;
+
+typedef struct nvtxResourceAttributes_v0 nvtxResourceAttributes_t;
+
+/* \cond SHOW_HIDDEN
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_RESOURCE_ATTRIB_STRUCT_SIZE ( (uint16_t)( sizeof(nvtxResourceAttributes_v0) ) )
+typedef struct nvtxResourceHandle* nvtxResourceHandle_t;
+/** \endcond */
+
+
+
+/* ------------------------------------------------------------------------- */
+/** \brief Create a resource object to track and associate data with OS and middleware objects
+*
+* Allows users to associate an API handle or pointer with a user-provided name.
+*
+*
+* \param domain - Domain to own the resource object
+* \param attribs - Attributes to be associated with the resource
+*
+* \return A handle that represents the newly created resource object.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("example domain");
+* nvtxResourceAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_RESOURCE_ATTRIB_STRUCT_SIZE;
+* attribs.identifierType = NVTX_RESOURCE_TYPE_GENERIC_POINTER;
+* attribs.identifier.pValue = (const void*)pMutex;
+* attribs.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* attribs.message.ascii = "Single thread access to database.";
+* nvtxResourceHandle_t handle = nvtxDomainResourceCreate(domain, attribs);
+* \endcode
+*
+* \sa
+* ::nvtxResourceAttributes_t
+* ::nvtxDomainResourceDestroy
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC nvtxResourceHandle_t NVTX_API nvtxDomainResourceCreate(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Destroy a resource object to track and associate data with OS and middleware objects
+*
+* Allows users to associate an API handle or pointer with a user-provided name.
+*
+* \param resource - Handle to the resource in which to operate.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("example domain");
+* nvtxResourceAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_RESOURCE_ATTRIB_STRUCT_SIZE;
+* attribs.identifierType = NVTX_RESOURCE_TYPE_GENERIC_POINTER;
+* attribs.identifier.pValue = (const void*)pMutex;
+* attribs.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* attribs.message.ascii = "Single thread access to database.";
+* nvtxResourceHandle_t handle = nvtxDomainResourceCreate(domain, attribs);
+* nvtxDomainResourceDestroy(handle);
+* \endcode
+*
+* \sa
+* ::nvtxDomainResourceCreate
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxDomainResourceDestroy(nvtxResourceHandle_t resource);
+/** @} */
+
+
+/** \name Functions for NVTX Category Naming*/
+
+/* ------------------------------------------------------------------------- */
+/**
+* \brief Annotate an NVTX category used within a domain.
+*
+* Categories are used to group sets of events. Each category is identified
+* through a unique ID and that ID is passed into any of the marker/range
+* events to assign that event to a specific category. The nvtxDomainNameCategory
+* function calls allow the user to assign a name to a category ID that is
+* specific to the domain.
+*
+* nvtxDomainNameCategory(NULL, category, name) is equivalent to calling
+* nvtxNameCategory(category, name).
+*
+* \param domain - The domain of scoping the category.
+* \param category - The category ID to name.
+* \param name - The name of the category.
+*
+* \remarks The category names are tracked per domain.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("example");
+* nvtxDomainNameCategoryA(domain, 1, "Memory Allocation");
+* nvtxDomainNameCategoryW(domain, 2, L"Memory Transfer");
+* \endcode
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryA(nvtxDomainHandle_t domain, uint32_t category, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryW(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name);
+/** @} */
+
+/** \brief Annotate an NVTX category.
+ *
+ * Categories are used to group sets of events. Each category is identified
+ * through a unique ID and that ID is passed into any of the marker/range
+ * events to assign that event to a specific category. The nvtxNameCategory
+ * function calls allow the user to assign a name to a category ID.
+ *
+ * \param category - The category ID to name.
+ * \param name - The name of the category.
+ *
+ * \remarks The category names are tracked per process.
+ *
+ * \par Example:
+ * \code
+ * nvtxNameCategory(1, "Memory Allocation");
+ * nvtxNameCategory(2, "Memory Transfer");
+ * nvtxNameCategory(3, "Memory Object Lifetime");
+ * \endcode
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCategoryA(uint32_t category, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCategoryW(uint32_t category, const wchar_t* name);
+/** @} */
+
+/** \name Functions for OS Threads Naming*/
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotate an OS thread.
+ *
+ * Allows the user to name an active thread of the current process. If an
+ * invalid thread ID is provided or a thread ID from a different process is
+ * used the behavior of the tool is implementation dependent.
+ *
+ * Tools expect thread ID to be a number that uniquely identifies the thread
+ * at the time of the call. Note that a thread's ID can be reused after
+ * it is destroyed. Tools may choose how to handle aliasing of thread IDs.
+ *
+ * POSIX pthread_t type returned by pthread_self() may not comply with these
+ * expectations. Please use OS-specific thread ID instead of pthread_t.
+ *
+ * The thread name is associated to the default domain. To support domains
+ * use resource objects via ::nvtxDomainResourceCreate.
+ *
+ * \param threadId - The ID of the thread to name.
+ * \param name - The name of the thread.
+ *
+ * \par Examples:
+ * MS Windows:
+ * \code
+ * #include <windows.h>
+ * nvtxNameOsThread(GetCurrentThreadId(), "Current thread");
+ * nvtxNameOsThread(GetThreadId(SomeThreadHandle), "Other thread");
+ * \endcode
+ *
+ * Android:
+ * \code
+ * #include <unistd.h>
+ * nvtxNameOsThreadA(gettid(), "Current thread");
+ * nvtxNameOsThreadA(getpid(), "Main thread");
+ * \endcode
+ *
+ * Linux:
+ * \code
+ * #include <sys/syscall.h>
+ * nvtxNameOsThreadA(syscall(SYS_gettid), "Current thread");
+ * \endcode
+ * \code
+ * #include <unistd.h>
+ * nvtxNameOsThreadA(getpid(), "Main thread");
+ * \endcode
+ *
+ * OS X:
+ * \code
+ * #include <sys/syscall.h>
+ * nvtxNameOsThreadA(syscall(SYS_thread_selfid), "Current thread");
+ * \endcode
+ * \code
+ * #include <pthread.h>
+ * __uint64_t id;
+ * pthread_threadid_np(pthread_self(), &id);
+ * nvtxNameOsThreadA(id, "Current thread");
+ * pthread_threadid_np(somePThreadId, &id);
+ * nvtxNameOsThreadA(id, "Other thread");
+ * \endcode
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadA(uint32_t threadId, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadW(uint32_t threadId, const wchar_t* name);
+/** @} */
+
+
+/** @} */ /*END defgroup*/
+/* ========================================================================= */
+/** \defgroup STRING_REGISTRATION String Registration
+*
+* Registered strings are intended to increase performance by lowering instrumentation
+* overhead. String may be registered once and the handle may be passed in place of
+* a string where an the APIs may allow.
+*
+* See \ref STRING_REGISTRATION for more details
+*
+* @{
+*/
+
+/* ------------------------------------------------------------------------- */
+/** \brief Register a string.
+
+* Registers an immutable string with NVTX. Once registered the pointer used
+* to register the domain name can be used in nvtxEventAttributes_t
+* \ref MESSAGE_FIELD. This allows NVTX implementation to skip copying the
+* contents of the message on each event invocation.
+*
+* String registration is an optimization. It is recommended to use string
+* registration if the string will be passed to an event many times.
+*
+* String are not unregistered, except that by unregistering the entire domain
+*
+* \param domain - Domain handle. If NULL then the global domain is used.
+* \param string - A unique pointer to a sequence of characters.
+*
+* \return A handle representing the registered string.
+*
+* \par Example:
+* \code
+* nvtxDomainCreateA("com.nvidia.nvtx.example");
+* nvtxStringHandle_t message = nvtxDomainRegisterStringA(domain, "registered string");
+* nvtxEventAttributes_t eventAttrib = {0};
+* eventAttrib.version = NVTX_VERSION;
+* eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib.messageType = NVTX_MESSAGE_TYPE_REGISTERED;
+* eventAttrib.message.registered = message;
+* \endcode
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringA(nvtxDomainHandle_t domain, const char* string);
+NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringW(nvtxDomainHandle_t domain, const wchar_t* string);
+/** @} */
+
+/** @} */ /*END defgroup*/
+/* ========================================================================= */
+/** \defgroup DOMAINS Domains
+*
+* Domains are used to group events to a developer defined scope. Middleware
+* vendors may also scope their own events to avoid collisions with the
+* the application developer's events, so that the application developer may
+* inspect both parts and easily differentiate or filter them. By default
+* all events are scoped to a global domain where NULL is provided or when
+* using APIs provided b versions of NVTX below v2
+*
+* Domains are intended to be typically long lived objects with the intention
+* of logically separating events of large modules from each other such as
+* middleware libraries from each other and the main application.
+*
+* See \ref DOMAINS for more details
+*
+* @{
+*/
+
+/* ------------------------------------------------------------------------- */
+/** \brief Register a NVTX domain.
+*
+* Domains are used to scope annotations. All NVTX_VERSION_0 and NVTX_VERSION_1
+* annotations are scoped to the global domain. The function nvtxDomainCreate
+* creates a new named domain.
+*
+* Each domain maintains its own nvtxRangePush and nvtxRangePop stack.
+*
+* \param name - A unique string representing the domain.
+*
+* \return A handle representing the domain.
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("com.nvidia.nvtx.example");
+*
+* nvtxMarkA("nvtxMarkA to global domain");
+*
+* nvtxEventAttributes_t eventAttrib1 = {0};
+* eventAttrib1.version = NVTX_VERSION;
+* eventAttrib1.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib1.message.ascii = "nvtxDomainMarkEx to global domain";
+* nvtxDomainMarkEx(NULL, &eventAttrib1);
+*
+* nvtxEventAttributes_t eventAttrib2 = {0};
+* eventAttrib2.version = NVTX_VERSION;
+* eventAttrib2.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
+* eventAttrib2.message.ascii = "nvtxDomainMarkEx to com.nvidia.nvtx.example";
+* nvtxDomainMarkEx(domain, &eventAttrib2);
+* nvtxDomainDestroy(domain);
+* \endcode
+*
+* \sa
+* ::nvtxDomainDestroy
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateA(const char* name);
+NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateW(const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Unregister a NVTX domain.
+*
+* Unregisters the domain handle and frees all domain specific resources.
+*
+* \param domain - the domain handle
+*
+* \par Example:
+* \code
+* nvtxDomainHandle_t domain = nvtxDomainCreateA("com.nvidia.nvtx.example");
+* nvtxDomainDestroy(domain);
+* \endcode
+*
+* \sa
+* ::nvtxDomainCreateA
+* ::nvtxDomainCreateW
+*
+* \version \NVTX_VERSION_2
+* @{ */
+NVTX_DECLSPEC void NVTX_API nvtxDomainDestroy(nvtxDomainHandle_t domain);
+/** @} */
+
+
+/** @} */ /*END defgroup*/
+/* ========================================================================= */
+/** \cond SHOW_HIDDEN */
+
+#ifdef UNICODE
+ #define nvtxMark nvtxMarkW
+ #define nvtxRangeStart nvtxRangeStartW
+ #define nvtxRangePush nvtxRangePushW
+ #define nvtxNameCategory nvtxNameCategoryW
+ #define nvtxNameOsThread nvtxNameOsThreadW
+ /* NVTX_VERSION_2 */
+ #define nvtxDomainCreate nvtxDomainCreateW
+ #define nvtxDomainRegisterString nvtxDomainRegisterStringW
+ #define nvtxDomainNameCategory nvtxDomainNameCategoryW
+#else
+ #define nvtxMark nvtxMarkA
+ #define nvtxRangeStart nvtxRangeStartA
+ #define nvtxRangePush nvtxRangePushA
+ #define nvtxNameCategory nvtxNameCategoryA
+ #define nvtxNameOsThread nvtxNameOsThreadA
+ /* NVTX_VERSION_2 */
+ #define nvtxDomainCreate nvtxDomainCreateA
+ #define nvtxDomainRegisterString nvtxDomainRegisterStringA
+ #define nvtxDomainNameCategory nvtxDomainNameCategoryA
+#endif
+
+/** \endcond */
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
+
+#define NVTX_IMPL_GUARD /* Ensure other headers cannot included directly */
+
+#include "nvtxDetail/nvtxTypes.h"
+
+#ifndef NVTX_NO_IMPL
+#include "nvtxDetail/nvtxImpl.h"
+#endif /*NVTX_NO_IMPL*/
+
+#undef NVTX_IMPL_GUARD
+
+#endif /* !defined(NVTX_VERSION) */
diff --git a/src/include/nvtx3/nvToolsExtCuda.h b/src/include/nvtx3/nvToolsExtCuda.h
new file mode 100644
index 0000000..b1e654c
--- /dev/null
+++ b/src/include/nvtx3/nvToolsExtCuda.h
@@ -0,0 +1,141 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#include "nvToolsExt.h"
+
+#include "cuda.h"
+
+#ifndef NVTOOLSEXT_CUDA_V3
+#define NVTOOLSEXT_CUDA_V3
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* ========================================================================= */
+/** \name Functions for CUDA Resource Naming
+*/
+/** \addtogroup RESOURCE_NAMING
+ * \section RESOURCE_NAMING_CUDA CUDA Resource Naming
+ *
+ * This section covers the API functions that allow to annotate CUDA resources
+ * with user-provided names.
+ *
+ * @{
+ */
+
+/* ------------------------------------------------------------------------- */
+/* \cond SHOW_HIDDEN
+* \brief Used to build a non-colliding value for resource types separated class
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_RESOURCE_CLASS_CUDA 4
+/** \endcond */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Resource types for CUDA
+*/
+typedef enum nvtxResourceCUDAType_t
+{
+ NVTX_RESOURCE_TYPE_CUDA_DEVICE = NVTX_RESOURCE_MAKE_TYPE(CUDA, 1), /* CUdevice */
+ NVTX_RESOURCE_TYPE_CUDA_CONTEXT = NVTX_RESOURCE_MAKE_TYPE(CUDA, 2), /* CUcontext */
+ NVTX_RESOURCE_TYPE_CUDA_STREAM = NVTX_RESOURCE_MAKE_TYPE(CUDA, 3), /* CUstream */
+ NVTX_RESOURCE_TYPE_CUDA_EVENT = NVTX_RESOURCE_MAKE_TYPE(CUDA, 4), /* CUevent */
+} nvtxResourceCUDAType_t;
+
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA device.
+ *
+ * Allows the user to associate a CUDA device with a user-provided name.
+ *
+ * \param device - The handle of the CUDA device to name.
+ * \param name - The name of the CUDA device.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceA(CUdevice device, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceW(CUdevice device, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA context.
+ *
+ * Allows the user to associate a CUDA context with a user-provided name.
+ *
+ * \param context - The handle of the CUDA context to name.
+ * \param name - The name of the CUDA context.
+ *
+ * \par Example:
+ * \code
+ * CUresult status = cuCtxCreate( &cuContext, 0, cuDevice );
+ * if ( CUDA_SUCCESS != status )
+ * goto Error;
+ * nvtxNameCuContext(cuContext, "CTX_NAME");
+ * \endcode
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCuContextA(CUcontext context, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCuContextW(CUcontext context, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA stream.
+ *
+ * Allows the user to associate a CUDA stream with a user-provided name.
+ *
+ * \param stream - The handle of the CUDA stream to name.
+ * \param name - The name of the CUDA stream.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamA(CUstream stream, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamW(CUstream stream, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA event.
+ *
+ * Allows the user to associate a CUDA event with a user-provided name.
+ *
+ * \param event - The handle of the CUDA event to name.
+ * \param name - The name of the CUDA event.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCuEventA(CUevent event, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCuEventW(CUevent event, const wchar_t* name);
+/** @} */
+
+/** @} */ /* END RESOURCE_NAMING */
+
+/* ========================================================================= */
+#ifdef UNICODE
+ #define nvtxNameCuDevice nvtxNameCuDeviceW
+ #define nvtxNameCuContext nvtxNameCuContextW
+ #define nvtxNameCuStream nvtxNameCuStreamW
+ #define nvtxNameCuEvent nvtxNameCuEventW
+#else
+ #define nvtxNameCuDevice nvtxNameCuDeviceA
+ #define nvtxNameCuContext nvtxNameCuContextA
+ #define nvtxNameCuStream nvtxNameCuStreamA
+ #define nvtxNameCuEvent nvtxNameCuEventA
+#endif
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#ifndef NVTX_NO_IMPL
+#define NVTX_IMPL_GUARD_CUDA /* Ensure other headers cannot included directly */
+#include "nvtxDetail/nvtxImplCuda_v3.h"
+#undef NVTX_IMPL_GUARD_CUDA
+#endif /*NVTX_NO_IMPL*/
+
+#endif /* NVTOOLSEXT_CUDA_V3 */
diff --git a/src/include/nvtx3/nvToolsExtCudaRt.h b/src/include/nvtx3/nvToolsExtCudaRt.h
new file mode 100644
index 0000000..002f6e9
--- /dev/null
+++ b/src/include/nvtx3/nvToolsExtCudaRt.h
@@ -0,0 +1,117 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#include "nvToolsExt.h"
+
+#include "cuda.h"
+#include "driver_types.h"
+
+#ifndef NVTOOLSEXT_CUDART_V3
+#define NVTOOLSEXT_CUDART_V3
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* ========================================================================= */
+/** \name Functions for CUDA Resource Naming
+*/
+/** \addtogroup RESOURCE_NAMING
+ * \section RESOURCE_NAMING_CUDART CUDA Runtime Resource Naming
+ *
+ * This section covers the API functions that allow to annotate CUDA resources
+ * with user-provided names.
+ *
+ * @{
+ */
+
+/* ------------------------------------------------------------------------- */
+/* \cond SHOW_HIDDEN
+* \brief Used to build a non-colliding value for resource types separated class
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_RESOURCE_CLASS_CUDART 5
+/** \endcond */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Resource types for CUDART
+*/
+typedef enum nvtxResourceCUDARTType_t
+{
+ NVTX_RESOURCE_TYPE_CUDART_DEVICE = NVTX_RESOURCE_MAKE_TYPE(CUDART, 0), /* int device */
+ NVTX_RESOURCE_TYPE_CUDART_STREAM = NVTX_RESOURCE_MAKE_TYPE(CUDART, 1), /* cudaStream_t */
+ NVTX_RESOURCE_TYPE_CUDART_EVENT = NVTX_RESOURCE_MAKE_TYPE(CUDART, 2), /* cudaEvent_t */
+} nvtxResourceCUDARTType_t;
+
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA device.
+ *
+ * Allows the user to associate a CUDA device with a user-provided name.
+ *
+ * \param device - The id of the CUDA device to name.
+ * \param name - The name of the CUDA device.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceA(int device, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceW(int device, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA stream.
+ *
+ * Allows the user to associate a CUDA stream with a user-provided name.
+ *
+ * \param stream - The handle of the CUDA stream to name.
+ * \param name - The name of the CUDA stream.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamA(cudaStream_t stream, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamW(cudaStream_t stream, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates a CUDA event.
+ *
+ * Allows the user to associate a CUDA event with a user-provided name.
+ *
+ * \param event - The handle of the CUDA event to name.
+ * \param name - The name of the CUDA event.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventA(cudaEvent_t event, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventW(cudaEvent_t event, const wchar_t* name);
+/** @} */
+
+/** @} */ /* END RESOURCE_NAMING */
+
+/* ========================================================================= */
+#ifdef UNICODE
+ #define nvtxNameCudaDevice nvtxNameCudaDeviceW
+ #define nvtxNameCudaStream nvtxNameCudaStreamW
+ #define nvtxNameCudaEvent nvtxNameCudaEventW
+#else
+ #define nvtxNameCudaDevice nvtxNameCudaDeviceA
+ #define nvtxNameCudaStream nvtxNameCudaStreamA
+ #define nvtxNameCudaEvent nvtxNameCudaEventA
+#endif
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#ifndef NVTX_NO_IMPL
+#define NVTX_IMPL_GUARD_CUDART /* Ensure other headers cannot included directly */
+#include "nvtxDetail/nvtxImplCudaRt_v3.h"
+#undef NVTX_IMPL_GUARD_CUDART
+#endif /*NVTX_NO_IMPL*/
+
+#endif /* NVTOOLSEXT_CUDART_V3 */
diff --git a/src/include/nvtx3/nvToolsExtOpenCL.h b/src/include/nvtx3/nvToolsExtOpenCL.h
new file mode 100644
index 0000000..611c0cb
--- /dev/null
+++ b/src/include/nvtx3/nvToolsExtOpenCL.h
@@ -0,0 +1,191 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#include "nvToolsExt.h"
+
+#include <CL/cl.h>
+
+#ifndef NVTOOLSEXT_OPENCL_V3
+#define NVTOOLSEXT_OPENCL_V3
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* ========================================================================= */
+/** \name Functions for OpenCL Resource Naming
+ */
+/** \addtogroup RESOURCE_NAMING
+ * \section RESOURCE_NAMING_OPENCL OpenCL Resource Naming
+ *
+ * This section covers the API functions that allow to annotate OpenCL resources
+ * with user-provided names.
+ *
+ * @{
+ */
+
+/* ------------------------------------------------------------------------- */
+/* \cond SHOW_HIDDEN
+* \brief Used to build a non-colliding value for resource types separated class
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_RESOURCE_CLASS_OPENCL 6
+/** \endcond */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Resource types for OpenCL
+*/
+typedef enum nvtxResourceOpenCLType_t
+{
+ NVTX_RESOURCE_TYPE_OPENCL_DEVICE = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 1),
+ NVTX_RESOURCE_TYPE_OPENCL_CONTEXT = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 2),
+ NVTX_RESOURCE_TYPE_OPENCL_COMMANDQUEUE = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 3),
+ NVTX_RESOURCE_TYPE_OPENCL_MEMOBJECT = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 4),
+ NVTX_RESOURCE_TYPE_OPENCL_SAMPLER = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 5),
+ NVTX_RESOURCE_TYPE_OPENCL_PROGRAM = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 6),
+ NVTX_RESOURCE_TYPE_OPENCL_EVENT = NVTX_RESOURCE_MAKE_TYPE(OPENCL, 7),
+} nvtxResourceOpenCLType_t;
+
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL device.
+ *
+ * Allows to associate an OpenCL device with a user-provided name.
+ *
+ * \param device - The handle of the OpenCL device to name.
+ * \param name - The name of the OpenCL device.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceA(cl_device_id device, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceW(cl_device_id device, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL context.
+ *
+ * Allows to associate an OpenCL context with a user-provided name.
+ *
+ * \param context - The handle of the OpenCL context to name.
+ * \param name - The name of the OpenCL context.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClContextA(cl_context context, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClContextW(cl_context context, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL command queue.
+ *
+ * Allows to associate an OpenCL command queue with a user-provided name.
+ *
+ * \param command_queue - The handle of the OpenCL command queue to name.
+ * \param name - The name of the OpenCL command queue.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueA(cl_command_queue command_queue, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueW(cl_command_queue command_queue, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL memory object.
+ *
+ * Allows to associate an OpenCL memory object with a user-provided name.
+ *
+ * \param memobj - The handle of the OpenCL memory object to name.
+ * \param name - The name of the OpenCL memory object.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectA(cl_mem memobj, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectW(cl_mem memobj, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL sampler.
+ *
+ * Allows to associate an OpenCL sampler with a user-provided name.
+ *
+ * \param sampler - The handle of the OpenCL sampler to name.
+ * \param name - The name of the OpenCL sampler.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerA(cl_sampler sampler, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerW(cl_sampler sampler, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL program.
+ *
+ * Allows to associate an OpenCL program with a user-provided name.
+ *
+ * \param program - The handle of the OpenCL program to name.
+ * \param name - The name of the OpenCL program.
+ *
+ * \code
+ * cpProgram = clCreateProgramWithSource(cxGPUContext, 1,
+ * (const char **) &cSourceCL, &program_length, &ciErrNum);
+ * shrCheckErrorEX(ciErrNum, CL_SUCCESS, pCleanup);
+ * nvtxNameClProgram(cpProgram, L"PROGRAM_NAME");
+ * \endcode
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClProgramA(cl_program program, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClProgramW(cl_program program, const wchar_t* name);
+/** @} */
+
+/* ------------------------------------------------------------------------- */
+/** \brief Annotates an OpenCL event.
+ *
+ * Allows to associate an OpenCL event with a user-provided name.
+ *
+ * \param evnt - The handle of the OpenCL event to name.
+ * \param name - The name of the OpenCL event.
+ *
+ * \version \NVTX_VERSION_1
+ * @{ */
+NVTX_DECLSPEC void NVTX_API nvtxNameClEventA(cl_event evnt, const char* name);
+NVTX_DECLSPEC void NVTX_API nvtxNameClEventW(cl_event evnt, const wchar_t* name);
+/** @} */
+
+/** @} */ /* END RESOURCE_NAMING */
+
+/* ========================================================================= */
+#ifdef UNICODE
+ #define nvtxNameClDevice nvtxNameClDeviceW
+ #define nvtxNameClContext nvtxNameClContextW
+ #define nvtxNameClCommandQueue nvtxNameClCommandQueueW
+ #define nvtxNameClMemObject nvtxNameClMemObjectW
+ #define nvtxNameClSampler nvtxNameClSamplerW
+ #define nvtxNameClProgram nvtxNameClProgramW
+ #define nvtxNameClEvent nvtxNameClEventW
+#else
+ #define nvtxNameClDevice nvtxNameClDeviceA
+ #define nvtxNameClContext nvtxNameClContextA
+ #define nvtxNameClCommandQueue nvtxNameClCommandQueueA
+ #define nvtxNameClMemObject nvtxNameClMemObjectA
+ #define nvtxNameClSampler nvtxNameClSamplerA
+ #define nvtxNameClProgram nvtxNameClProgramA
+ #define nvtxNameClEvent nvtxNameClEventA
+#endif
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#ifndef NVTX_NO_IMPL
+#define NVTX_IMPL_GUARD_OPENCL /* Ensure other headers cannot included directly */
+#include "nvtxDetail/nvtxImplOpenCL_v3.h"
+#undef NVTX_IMPL_GUARD_OPENCL
+#endif /*NVTX_NO_IMPL*/
+
+#endif /* NVTOOLSEXT_OPENCL_V3 */
diff --git a/src/include/nvtx3/nvToolsExtSync.h b/src/include/nvtx3/nvToolsExtSync.h
new file mode 100644
index 0000000..5d24729
--- /dev/null
+++ b/src/include/nvtx3/nvToolsExtSync.h
@@ -0,0 +1,382 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#include "nvToolsExt.h"
+
+#ifndef NVTOOLSEXT_SYNC_V3
+#define NVTOOLSEXT_SYNC_V3
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* \cond SHOW_HIDDEN
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_SYNCUSER_ATTRIB_STRUCT_SIZE ( (uint16_t)( sizeof(nvtxSyncUserAttributes_v0) ) )
+/** \endcond */
+
+
+/**
+* \page PAGE_SYNCHRONIZATION Synchronization
+*
+* This section covers a subset of the API that allow users to track additional
+* synchronization details of their application. Naming OS synchronization primitives
+* may allow users to better understand the data collected by traced synchronization
+* APIs. Additionally, a user defined synchronization object can allow the users to
+* to tell the tools when the user is building their own synchronization system
+* that do not rely on the OS to provide behaviors and instead use techniques like
+* atomic operations and spinlocks.
+*
+* See module \ref SYNCHRONIZATION for details.
+*
+* \par Example:
+* \code
+* class MyMutex
+* {
+* volatile long bLocked;
+* nvtxSyncUser_t hSync;
+* public:
+* MyMutex(const char* name, nvtxDomainHandle_t d){
+* bLocked = 0;
+*
+* nvtxSyncUserAttributes_t attribs = { 0 };
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_SYNCUSER_ATTRIB_STRUCT_SIZE;
+* attribs.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* attribs.message.ascii = name;
+* hSync = nvtxDomainSyncUserCreate(d, &attribs);
+* }
+*
+* ~MyMutex() {
+* nvtxDomainSyncUserDestroy(hSync);
+* }
+*
+* bool Lock() {
+* nvtxDomainSyncUserAcquireStart(hSync);
+* bool acquired = __sync_bool_compare_and_swap(&bLocked, 0, 1);//atomic compiler intrinsic
+
+* if (acquired) {
+* nvtxDomainSyncUserAcquireSuccess(hSync);
+* }
+* else {
+* nvtxDomainSyncUserAcquireFailed(hSync);
+* }
+* return acquired;
+* }
+
+* void Unlock() {
+* nvtxDomainSyncUserReleasing(hSync);
+* bLocked = false;
+* }
+* };
+* \endcode
+*
+* \version \NVTX_VERSION_2
+*/
+
+/* ------------------------------------------------------------------------- */
+/* \cond SHOW_HIDDEN
+* \brief Used to build a non-colliding value for resource types separated class
+* \version \NVTX_VERSION_2
+*/
+#define NVTX_RESOURCE_CLASS_SYNC_OS 2 /**< Synchronization objects that are OS specific. */
+#define NVTX_RESOURCE_CLASS_SYNC_PTHREAD 3 /**< Synchronization objects that are from the POSIX Threads API (pthread)*/
+/** \endcond */
+
+
+/* ------------------------------------------------------------------------- */
+/** \defgroup SYNCHRONIZATION Synchronization
+* See page \ref PAGE_SYNCHRONIZATION.
+* @{
+*/
+
+/** \brief Resource type values for OSs with POSIX Thread API support
+ */
+typedef enum nvtxResourceSyncPosixThreadType_t
+{
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_MUTEX = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 1), /* pthread_mutex_t */
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_CONDITION = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 2), /* pthread_cond_t */
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_RWLOCK = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 3), /* pthread_rwlock_t */
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_BARRIER = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 4), /* pthread_barrier_t */
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_SPINLOCK = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 5), /* pthread_spinlock_t */
+ NVTX_RESOURCE_TYPE_SYNC_PTHREAD_ONCE = NVTX_RESOURCE_MAKE_TYPE(SYNC_PTHREAD, 6) /* pthread_once_t */
+} nvtxResourceSyncPosixThreadType_t;
+
+/** \brief Resource type values for Windows OSs
+*/
+typedef enum nvtxResourceSyncWindowsType_t
+{
+ NVTX_RESOURCE_TYPE_SYNC_WINDOWS_MUTEX = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 1),
+ NVTX_RESOURCE_TYPE_SYNC_WINDOWS_SEMAPHORE = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 2),
+ NVTX_RESOURCE_TYPE_SYNC_WINDOWS_EVENT = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 3),
+ NVTX_RESOURCE_TYPE_SYNC_WINDOWS_CRITICAL_SECTION = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 4),
+ NVTX_RESOURCE_TYPE_SYNC_WINDOWS_SRWLOCK = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 5)
+} nvtxResourceSyncWindowsType_t;
+
+/** \brief Resource type values for Linux and Linux derived OSs such as Android
+* \sa
+* ::nvtxResourceSyncPosixThreadType_t
+*/
+typedef enum nvtxResourceSyncLinuxType_t
+{
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_MUTEX = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 1),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_FUTEX = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 2),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_SEMAPHORE = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 3),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_COMPLETION = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 4),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_SPINLOCK = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 5),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_SEQLOCK = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 6),
+ NVTX_RESOURCE_TYPE_SYNC_LINUX_RCU = NVTX_RESOURCE_MAKE_TYPE(SYNC_OS, 7)
+} nvtxResourceSyncLinuxType_t;
+
+/** \brief Resource type values for Android come from Linux.
+* \sa
+* ::nvtxResourceSyncLinuxType_t
+* ::nvtxResourceSyncPosixThreadType_t
+*/
+typedef enum nvtxResourceSyncLinuxType_t nvtxResourceSyncAndroidType_t;
+
+/** \brief User Defined Synchronization Object Handle .
+* \anchor SYNCUSER_HANDLE_STRUCTURE
+*
+* This structure is opaque to the user and is used as a handle to reference
+* a user defined syncrhonization object. The tools will return a pointer through the API for the application
+* to hold on it's behalf to reference the string in the future.
+*
+*/
+typedef struct nvtxSyncUser* nvtxSyncUser_t;
+
+/** \brief User Defined Synchronization Object Attributes Structure.
+* \anchor USERDEF_SYNC_ATTRIBUTES_STRUCTURE
+*
+* This structure is used to describe the attributes of a user defined synchronization
+* object. The layout of the structure is defined by a specific version of the tools
+* extension library and can change between different versions of the Tools Extension
+* library.
+*
+* \par Initializing the Attributes
+*
+* The caller should always perform the following three tasks when using
+* attributes:
+* <ul>
+* <li>Zero the structure
+* <li>Set the version field
+* <li>Set the size field
+* </ul>
+*
+* Zeroing the structure sets all the event attributes types and values
+* to the default value.
+*
+* The version and size field are used by the Tools Extension
+* implementation to handle multiple versions of the attributes structure.
+*
+* It is recommended that the caller use one of the following to methods
+* to initialize the event attributes structure:
+*
+* \par Method 1: Initializing nvtxEventAttributes for future compatibility
+* \code
+* nvtxSyncUserAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_SYNCUSER_ATTRIB_STRUCT_SIZE;
+* \endcode
+*
+* \par Method 2: Initializing nvtxSyncUserAttributes_t for a specific version
+* \code
+* nvtxSyncUserAttributes_t attribs = {0};
+* attribs.version = 1;
+* attribs.size = (uint16_t)(sizeof(nvtxSyncUserAttributes_t));
+* \endcode
+*
+* If the caller uses Method 1 it is critical that the entire binary
+* layout of the structure be configured to 0 so that all fields
+* are initialized to the default value.
+*
+* The caller should either use both NVTX_VERSION and
+* NVTX_SYNCUSER_ATTRIB_STRUCT_SIZE (Method 1) or use explicit values
+* and a versioned type (Method 2). Using a mix of the two methods
+* will likely cause either source level incompatibility or binary
+* incompatibility in the future.
+*
+* \par Settings Attribute Types and Values
+*
+*
+* \par Example:
+* \code
+* // Initialize
+* nvtxSyncUserAttributes_t attribs = {0};
+* attribs.version = NVTX_VERSION;
+* attribs.size = NVTX_SYNCUSER_ATTRIB_STRUCT_SIZE;
+*
+* // Configure the Attributes
+* attribs.messageType = NVTX_MESSAGE_TYPE_ASCII;
+* attribs.message.ascii = "Example";
+* \endcode
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+*/
+typedef struct nvtxSyncUserAttributes_v0
+{
+ /**
+ * \brief Version flag of the structure.
+ *
+ * Needs to be set to NVTX_VERSION to indicate the version of NVTX APIs
+ * supported in this header file. This can optionally be overridden to
+ * another version of the tools extension library.
+ */
+ uint16_t version;
+
+ /**
+ * \brief Size of the structure.
+ *
+ * Needs to be set to the size in bytes of the event attribute
+ * structure used to specify the event.
+ */
+ uint16_t size;
+
+ /** \brief Message type specified in this attribute structure.
+ *
+ * Defines the message format of the attribute structure's \ref nvtxSyncUserAttributes_v0::message
+ * "message" field.
+ *
+ * Default Value is NVTX_MESSAGE_UNKNOWN
+ */
+ int32_t messageType; /* nvtxMessageType_t */
+
+ /** \brief Message assigned to this attribute structure.
+ *
+ * The text message that is attached to an event.
+ */
+ nvtxMessageValue_t message;
+
+} nvtxSyncUserAttributes_v0;
+
+typedef struct nvtxSyncUserAttributes_v0 nvtxSyncUserAttributes_t;
+
+/* ------------------------------------------------------------------------- */
+/** \brief Create a user defined synchronization object
+* This is used to track non-OS synchronization working with spinlocks and atomics
+*
+* \param domain - Domain to own the resource
+* \param attribs - A structure to assign multiple attributes to the object.
+*
+* \return A handle that represents the newly created user defined synchronization object.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/
+NVTX_DECLSPEC nvtxSyncUser_t NVTX_API nvtxDomainSyncUserCreate(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
+
+/* ------------------------------------------------------------------------- */
+/** \brief Destroy a user defined synchronization object
+* This is used to track non-OS synchronization working with spinlocks and atomics
+*
+* \param handle - A handle to the object to operate on.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserDestroy(nvtxSyncUser_t handle);
+
+/* ------------------------------------------------------------------------- */
+/** \brief Signal to tools that an attempt to acquire a user defined synchronization object
+*
+* \param handle - A handle to the object to operate on.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireStart(nvtxSyncUser_t handle);
+
+/* ------------------------------------------------------------------------- */
+/** \brief Signal to tools of failure in acquiring a user defined synchronization object
+* This should be called after \ref nvtxDomainSyncUserAcquireStart
+*
+* \param handle - A handle to the object to operate on.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireFailed(nvtxSyncUser_t handle);
+
+/* ------------------------------------------------------------------------- */
+/** \brief Signal to tools of success in acquiring a user defined synchronization object
+* This should be called after \ref nvtxDomainSyncUserAcquireStart.
+*
+* \param handle - A handle to the object to operate on.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireSuccess(nvtxSyncUser_t handle);
+
+/* ------------------------------------------------------------------------- */
+/** \brief Signal to tools of releasing a reservation on user defined synchronization object
+* This should be called after \ref nvtxDomainSyncUserAcquireSuccess.
+*
+* \param handle - A handle to the object to operate on.
+*
+* \sa
+* ::nvtxDomainSyncUserCreate
+* ::nvtxDomainSyncUserDestroy
+* ::nvtxDomainSyncUserAcquireStart
+* ::nvtxDomainSyncUserAcquireFailed
+* ::nvtxDomainSyncUserAcquireSuccess
+* ::nvtxDomainSyncUserReleasing
+*
+* \version \NVTX_VERSION_2
+*/
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserReleasing(nvtxSyncUser_t handle);
+
+
+/** @} */ /*END defgroup*/
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#ifndef NVTX_NO_IMPL
+#define NVTX_IMPL_GUARD_SYNC /* Ensure other headers cannot included directly */
+#include "nvtxDetail/nvtxImplSync_v3.h"
+#undef NVTX_IMPL_GUARD_SYNC
+#endif /*NVTX_NO_IMPL*/
+
+#endif /* NVTOOLSEXT_SYNC_V3 */
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImpl.h b/src/include/nvtx3/nvtxDetail/nvtxImpl.h
new file mode 100644
index 0000000..be27f43
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImpl.h
@@ -0,0 +1,438 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD
+#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+/* ---- Include required platform headers ---- */
+
+#if defined(_WIN32)
+
+#include <Windows.h>
+
+#else
+#include <unistd.h>
+
+#if defined(__ANDROID__)
+#include <android/api-level.h>
+#endif
+
+#if defined(__linux__) || defined(__CYGWIN__)
+#include <sched.h>
+#endif
+
+#include <limits.h>
+#include <dlfcn.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <errno.h>
+
+#include <string.h>
+#include <sys/types.h>
+#include <pthread.h>
+#include <stdlib.h>
+#include <wchar.h>
+
+#endif
+
+/* ---- Define macros used in this file ---- */
+
+#define NVTX_INIT_STATE_FRESH 0
+#define NVTX_INIT_STATE_STARTED 1
+#define NVTX_INIT_STATE_COMPLETE 2
+
+#ifdef NVTX_DEBUG_PRINT
+#ifdef __ANDROID__
+#include <android/log.h>
+#define NVTX_ERR(...) __android_log_print(ANDROID_LOG_ERROR, "NVTOOLSEXT", __VA_ARGS__);
+#define NVTX_INFO(...) __android_log_print(ANDROID_LOG_INFO, "NVTOOLSEXT", __VA_ARGS__);
+#else
+#include <stdio.h>
+#define NVTX_ERR(...) fprintf(stderr, "NVTX_ERROR: " __VA_ARGS__)
+#define NVTX_INFO(...) fprintf(stderr, "NVTX_INFO: " __VA_ARGS__)
+#endif
+#else /* !defined(NVTX_DEBUG_PRINT) */
+#define NVTX_ERR(...)
+#define NVTX_INFO(...)
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+#ifdef __GNUC__
+#pragma GCC visibility push(hidden)
+#endif
+
+/* ---- Forward declare all functions referenced in globals ---- */
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)(void);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)(
+ NvtxCallbackModule module,
+ NvtxFunctionTable* out_table,
+ unsigned int* out_size);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)(
+ uint32_t version);
+NVTX_LINKONCE_FWDDECL_FUNCTION const void* NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable)(
+ uint32_t exportTableId);
+
+#include "nvtxInitDecls.h"
+
+/* ---- Define all globals ---- */
+
+typedef struct nvtxGlobals_t
+{
+ volatile unsigned int initState;
+ NvtxExportTableCallbacks etblCallbacks;
+ NvtxExportTableVersionInfo etblVersionInfo;
+
+ /* Implementation function pointers */
+ nvtxMarkEx_impl_fntype nvtxMarkEx_impl_fnptr;
+ nvtxMarkA_impl_fntype nvtxMarkA_impl_fnptr;
+ nvtxMarkW_impl_fntype nvtxMarkW_impl_fnptr;
+ nvtxRangeStartEx_impl_fntype nvtxRangeStartEx_impl_fnptr;
+ nvtxRangeStartA_impl_fntype nvtxRangeStartA_impl_fnptr;
+ nvtxRangeStartW_impl_fntype nvtxRangeStartW_impl_fnptr;
+ nvtxRangeEnd_impl_fntype nvtxRangeEnd_impl_fnptr;
+ nvtxRangePushEx_impl_fntype nvtxRangePushEx_impl_fnptr;
+ nvtxRangePushA_impl_fntype nvtxRangePushA_impl_fnptr;
+ nvtxRangePushW_impl_fntype nvtxRangePushW_impl_fnptr;
+ nvtxRangePop_impl_fntype nvtxRangePop_impl_fnptr;
+ nvtxNameCategoryA_impl_fntype nvtxNameCategoryA_impl_fnptr;
+ nvtxNameCategoryW_impl_fntype nvtxNameCategoryW_impl_fnptr;
+ nvtxNameOsThreadA_impl_fntype nvtxNameOsThreadA_impl_fnptr;
+ nvtxNameOsThreadW_impl_fntype nvtxNameOsThreadW_impl_fnptr;
+
+ nvtxNameCuDeviceA_fakeimpl_fntype nvtxNameCuDeviceA_impl_fnptr;
+ nvtxNameCuDeviceW_fakeimpl_fntype nvtxNameCuDeviceW_impl_fnptr;
+ nvtxNameCuContextA_fakeimpl_fntype nvtxNameCuContextA_impl_fnptr;
+ nvtxNameCuContextW_fakeimpl_fntype nvtxNameCuContextW_impl_fnptr;
+ nvtxNameCuStreamA_fakeimpl_fntype nvtxNameCuStreamA_impl_fnptr;
+ nvtxNameCuStreamW_fakeimpl_fntype nvtxNameCuStreamW_impl_fnptr;
+ nvtxNameCuEventA_fakeimpl_fntype nvtxNameCuEventA_impl_fnptr;
+ nvtxNameCuEventW_fakeimpl_fntype nvtxNameCuEventW_impl_fnptr;
+
+ nvtxNameClDeviceA_fakeimpl_fntype nvtxNameClDeviceA_impl_fnptr;
+ nvtxNameClDeviceW_fakeimpl_fntype nvtxNameClDeviceW_impl_fnptr;
+ nvtxNameClContextA_fakeimpl_fntype nvtxNameClContextA_impl_fnptr;
+ nvtxNameClContextW_fakeimpl_fntype nvtxNameClContextW_impl_fnptr;
+ nvtxNameClCommandQueueA_fakeimpl_fntype nvtxNameClCommandQueueA_impl_fnptr;
+ nvtxNameClCommandQueueW_fakeimpl_fntype nvtxNameClCommandQueueW_impl_fnptr;
+ nvtxNameClMemObjectA_fakeimpl_fntype nvtxNameClMemObjectA_impl_fnptr;
+ nvtxNameClMemObjectW_fakeimpl_fntype nvtxNameClMemObjectW_impl_fnptr;
+ nvtxNameClSamplerA_fakeimpl_fntype nvtxNameClSamplerA_impl_fnptr;
+ nvtxNameClSamplerW_fakeimpl_fntype nvtxNameClSamplerW_impl_fnptr;
+ nvtxNameClProgramA_fakeimpl_fntype nvtxNameClProgramA_impl_fnptr;
+ nvtxNameClProgramW_fakeimpl_fntype nvtxNameClProgramW_impl_fnptr;
+ nvtxNameClEventA_fakeimpl_fntype nvtxNameClEventA_impl_fnptr;
+ nvtxNameClEventW_fakeimpl_fntype nvtxNameClEventW_impl_fnptr;
+
+ nvtxNameCudaDeviceA_impl_fntype nvtxNameCudaDeviceA_impl_fnptr;
+ nvtxNameCudaDeviceW_impl_fntype nvtxNameCudaDeviceW_impl_fnptr;
+ nvtxNameCudaStreamA_fakeimpl_fntype nvtxNameCudaStreamA_impl_fnptr;
+ nvtxNameCudaStreamW_fakeimpl_fntype nvtxNameCudaStreamW_impl_fnptr;
+ nvtxNameCudaEventA_fakeimpl_fntype nvtxNameCudaEventA_impl_fnptr;
+ nvtxNameCudaEventW_fakeimpl_fntype nvtxNameCudaEventW_impl_fnptr;
+
+ nvtxDomainMarkEx_impl_fntype nvtxDomainMarkEx_impl_fnptr;
+ nvtxDomainRangeStartEx_impl_fntype nvtxDomainRangeStartEx_impl_fnptr;
+ nvtxDomainRangeEnd_impl_fntype nvtxDomainRangeEnd_impl_fnptr;
+ nvtxDomainRangePushEx_impl_fntype nvtxDomainRangePushEx_impl_fnptr;
+ nvtxDomainRangePop_impl_fntype nvtxDomainRangePop_impl_fnptr;
+ nvtxDomainResourceCreate_impl_fntype nvtxDomainResourceCreate_impl_fnptr;
+ nvtxDomainResourceDestroy_impl_fntype nvtxDomainResourceDestroy_impl_fnptr;
+ nvtxDomainNameCategoryA_impl_fntype nvtxDomainNameCategoryA_impl_fnptr;
+ nvtxDomainNameCategoryW_impl_fntype nvtxDomainNameCategoryW_impl_fnptr;
+ nvtxDomainRegisterStringA_impl_fntype nvtxDomainRegisterStringA_impl_fnptr;
+ nvtxDomainRegisterStringW_impl_fntype nvtxDomainRegisterStringW_impl_fnptr;
+ nvtxDomainCreateA_impl_fntype nvtxDomainCreateA_impl_fnptr;
+ nvtxDomainCreateW_impl_fntype nvtxDomainCreateW_impl_fnptr;
+ nvtxDomainDestroy_impl_fntype nvtxDomainDestroy_impl_fnptr;
+ nvtxInitialize_impl_fntype nvtxInitialize_impl_fnptr;
+
+ nvtxDomainSyncUserCreate_impl_fntype nvtxDomainSyncUserCreate_impl_fnptr;
+ nvtxDomainSyncUserDestroy_impl_fntype nvtxDomainSyncUserDestroy_impl_fnptr;
+ nvtxDomainSyncUserAcquireStart_impl_fntype nvtxDomainSyncUserAcquireStart_impl_fnptr;
+ nvtxDomainSyncUserAcquireFailed_impl_fntype nvtxDomainSyncUserAcquireFailed_impl_fnptr;
+ nvtxDomainSyncUserAcquireSuccess_impl_fntype nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
+ nvtxDomainSyncUserReleasing_impl_fntype nvtxDomainSyncUserReleasing_impl_fnptr;
+
+ /* Tables of function pointers -- Extra null added to the end to ensure
+ * a crash instead of silent corruption if a tool reads off the end. */
+ NvtxFunctionPointer* functionTable_CORE [NVTX_CBID_CORE_SIZE + 1];
+ NvtxFunctionPointer* functionTable_CUDA [NVTX_CBID_CUDA_SIZE + 1];
+ NvtxFunctionPointer* functionTable_OPENCL[NVTX_CBID_OPENCL_SIZE + 1];
+ NvtxFunctionPointer* functionTable_CUDART[NVTX_CBID_CUDART_SIZE + 1];
+ NvtxFunctionPointer* functionTable_CORE2 [NVTX_CBID_CORE2_SIZE + 1];
+ NvtxFunctionPointer* functionTable_SYNC [NVTX_CBID_SYNC_SIZE + 1];
+} nvtxGlobals_t;
+
+NVTX_LINKONCE_DEFINE_GLOBAL nvtxGlobals_t NVTX_VERSIONED_IDENTIFIER(nvtxGlobals) =
+{
+ NVTX_INIT_STATE_FRESH,
+
+ {
+ sizeof(NvtxExportTableCallbacks),
+ NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)
+ },
+ {
+ sizeof(NvtxExportTableVersionInfo),
+ NVTX_VERSION,
+ 0,
+ NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)
+ },
+
+ /* Implementation function pointers */
+ NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init),
+
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init),
+
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init),
+
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init),
+
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init),
+
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init),
+ NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init),
+
+ /* Tables of function pointers */
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr,
+ 0
+ },
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr,
+ 0
+ },
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr,
+ 0
+ },
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr,
+ 0
+ },
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr,
+ 0
+ },
+ {
+ 0,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr,
+ (NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr,
+ 0
+ }
+};
+
+/* ---- Define static inline implementations of core API functions ---- */
+
+#include "nvtxImplCore.h"
+
+/* ---- Define implementations of export table functions ---- */
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)(
+ NvtxCallbackModule module,
+ NvtxFunctionTable* out_table,
+ unsigned int* out_size)
+{
+ unsigned int bytes = 0;
+ NvtxFunctionTable table = (NvtxFunctionTable)0;
+
+ switch (module)
+ {
+ case NVTX_CB_MODULE_CORE:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE);
+ break;
+ case NVTX_CB_MODULE_CUDA:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDA;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDA);
+ break;
+ case NVTX_CB_MODULE_OPENCL:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_OPENCL;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_OPENCL);
+ break;
+ case NVTX_CB_MODULE_CUDART:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDART;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDART);
+ break;
+ case NVTX_CB_MODULE_CORE2:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE2;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE2);
+ break;
+ case NVTX_CB_MODULE_SYNC:
+ table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_SYNC;
+ bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_SYNC);
+ break;
+ default: return 0;
+ }
+
+ if (out_size)
+ *out_size = (bytes / (unsigned int)sizeof(NvtxFunctionPointer*)) - 1;
+
+ if (out_table)
+ *out_table = table;
+
+ return 1;
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION const void* NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable)(uint32_t exportTableId)
+{
+ switch (exportTableId)
+ {
+ case NVTX_ETID_CALLBACKS: return &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).etblCallbacks;
+ case NVTX_ETID_VERSIONINFO: return &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).etblVersionInfo;
+ default: return 0;
+ }
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)(uint32_t version)
+{
+ /* Reserved for custom implementations to resolve problems with tools */
+ (void)version;
+}
+
+/* ---- Define implementations of init versions of all API functions ---- */
+
+#include "nvtxInitDefs.h"
+
+/* ---- Define implementations of initialization functions ---- */
+
+#include "nvtxInit.h"
+
+#ifdef __GNUC__
+#pragma GCC visibility pop
+#endif
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCore.h b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h
new file mode 100644
index 0000000..9f014ca
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h
@@ -0,0 +1,307 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+NVTX_DECLSPEC void NVTX_API nvtxMarkEx(const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxMarkEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr;
+ if(local!=0)
+ (*local)(eventAttrib);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxMarkA(const char* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxMarkA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr;
+ if(local!=0)
+ (*local)(message);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxMarkW(const wchar_t* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxMarkW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr;
+ if(local!=0)
+ (*local)(message);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartEx(const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangeStartEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr;
+ if(local!=0)
+ return (*local)(eventAttrib);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxRangeId_t)0;
+}
+
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartA(const char* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangeStartA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxRangeId_t)0;
+}
+
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartW(const wchar_t* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangeStartW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxRangeId_t)0;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxRangeEnd(nvtxRangeId_t id)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangeEnd_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr;
+ if(local!=0)
+ (*local)(id);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxRangePushEx(const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangePushEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr;
+ if(local!=0)
+ return (*local)(eventAttrib);
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxRangePushA(const char* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangePushA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxRangePushW(const wchar_t* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangePushW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxRangePop(void)
+{
+#ifndef NVTX_DISABLE
+ nvtxRangePop_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr;
+ if(local!=0)
+ return (*local)();
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCategoryA(uint32_t category, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCategoryA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr;
+ if(local!=0)
+ (*local)(category, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCategoryW(uint32_t category, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCategoryW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr;
+ if(local!=0)
+ (*local)(category, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadA(uint32_t threadId, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameOsThreadA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr;
+ if(local!=0)
+ (*local)(threadId, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadW(uint32_t threadId, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameOsThreadW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr;
+ if(local!=0)
+ (*local)(threadId, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainMarkEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainMarkEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr;
+ if(local!=0)
+ (*local)(domain, eventAttrib);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxDomainRangeStartEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRangeStartEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, eventAttrib);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxRangeId_t)0;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainRangeEnd(nvtxDomainHandle_t domain, nvtxRangeId_t id)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRangeEnd_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr;
+ if(local!=0)
+ (*local)(domain, id);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxDomainRangePushEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRangePushEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, eventAttrib);
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC int NVTX_API nvtxDomainRangePop(nvtxDomainHandle_t domain)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRangePop_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain);
+ else
+#endif /*NVTX_DISABLE*/
+ return (int)NVTX_NO_PUSH_POP_TRACKING;
+}
+
+NVTX_DECLSPEC nvtxResourceHandle_t NVTX_API nvtxDomainResourceCreate(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainResourceCreate_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, attribs);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxResourceHandle_t)0;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainResourceDestroy(nvtxResourceHandle_t resource)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainResourceDestroy_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr;
+ if(local!=0)
+ (*local)(resource);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryA(nvtxDomainHandle_t domain, uint32_t category, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainNameCategoryA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr;
+ if(local!=0)
+ (*local)(domain, category, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryW(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainNameCategoryW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr;
+ if(local!=0)
+ (*local)(domain, category, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringA(nvtxDomainHandle_t domain, const char* string)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRegisterStringA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, string);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxStringHandle_t)0;
+}
+
+NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringW(nvtxDomainHandle_t domain, const wchar_t* string)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainRegisterStringW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, string);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxStringHandle_t)0;
+}
+
+NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateA(const char* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainCreateA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxDomainHandle_t)0;
+}
+
+NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateW(const wchar_t* message)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainCreateW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr;
+ if(local!=0)
+ return (*local)(message);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxDomainHandle_t)0;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainDestroy(nvtxDomainHandle_t domain)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainDestroy_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr;
+ if(local!=0)
+ (*local)(domain);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxInitialize(const void* reserved)
+{
+#ifndef NVTX_DISABLE
+ nvtxInitialize_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr;
+ if(local!=0)
+ (*local)(reserved);
+#endif /*NVTX_DISABLE*/
+}
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h
new file mode 100644
index 0000000..d4c0cdf
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h
@@ -0,0 +1,81 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD_CUDART
+#error Never include this file directly -- it is automatically included by nvToolsExtCudaRt.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+typedef void (NVTX_API * nvtxNameCudaDeviceA_impl_fntype)(int device, const char* name);
+typedef void (NVTX_API * nvtxNameCudaDeviceW_impl_fntype)(int device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCudaStreamA_impl_fntype)(cudaStream_t stream, const char* name);
+typedef void (NVTX_API * nvtxNameCudaStreamW_impl_fntype)(cudaStream_t stream, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCudaEventA_impl_fntype)(cudaEvent_t event, const char* name);
+typedef void (NVTX_API * nvtxNameCudaEventW_impl_fntype)(cudaEvent_t event, const wchar_t* name);
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceA(int device, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaDeviceA_impl_fntype local = (nvtxNameCudaDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceW(int device, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaDeviceW_impl_fntype local = (nvtxNameCudaDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamA(cudaStream_t stream, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaStreamA_impl_fntype local = (nvtxNameCudaStreamA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr;
+ if(local!=0)
+ (*local)(stream, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamW(cudaStream_t stream, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaStreamW_impl_fntype local = (nvtxNameCudaStreamW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr;
+ if(local!=0)
+ (*local)(stream, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventA(cudaEvent_t event, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaEventA_impl_fntype local = (nvtxNameCudaEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr;
+ if(local!=0)
+ (*local)(event, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventW(cudaEvent_t event, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCudaEventW_impl_fntype local = (nvtxNameCudaEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr;
+ if(local!=0)
+ (*local)(event, name);
+#endif /*NVTX_DISABLE*/
+}
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
+
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h
new file mode 100644
index 0000000..4b5d6c7
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h
@@ -0,0 +1,102 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD_CUDA
+#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+typedef void (NVTX_API * nvtxNameCuDeviceA_impl_fntype)(CUdevice device, const char* name);
+typedef void (NVTX_API * nvtxNameCuDeviceW_impl_fntype)(CUdevice device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuContextA_impl_fntype)(CUcontext context, const char* name);
+typedef void (NVTX_API * nvtxNameCuContextW_impl_fntype)(CUcontext context, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuStreamA_impl_fntype)(CUstream stream, const char* name);
+typedef void (NVTX_API * nvtxNameCuStreamW_impl_fntype)(CUstream stream, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuEventA_impl_fntype)(CUevent event, const char* name);
+typedef void (NVTX_API * nvtxNameCuEventW_impl_fntype)(CUevent event, const wchar_t* name);
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceA(CUdevice device, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuDeviceA_impl_fntype local = (nvtxNameCuDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceW(CUdevice device, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuDeviceW_impl_fntype local = (nvtxNameCuDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuContextA(CUcontext context, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuContextA_impl_fntype local = (nvtxNameCuContextA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr;
+ if(local!=0)
+ (*local)(context, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuContextW(CUcontext context, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuContextW_impl_fntype local = (nvtxNameCuContextW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr;
+ if(local!=0)
+ (*local)(context, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamA(CUstream stream, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuStreamA_impl_fntype local = (nvtxNameCuStreamA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr;
+ if(local!=0)
+ (*local)(stream, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamW(CUstream stream, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuStreamW_impl_fntype local = (nvtxNameCuStreamW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr;
+ if(local!=0)
+ (*local)(stream, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuEventA(CUevent event, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuEventA_impl_fntype local = (nvtxNameCuEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr;
+ if(local!=0)
+ (*local)(event, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameCuEventW(CUevent event, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameCuEventW_impl_fntype local = (nvtxNameCuEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr;
+ if(local!=0)
+ (*local)(event, name);
+#endif /*NVTX_DISABLE*/
+}
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
+
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h
new file mode 100644
index 0000000..4a026f0
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h
@@ -0,0 +1,161 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD_OPENCL
+#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+typedef void (NVTX_API * nvtxNameClDeviceA_impl_fntype)(cl_device_id device, const char* name);
+typedef void (NVTX_API * nvtxNameClDeviceW_impl_fntype)(cl_device_id device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClContextA_impl_fntype)(cl_context context, const char* name);
+typedef void (NVTX_API * nvtxNameClContextW_impl_fntype)(cl_context context, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClCommandQueueA_impl_fntype)(cl_command_queue command_queue, const char* name);
+typedef void (NVTX_API * nvtxNameClCommandQueueW_impl_fntype)(cl_command_queue command_queue, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClMemObjectA_impl_fntype)(cl_mem memobj, const char* name);
+typedef void (NVTX_API * nvtxNameClMemObjectW_impl_fntype)(cl_mem memobj, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClSamplerA_impl_fntype)(cl_sampler sampler, const char* name);
+typedef void (NVTX_API * nvtxNameClSamplerW_impl_fntype)(cl_sampler sampler, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClProgramA_impl_fntype)(cl_program program, const char* name);
+typedef void (NVTX_API * nvtxNameClProgramW_impl_fntype)(cl_program program, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClEventA_impl_fntype)(cl_event evnt, const char* name);
+typedef void (NVTX_API * nvtxNameClEventW_impl_fntype)(cl_event evnt, const wchar_t* name);
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceA(cl_device_id device, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClDeviceA_impl_fntype local = (nvtxNameClDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceW(cl_device_id device, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClDeviceW_impl_fntype local = (nvtxNameClDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr;
+ if(local!=0)
+ (*local)(device, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClContextA(cl_context context, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClContextA_impl_fntype local = (nvtxNameClContextA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr;
+ if(local!=0)
+ (*local)(context, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClContextW(cl_context context, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClContextW_impl_fntype local = (nvtxNameClContextW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr;
+ if(local!=0)
+ (*local)(context, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueA(cl_command_queue command_queue, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClCommandQueueA_impl_fntype local = (nvtxNameClCommandQueueA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr;
+ if(local!=0)
+ (*local)(command_queue, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueW(cl_command_queue command_queue, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClCommandQueueW_impl_fntype local = (nvtxNameClCommandQueueW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr;
+ if(local!=0)
+ (*local)(command_queue, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectA(cl_mem memobj, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClMemObjectA_impl_fntype local = (nvtxNameClMemObjectA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr;
+ if(local!=0)
+ (*local)(memobj, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectW(cl_mem memobj, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClMemObjectW_impl_fntype local = (nvtxNameClMemObjectW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr;
+ if(local!=0)
+ (*local)(memobj, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerA(cl_sampler sampler, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClSamplerA_impl_fntype local = (nvtxNameClSamplerA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr;
+ if(local!=0)
+ (*local)(sampler, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerW(cl_sampler sampler, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClSamplerW_impl_fntype local = (nvtxNameClSamplerW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr;
+ if(local!=0)
+ (*local)(sampler, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClProgramA(cl_program program, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClProgramA_impl_fntype local = (nvtxNameClProgramA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr;
+ if(local!=0)
+ (*local)(program, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClProgramW(cl_program program, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClProgramW_impl_fntype local = (nvtxNameClProgramW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr;
+ if(local!=0)
+ (*local)(program, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClEventA(cl_event evnt, const char* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClEventA_impl_fntype local = (nvtxNameClEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr;
+ if(local!=0)
+ (*local)(evnt, name);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxNameClEventW(cl_event evnt, const wchar_t* name)
+{
+#ifndef NVTX_DISABLE
+ nvtxNameClEventW_impl_fntype local = (nvtxNameClEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr;
+ if(local!=0)
+ (*local)(evnt, name);
+#endif /*NVTX_DISABLE*/
+}
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h
new file mode 100644
index 0000000..90616da
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h
@@ -0,0 +1,83 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD_SYNC
+#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+typedef nvtxSyncUser_t (NVTX_API * nvtxDomainSyncUserCreate_impl_fntype)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
+typedef void (NVTX_API * nvtxDomainSyncUserDestroy_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireStart_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireFailed_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireSuccess_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserReleasing_impl_fntype)(nvtxSyncUser_t handle);
+
+NVTX_DECLSPEC nvtxSyncUser_t NVTX_API nvtxDomainSyncUserCreate(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserCreate_impl_fntype local = (nvtxDomainSyncUserCreate_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr;
+ if(local!=0)
+ return (*local)(domain, attribs);
+ else
+#endif /*NVTX_DISABLE*/
+ return (nvtxSyncUser_t)0;
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserDestroy(nvtxSyncUser_t handle)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserDestroy_impl_fntype local = (nvtxDomainSyncUserDestroy_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr;
+ if(local!=0)
+ (*local)(handle);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireStart(nvtxSyncUser_t handle)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserAcquireStart_impl_fntype local = (nvtxDomainSyncUserAcquireStart_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr;
+ if(local!=0)
+ (*local)(handle);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireFailed(nvtxSyncUser_t handle)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserAcquireFailed_impl_fntype local = (nvtxDomainSyncUserAcquireFailed_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr;
+ if(local!=0)
+ (*local)(handle);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireSuccess(nvtxSyncUser_t handle)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserAcquireSuccess_impl_fntype local = (nvtxDomainSyncUserAcquireSuccess_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
+ if(local!=0)
+ (*local)(handle);
+#endif /*NVTX_DISABLE*/
+}
+
+NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserReleasing(nvtxSyncUser_t handle)
+{
+#ifndef NVTX_DISABLE
+ nvtxDomainSyncUserReleasing_impl_fntype local = (nvtxDomainSyncUserReleasing_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr;
+ if(local!=0)
+ (*local)(handle);
+#endif /*NVTX_DISABLE*/
+}
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif /* __cplusplus */
diff --git a/src/include/nvtx3/nvtxDetail/nvtxInit.h b/src/include/nvtx3/nvtxDetail/nvtxInit.h
new file mode 100644
index 0000000..44dcc0f
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxInit.h
@@ -0,0 +1,312 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD
+#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+/* ---- Platform-independent helper definitions and functions ---- */
+
+/* Prefer macros over inline functions to reduce symbol resolution at link time */
+
+#if defined(_WIN32)
+#define NVTX_PATHCHAR wchar_t
+#define NVTX_STR(x) L##x
+#define NVTX_GETENV _wgetenv
+#define NVTX_BUFSIZE MAX_PATH
+#define NVTX_DLLHANDLE HMODULE
+#define NVTX_DLLOPEN(x) LoadLibraryW(x)
+#define NVTX_DLLFUNC GetProcAddress
+#define NVTX_DLLCLOSE FreeLibrary
+#define NVTX_YIELD() SwitchToThread()
+#define NVTX_MEMBAR() MemoryBarrier()
+#define NVTX_ATOMIC_WRITE_32(address, value) InterlockedExchange((volatile LONG*)address, value)
+#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) old = InterlockedCompareExchange((volatile LONG*)address, exchange, comparand)
+#elif defined(__GNUC__)
+#define NVTX_PATHCHAR char
+#define NVTX_STR(x) x
+#define NVTX_GETENV getenv
+#define NVTX_BUFSIZE PATH_MAX
+#define NVTX_DLLHANDLE void*
+#define NVTX_DLLOPEN(x) dlopen(x, RTLD_LAZY)
+#define NVTX_DLLFUNC dlsym
+#define NVTX_DLLCLOSE dlclose
+#define NVTX_YIELD() sched_yield()
+#define NVTX_MEMBAR() __sync_synchronize()
+/* Ensure full memory barrier for atomics, to match Windows functions */
+#define NVTX_ATOMIC_WRITE_32(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value)
+#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand)
+#else
+#error The library does not support your configuration!
+#endif
+
+/* Define this to 1 for platforms that where pre-injected libraries can be discovered. */
+#if defined(_WIN32)
+/* TODO */
+#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0
+#else
+#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0
+#endif
+
+/* Define this to 1 for platforms that support environment variables */
+/* TODO: Detect UWP, a.k.a. Windows Store app, and set this to 0. */
+/* Try: #if defined(WINAPI_FAMILY_PARTITION) && WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */
+#define NVTX_SUPPORT_ENV_VARS 1
+
+/* Define this to 1 for platforms that support dynamic/shared libraries */
+#define NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY 1
+
+/* Injection libraries implementing InitializeInjectionNvtx2 may be statically linked,
+* and this will override any dynamic injection. Useful for platforms where dynamic
+* injection is not available. Since weak symbols not explicitly marked extern are
+* guaranteed to be initialized to zero if no definitions are found by the linker, the
+* dynamic injection process proceeds normally if pfnInitializeInjectionNvtx2 is 0. */
+#if defined(__GNUC__) && !defined(_WIN32) && !defined(__CYGWIN__)
+#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 1
+/* To statically inject an NVTX library, define InitializeInjectionNvtx2_fnptr as a normal
+* symbol (not weak) pointing to the implementation of InitializeInjectionNvtx2 (which
+* does not need to be named "InitializeInjectionNvtx2" as is necessary in a dynamic
+* injection library. */
+__attribute__((weak)) NvtxInitializeInjectionNvtxFunc_t InitializeInjectionNvtx2_fnptr;
+#else
+#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 0
+#endif
+
+/* This function tries to find or load an NVTX injection library and get the
+* address of its InitializeInjection2 function. If such a function pointer
+* is found, it is called, and passed the address of this NVTX instance's
+* nvtxGetExportTable function, so the injection can attach to this instance.
+* If the initialization fails for any reason, any dynamic library loaded will
+* be freed, and all NVTX implementation functions will be set to no-ops. If
+* initialization succeeds, NVTX functions not attached to the tool will be set
+* to no-ops. This is implemented as one function instead of several small
+* functions to minimize the number of weak symbols the linker must resolve.
+* Order of search is:
+* - Pre-injected library exporting InitializeInjectionNvtx2
+* - Loadable library exporting InitializeInjectionNvtx2
+* - Path specified by env var NVTX_INJECTION??_PATH (?? is 32 or 64)
+* - On Android, libNvtxInjection??.so within the package (?? is 32 or 64)
+* - Statically-linked injection library defining InitializeInjectionNvtx2_fnptr
+*/
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)(void);
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)(void)
+{
+ const char* const initFuncName = "InitializeInjectionNvtx2";
+ NvtxInitializeInjectionNvtxFunc_t init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)0;
+ NVTX_DLLHANDLE injectionLibraryHandle = (NVTX_DLLHANDLE)0;
+ int entryPointStatus = 0;
+
+#if NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY
+ /* Use POSIX global symbol chain to query for init function from any module */
+ init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)NVTX_DLLFUNC(0, initFuncName);
+#endif
+
+#if NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY
+ /* Try discovering dynamic injection library to load */
+ if (!init_fnptr)
+ {
+#if NVTX_SUPPORT_ENV_VARS
+ /* If env var NVTX_INJECTION64_PATH is set, it should contain the path
+ * to a 64-bit dynamic NVTX injection library (and similar for 32-bit). */
+ const NVTX_PATHCHAR* const nvtxEnvVarName = (sizeof(void*) == 4)
+ ? NVTX_STR("NVTX_INJECTION32_PATH")
+ : NVTX_STR("NVTX_INJECTION64_PATH");
+#endif /* NVTX_SUPPORT_ENV_VARS */
+ NVTX_PATHCHAR injectionLibraryPathBuf[NVTX_BUFSIZE];
+ const NVTX_PATHCHAR* injectionLibraryPath = (const NVTX_PATHCHAR*)0;
+
+ /* Refer to this variable explicitly in case all references to it are #if'ed out */
+ (void)injectionLibraryPathBuf;
+
+#if NVTX_SUPPORT_ENV_VARS
+ /* Disable the warning for getenv & _wgetenv -- this usage is safe because
+ * these functions are not called again before using the returned value. */
+#if defined(_MSC_VER)
+#pragma warning( push )
+#pragma warning( disable : 4996 )
+#endif
+ injectionLibraryPath = NVTX_GETENV(nvtxEnvVarName);
+#if defined(_MSC_VER)
+#pragma warning( pop )
+#endif
+#endif
+
+#if defined(__ANDROID__)
+ if (!injectionLibraryPath)
+ {
+ const char *bits = (sizeof(void*) == 4) ? "32" : "64";
+ char cmdlineBuf[32];
+ char pkgName[PATH_MAX];
+ int count;
+ int pid;
+ FILE *fp;
+ size_t bytesRead;
+ size_t pos;
+
+ pid = (int)getpid();
+ count = snprintf(cmdlineBuf, sizeof(cmdlineBuf), "/proc/%d/cmdline", pid);
+ if (count <= 0 || count >= (int)sizeof(cmdlineBuf))
+ {
+ NVTX_ERR("Path buffer too small for: /proc/%d/cmdline\n", pid);
+ return NVTX_ERR_INIT_ACCESS_LIBRARY;
+ }
+
+ fp = fopen(cmdlineBuf, "r");
+ if (!fp)
+ {
+ NVTX_ERR("File couldn't be opened: %s\n", cmdlineBuf);
+ return NVTX_ERR_INIT_ACCESS_LIBRARY;
+ }
+
+ bytesRead = fread(pkgName, 1, sizeof(pkgName) - 1, fp);
+ fclose(fp);
+ if (bytesRead == 0)
+ {
+ NVTX_ERR("Package name couldn't be read from file: %s\n", cmdlineBuf);
+ return NVTX_ERR_INIT_ACCESS_LIBRARY;
+ }
+
+ pkgName[bytesRead] = 0;
+
+ /* String can contain colon as a process separator. In this case the package name is before the colon. */
+ pos = 0;
+ while (pos < bytesRead && pkgName[pos] != ':' && pkgName[pos] != '\0')
+ {
+ ++pos;
+ }
+ pkgName[pos] = 0;
+
+ count = snprintf(injectionLibraryPathBuf, NVTX_BUFSIZE, "/data/data/%s/files/libNvtxInjection%s.so", pkgName, bits);
+ if (count <= 0 || count >= NVTX_BUFSIZE)
+ {
+ NVTX_ERR("Path buffer too small for: /data/data/%s/files/libNvtxInjection%s.so\n", pkgName, bits);
+ return NVTX_ERR_INIT_ACCESS_LIBRARY;
+ }
+
+ /* On Android, verify path is accessible due to aggressive file access restrictions. */
+ /* For dlopen, if the filename contains a leading slash, then it is interpreted as a */
+ /* relative or absolute pathname; otherwise it will follow the rules in ld.so. */
+ if (injectionLibraryPathBuf[0] == '/')
+ {
+#if (__ANDROID_API__ < 21)
+ int access_err = access(injectionLibraryPathBuf, F_OK | R_OK);
+#else
+ int access_err = faccessat(AT_FDCWD, injectionLibraryPathBuf, F_OK | R_OK, 0);
+#endif
+ if (access_err != 0)
+ {
+ NVTX_ERR("Injection library path wasn't accessible [code=%s] [path=%s]\n", strerror(errno), injectionLibraryPathBuf);
+ return NVTX_ERR_INIT_ACCESS_LIBRARY;
+ }
+ }
+ injectionLibraryPath = injectionLibraryPathBuf;
+ }
+#endif
+
+ /* At this point, injectionLibraryPath is specified if a dynamic
+ * injection library was specified by a tool. */
+ if (injectionLibraryPath)
+ {
+ /* Load the injection library */
+ injectionLibraryHandle = NVTX_DLLOPEN(injectionLibraryPath);
+ if (!injectionLibraryHandle)
+ {
+ NVTX_ERR("Failed to load injection library\n");
+ return NVTX_ERR_INIT_LOAD_LIBRARY;
+ }
+ else
+ {
+ /* Attempt to get the injection library's entry-point */
+ init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)NVTX_DLLFUNC(injectionLibraryHandle, initFuncName);
+ if (!init_fnptr)
+ {
+ NVTX_DLLCLOSE(injectionLibraryHandle);
+ NVTX_ERR("Failed to get address of function InitializeInjectionNvtx2 from injection library\n");
+ return NVTX_ERR_INIT_MISSING_LIBRARY_ENTRY_POINT;
+ }
+ }
+ }
+ }
+#endif
+
+#if NVTX_SUPPORT_STATIC_INJECTION_LIBRARY
+ if (!init_fnptr)
+ {
+ /* Check weakly-defined function pointer. A statically-linked injection can define this as
+ * a normal symbol and it will take precedence over a dynamic injection. */
+ if (InitializeInjectionNvtx2_fnptr)
+ {
+ init_fnptr = InitializeInjectionNvtx2_fnptr;
+ }
+ }
+#endif
+
+ /* At this point, if init_fnptr is not set, then no tool has specified
+ * an NVTX injection library -- return non-success result so all NVTX
+ * API functions will be set to no-ops. */
+ if (!init_fnptr)
+ {
+ return NVTX_ERR_NO_INJECTION_LIBRARY_AVAILABLE;
+ }
+
+ /* Invoke injection library's initialization function. If it returns
+ * 0 (failure) and a dynamic injection was loaded, unload it. */
+ entryPointStatus = init_fnptr(NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable));
+ if (entryPointStatus == 0)
+ {
+ NVTX_ERR("Failed to initialize injection library -- initialization function returned 0\n");
+ if (injectionLibraryHandle)
+ {
+ NVTX_DLLCLOSE(injectionLibraryHandle);
+ }
+ return NVTX_ERR_INIT_FAILED_LIBRARY_ENTRY_POINT;
+ }
+
+ return NVTX_SUCCESS;
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)(void)
+{
+ unsigned int old;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState == NVTX_INIT_STATE_COMPLETE)
+ {
+ return;
+ }
+
+ NVTX_ATOMIC_CAS_32(
+ old,
+ &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState,
+ NVTX_INIT_STATE_STARTED,
+ NVTX_INIT_STATE_FRESH);
+ if (old == NVTX_INIT_STATE_FRESH)
+ {
+ int result;
+ int forceAllToNoops;
+
+ /* Load & initialize injection library -- it will assign the function pointers */
+ result = NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)();
+
+ /* Set all pointers not assigned by the injection to null */
+ forceAllToNoops = result != NVTX_SUCCESS; /* Set all to null if injection init failed */
+ NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(forceAllToNoops);
+
+ /* Signal that initialization has finished, so now the assigned function pointers will be used */
+ NVTX_ATOMIC_WRITE_32(
+ &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState,
+ NVTX_INIT_STATE_COMPLETE);
+ }
+ else /* Spin-wait until initialization has finished */
+ {
+ NVTX_MEMBAR();
+ while (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState != NVTX_INIT_STATE_COMPLETE)
+ {
+ NVTX_YIELD();
+ NVTX_MEMBAR();
+ }
+ }
+}
diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h
new file mode 100644
index 0000000..261681b
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h
@@ -0,0 +1,81 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD
+#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init)(const char* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init)(const wchar_t* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init)(const char* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init)(const wchar_t* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init)(nvtxRangeId_t id);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init)(const char* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init)(const wchar_t* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init)(void);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init)(uint32_t category, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init)(uint32_t category, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init)(uint32_t threadId, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init)(uint32_t threadId, const wchar_t* name);
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init)(nvtx_CUdevice device, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init)(nvtx_CUdevice device, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init)(nvtx_CUcontext context, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init)(nvtx_CUcontext context, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init)(nvtx_CUstream stream, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init)(nvtx_CUstream stream, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init)(nvtx_CUevent event, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init)(nvtx_CUevent event, const wchar_t* name);
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init)(nvtx_cl_device_id device, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init)(nvtx_cl_device_id device, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init)(nvtx_cl_context context, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init)(nvtx_cl_context context, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init)(nvtx_cl_command_queue command_queue, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init)(nvtx_cl_command_queue command_queue, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init)(nvtx_cl_mem memobj, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init)(nvtx_cl_mem memobj, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init)(nvtx_cl_sampler sampler, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init)(nvtx_cl_sampler sampler, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init)(nvtx_cl_program program, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init)(nvtx_cl_program program, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init)(nvtx_cl_event evnt, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init)(nvtx_cl_event evnt, const wchar_t* name);
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init)(int device, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init)(int device, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init)(nvtx_cudaStream_t stream, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init)(nvtx_cudaStream_t stream, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init)(nvtx_cudaEvent_t event, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init)(nvtx_cudaEvent_t event, const wchar_t* name);
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init)(nvtxDomainHandle_t domain, nvtxRangeId_t id);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init)(nvtxDomainHandle_t domain);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxResourceHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init)(nvtxResourceHandle_t resource);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const char* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init)(nvtxDomainHandle_t domain, const char* string);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init)(nvtxDomainHandle_t domain, const wchar_t* string);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init)(const char* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init)(const wchar_t* message);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init)(nvtxDomainHandle_t domain);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init)(const void* reserved);
+
+NVTX_LINKONCE_FWDDECL_FUNCTION nvtxSyncUser_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init)(nvtxSyncUser_t handle);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init)(nvtxSyncUser_t handle);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init)(nvtxSyncUser_t handle);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init)(nvtxSyncUser_t handle);
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init)(nvtxSyncUser_t handle);
diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h
new file mode 100644
index 0000000..ded156c
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h
@@ -0,0 +1,573 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef NVTX_IMPL_GUARD
+#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
+#endif
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxMarkEx(eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init)(const char* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxMarkA(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init)(const wchar_t* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxMarkW(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangeStartEx(eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init)(const char* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangeStartA(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init)(const wchar_t* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangeStartW(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init)(nvtxRangeId_t id){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxRangeEnd(id);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangePushEx(eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init)(const char* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangePushA(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init)(const wchar_t* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangePushW(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init)(void){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxRangePop();
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init)(uint32_t category, const char* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxNameCategoryA(category, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init)(uint32_t category, const wchar_t* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxNameCategoryW(category, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init)(uint32_t threadId, const char* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxNameOsThreadA(threadId, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init)(uint32_t threadId, const wchar_t* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxNameOsThreadW(threadId, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainMarkEx(domain, eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainRangeStartEx(domain, eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init)(nvtxDomainHandle_t domain, nvtxRangeId_t id){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainRangeEnd(domain, id);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainRangePushEx(domain, eventAttrib);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init)(nvtxDomainHandle_t domain){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainRangePop(domain);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxResourceHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainResourceCreate(domain, attribs);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init)(nvtxResourceHandle_t resource){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainResourceDestroy(resource);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const char* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainNameCategoryA(domain, category, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainNameCategoryW(domain, category, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init)(nvtxDomainHandle_t domain, const char* string){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainRegisterStringA(domain, string);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init)(nvtxDomainHandle_t domain, const wchar_t* string){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainRegisterStringW(domain, string);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init)(const char* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainCreateA(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init)(const wchar_t* message){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ return nvtxDomainCreateW(message);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init)(nvtxDomainHandle_t domain){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxDomainDestroy(domain);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init)(const void* reserved){
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ nvtxInitialize(reserved);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init)(nvtx_CUdevice device, const char* name){
+ nvtxNameCuDeviceA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init)(nvtx_CUdevice device, const wchar_t* name){
+ nvtxNameCuDeviceW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init)(nvtx_CUcontext context, const char* name){
+ nvtxNameCuContextA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr;
+ if (local)
+ local(context, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init)(nvtx_CUcontext context, const wchar_t* name){
+ nvtxNameCuContextW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr;
+ if (local)
+ local(context, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init)(nvtx_CUstream stream, const char* name){
+ nvtxNameCuStreamA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr;
+ if (local)
+ local(stream, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init)(nvtx_CUstream stream, const wchar_t* name){
+ nvtxNameCuStreamW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr;
+ if (local)
+ local(stream, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init)(nvtx_CUevent event, const char* name){
+ nvtxNameCuEventA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr;
+ if (local)
+ local(event, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init)(nvtx_CUevent event, const wchar_t* name){
+ nvtxNameCuEventW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr;
+ if (local)
+ local(event, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init)(int device, const char* name){
+ nvtxNameCudaDeviceA_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init)(int device, const wchar_t* name){
+ nvtxNameCudaDeviceW_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init)(nvtx_cudaStream_t stream, const char* name){
+ nvtxNameCudaStreamA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr;
+ if (local)
+ local(stream, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init)(nvtx_cudaStream_t stream, const wchar_t* name){
+ nvtxNameCudaStreamW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr;
+ if (local)
+ local(stream, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init)(nvtx_cudaEvent_t event, const char* name){
+ nvtxNameCudaEventA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr;
+ if (local)
+ local(event, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init)(nvtx_cudaEvent_t event, const wchar_t* name){
+ nvtxNameCudaEventW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr;
+ if (local)
+ local(event, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init)(nvtx_cl_device_id device, const char* name){
+ nvtxNameClDeviceA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init)(nvtx_cl_device_id device, const wchar_t* name){
+ nvtxNameClDeviceW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr;
+ if (local)
+ local(device, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init)(nvtx_cl_context context, const char* name){
+ nvtxNameClContextA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr;
+ if (local)
+ local(context, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init)(nvtx_cl_context context, const wchar_t* name){
+ nvtxNameClContextW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr;
+ if (local)
+ local(context, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init)(nvtx_cl_command_queue command_queue, const char* name){
+ nvtxNameClCommandQueueA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr;
+ if (local)
+ local(command_queue, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init)(nvtx_cl_command_queue command_queue, const wchar_t* name){
+ nvtxNameClCommandQueueW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr;
+ if (local)
+ local(command_queue, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init)(nvtx_cl_mem memobj, const char* name){
+ nvtxNameClMemObjectA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr;
+ if (local)
+ local(memobj, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init)(nvtx_cl_mem memobj, const wchar_t* name){
+ nvtxNameClMemObjectW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr;
+ if (local)
+ local(memobj, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init)(nvtx_cl_sampler sampler, const char* name){
+ nvtxNameClSamplerA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr;
+ if (local)
+ local(sampler, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init)(nvtx_cl_sampler sampler, const wchar_t* name){
+ nvtxNameClSamplerW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr;
+ if (local)
+ local(sampler, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init)(nvtx_cl_program program, const char* name){
+ nvtxNameClProgramA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr;
+ if (local)
+ local(program, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init)(nvtx_cl_program program, const wchar_t* name){
+ nvtxNameClProgramW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr;
+ if (local)
+ local(program, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init)(nvtx_cl_event evnt, const char* name){
+ nvtxNameClEventA_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr;
+ if (local)
+ local(evnt, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init)(nvtx_cl_event evnt, const wchar_t* name){
+ nvtxNameClEventW_fakeimpl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr;
+ if (local)
+ local(evnt, name);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION nvtxSyncUser_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs){
+ nvtxDomainSyncUserCreate_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr;
+ if (local) {
+ return local(domain, attribs);
+ }
+ return (nvtxSyncUser_t)0;
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init)(nvtxSyncUser_t handle){
+ nvtxDomainSyncUserDestroy_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr;
+ if (local)
+ local(handle);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init)(nvtxSyncUser_t handle){
+ nvtxDomainSyncUserAcquireStart_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr;
+ if (local)
+ local(handle);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init)(nvtxSyncUser_t handle){
+ nvtxDomainSyncUserAcquireFailed_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr;
+ if (local)
+ local(handle);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init)(nvtxSyncUser_t handle){
+ nvtxDomainSyncUserAcquireSuccess_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
+ if (local)
+ local(handle);
+}
+
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init)(nvtxSyncUser_t handle){
+ nvtxDomainSyncUserReleasing_impl_fntype local;
+ NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
+ local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr;
+ if (local)
+ local(handle);
+}
+
+NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(int forceAllToNoops);
+NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(int forceAllToNoops)
+{
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr = NULL;
+
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr = NULL;
+
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr = NULL;
+
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr = NULL;
+
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr = NULL;
+
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr = NULL;
+ if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init) || forceAllToNoops)
+ NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr = NULL;
+}
diff --git a/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h
new file mode 100644
index 0000000..908ce88
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h
@@ -0,0 +1,83 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+#ifndef __NVTX_LINKONCE_H__
+#define __NVTX_LINKONCE_H__
+
+/* This header defines macros to permit making definitions of global variables
+ * and functions in C/C++ header files which may be included multiple times in
+ * a translation unit or linkage unit. It allows authoring header-only libraries
+ * which can be used by multiple other header-only libraries (either as the same
+ * copy or multiple copies), and does not require any build changes, such as
+ * adding another .c file, linking a static library, or deploying a dynamic
+ * library. Globals defined with these macros have the property that they have
+ * the same address, pointing to a single instance, for the entire linkage unit.
+ * It is expected but not guaranteed that each linkage unit will have a separate
+ * instance.
+ *
+ * In some situations it is desirable to declare a variable without initializing
+ * it, refer to it in code or other variables' initializers, and then initialize
+ * it later. Similarly, functions can be prototyped, have their address taken,
+ * and then have their body defined later. In such cases, use the FWDDECL macros
+ * when forward-declaring LINKONCE global variables without initializers and
+ * function prototypes, and then use the DEFINE macros when later defining them.
+ * Although in many cases the FWDDECL macro is equivalent to the DEFINE macro,
+ * following this pattern makes code maximally portable.
+ */
+
+#if defined(__MINGW32__) /* MinGW */
+ #define NVTX_LINKONCE_WEAK __attribute__((section(".gnu.linkonce.0.")))
+ #if defined(__cplusplus)
+ #define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
+ #define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline NVTX_LINKONCE_WEAK
+ #else
+ #define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
+ #define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
+ #endif
+#elif defined(_MSC_VER) /* MSVC */
+ #if defined(__cplusplus)
+ #define NVTX_LINKONCE_DEFINE_GLOBAL extern "C" __declspec(selectany)
+ #define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline
+ #else
+ #define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
+ #define NVTX_LINKONCE_DEFINE_FUNCTION __inline
+ #endif
+#elif defined(__CYGWIN__) && defined(__clang__) /* Clang on Cygwin */
+ #define NVTX_LINKONCE_WEAK __attribute__((section(".gnu.linkonce.0.")))
+ #if defined(__cplusplus)
+ #define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
+ #define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" NVTX_LINKONCE_WEAK
+ #else
+ #define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
+ #define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
+ #endif
+#elif defined(__CYGWIN__) /* Assume GCC or compatible */
+ #define NVTX_LINKONCE_WEAK __attribute__((weak))
+ #if defined(__cplusplus)
+ #define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
+ #define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline
+ #else
+ #define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
+ #define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
+ #endif
+#else /* All others: Assume GCC, clang, or compatible */
+ #define NVTX_LINKONCE_WEAK __attribute__((weak))
+ #define NVTX_LINKONCE_HIDDEN __attribute__((visibility("hidden")))
+ #if defined(__cplusplus)
+ #define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
+ #define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" NVTX_LINKONCE_HIDDEN inline
+ #else
+ #define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
+ #define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
+ #endif
+#endif
+
+#define NVTX_LINKONCE_FWDDECL_GLOBAL NVTX_LINKONCE_DEFINE_GLOBAL extern
+#define NVTX_LINKONCE_FWDDECL_FUNCTION NVTX_LINKONCE_DEFINE_FUNCTION
+
+#endif /* __NVTX_LINKONCE_H__ */
diff --git a/src/include/nvtx3/nvtxDetail/nvtxTypes.h b/src/include/nvtx3/nvtxDetail/nvtxTypes.h
new file mode 100644
index 0000000..53c6c00
--- /dev/null
+++ b/src/include/nvtx3/nvtxDetail/nvtxTypes.h
@@ -0,0 +1,304 @@
+/*
+* Copyright 2009-2020 NVIDIA Corporation. All rights reserved.
+*
+* Licensed under the Apache License v2.0 with LLVM Exceptions.
+* See https://llvm.org/LICENSE.txt for license information.
+* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+*/
+
+/* This header defines types which are used by the internal implementation
+* of NVTX and callback subscribers. API clients do not use these types,
+* so they are defined here instead of in nvToolsExt.h to clarify they are
+* not part of the NVTX client API. */
+
+#ifndef NVTX_IMPL_GUARD
+#error Never include this file directly -- it is automatically included by nvToolsExt.h.
+#endif
+
+/* ------ Dependency-free types binary-compatible with real types ------- */
+
+/* In order to avoid having the NVTX core API headers depend on non-NVTX
+* headers like cuda.h, NVTX defines binary-compatible types to use for
+* safely making the initialization versions of all NVTX functions without
+* needing to have definitions for the real types. */
+
+typedef int nvtx_CUdevice;
+typedef void* nvtx_CUcontext;
+typedef void* nvtx_CUstream;
+typedef void* nvtx_CUevent;
+
+typedef void* nvtx_cudaStream_t;
+typedef void* nvtx_cudaEvent_t;
+
+typedef void* nvtx_cl_platform_id;
+typedef void* nvtx_cl_device_id;
+typedef void* nvtx_cl_context;
+typedef void* nvtx_cl_command_queue;
+typedef void* nvtx_cl_mem;
+typedef void* nvtx_cl_program;
+typedef void* nvtx_cl_kernel;
+typedef void* nvtx_cl_event;
+typedef void* nvtx_cl_sampler;
+
+typedef struct nvtxSyncUser* nvtxSyncUser_t;
+struct nvtxSyncUserAttributes_v0;
+typedef struct nvtxSyncUserAttributes_v0 nvtxSyncUserAttributes_t;
+
+/* --------- Types for function pointers (with fake API types) ---------- */
+
+typedef void (NVTX_API * nvtxMarkEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
+typedef void (NVTX_API * nvtxMarkA_impl_fntype)(const char* message);
+typedef void (NVTX_API * nvtxMarkW_impl_fntype)(const wchar_t* message);
+typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
+typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartA_impl_fntype)(const char* message);
+typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartW_impl_fntype)(const wchar_t* message);
+typedef void (NVTX_API * nvtxRangeEnd_impl_fntype)(nvtxRangeId_t id);
+typedef int (NVTX_API * nvtxRangePushEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
+typedef int (NVTX_API * nvtxRangePushA_impl_fntype)(const char* message);
+typedef int (NVTX_API * nvtxRangePushW_impl_fntype)(const wchar_t* message);
+typedef int (NVTX_API * nvtxRangePop_impl_fntype)(void);
+typedef void (NVTX_API * nvtxNameCategoryA_impl_fntype)(uint32_t category, const char* name);
+typedef void (NVTX_API * nvtxNameCategoryW_impl_fntype)(uint32_t category, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameOsThreadA_impl_fntype)(uint32_t threadId, const char* name);
+typedef void (NVTX_API * nvtxNameOsThreadW_impl_fntype)(uint32_t threadId, const wchar_t* name);
+
+/* Real impl types are defined in nvtxImplCuda_v3.h, where CUDA headers are included */
+typedef void (NVTX_API * nvtxNameCuDeviceA_fakeimpl_fntype)(nvtx_CUdevice device, const char* name);
+typedef void (NVTX_API * nvtxNameCuDeviceW_fakeimpl_fntype)(nvtx_CUdevice device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuContextA_fakeimpl_fntype)(nvtx_CUcontext context, const char* name);
+typedef void (NVTX_API * nvtxNameCuContextW_fakeimpl_fntype)(nvtx_CUcontext context, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuStreamA_fakeimpl_fntype)(nvtx_CUstream stream, const char* name);
+typedef void (NVTX_API * nvtxNameCuStreamW_fakeimpl_fntype)(nvtx_CUstream stream, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCuEventA_fakeimpl_fntype)(nvtx_CUevent event, const char* name);
+typedef void (NVTX_API * nvtxNameCuEventW_fakeimpl_fntype)(nvtx_CUevent event, const wchar_t* name);
+
+/* Real impl types are defined in nvtxImplOpenCL_v3.h, where OPENCL headers are included */
+typedef void (NVTX_API * nvtxNameClDeviceA_fakeimpl_fntype)(nvtx_cl_device_id device, const char* name);
+typedef void (NVTX_API * nvtxNameClDeviceW_fakeimpl_fntype)(nvtx_cl_device_id device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClContextA_fakeimpl_fntype)(nvtx_cl_context context, const char* name);
+typedef void (NVTX_API * nvtxNameClContextW_fakeimpl_fntype)(nvtx_cl_context context, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClCommandQueueA_fakeimpl_fntype)(nvtx_cl_command_queue command_queue, const char* name);
+typedef void (NVTX_API * nvtxNameClCommandQueueW_fakeimpl_fntype)(nvtx_cl_command_queue command_queue, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClMemObjectA_fakeimpl_fntype)(nvtx_cl_mem memobj, const char* name);
+typedef void (NVTX_API * nvtxNameClMemObjectW_fakeimpl_fntype)(nvtx_cl_mem memobj, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClSamplerA_fakeimpl_fntype)(nvtx_cl_sampler sampler, const char* name);
+typedef void (NVTX_API * nvtxNameClSamplerW_fakeimpl_fntype)(nvtx_cl_sampler sampler, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClProgramA_fakeimpl_fntype)(nvtx_cl_program program, const char* name);
+typedef void (NVTX_API * nvtxNameClProgramW_fakeimpl_fntype)(nvtx_cl_program program, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameClEventA_fakeimpl_fntype)(nvtx_cl_event evnt, const char* name);
+typedef void (NVTX_API * nvtxNameClEventW_fakeimpl_fntype)(nvtx_cl_event evnt, const wchar_t* name);
+
+/* Real impl types are defined in nvtxImplCudaRt_v3.h, where CUDART headers are included */
+typedef void (NVTX_API * nvtxNameCudaDeviceA_impl_fntype)(int device, const char* name);
+typedef void (NVTX_API * nvtxNameCudaDeviceW_impl_fntype)(int device, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCudaStreamA_fakeimpl_fntype)(nvtx_cudaStream_t stream, const char* name);
+typedef void (NVTX_API * nvtxNameCudaStreamW_fakeimpl_fntype)(nvtx_cudaStream_t stream, const wchar_t* name);
+typedef void (NVTX_API * nvtxNameCudaEventA_fakeimpl_fntype)(nvtx_cudaEvent_t event, const char* name);
+typedef void (NVTX_API * nvtxNameCudaEventW_fakeimpl_fntype)(nvtx_cudaEvent_t event, const wchar_t* name);
+
+typedef void (NVTX_API * nvtxDomainMarkEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+typedef nvtxRangeId_t (NVTX_API * nvtxDomainRangeStartEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+typedef void (NVTX_API * nvtxDomainRangeEnd_impl_fntype)(nvtxDomainHandle_t domain, nvtxRangeId_t id);
+typedef int (NVTX_API * nvtxDomainRangePushEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
+typedef int (NVTX_API * nvtxDomainRangePop_impl_fntype)(nvtxDomainHandle_t domain);
+typedef nvtxResourceHandle_t (NVTX_API * nvtxDomainResourceCreate_impl_fntype)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs);
+typedef void (NVTX_API * nvtxDomainResourceDestroy_impl_fntype)(nvtxResourceHandle_t resource);
+typedef void (NVTX_API * nvtxDomainNameCategoryA_impl_fntype)(nvtxDomainHandle_t domain, uint32_t category, const char* name);
+typedef void (NVTX_API * nvtxDomainNameCategoryW_impl_fntype)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name);
+typedef nvtxStringHandle_t (NVTX_API * nvtxDomainRegisterStringA_impl_fntype)(nvtxDomainHandle_t domain, const char* string);
+typedef nvtxStringHandle_t (NVTX_API * nvtxDomainRegisterStringW_impl_fntype)(nvtxDomainHandle_t domain, const wchar_t* string);
+typedef nvtxDomainHandle_t (NVTX_API * nvtxDomainCreateA_impl_fntype)(const char* message);
+typedef nvtxDomainHandle_t (NVTX_API * nvtxDomainCreateW_impl_fntype)(const wchar_t* message);
+typedef void (NVTX_API * nvtxDomainDestroy_impl_fntype)(nvtxDomainHandle_t domain);
+typedef void (NVTX_API * nvtxInitialize_impl_fntype)(const void* reserved);
+
+typedef nvtxSyncUser_t (NVTX_API * nvtxDomainSyncUserCreate_impl_fntype)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
+typedef void (NVTX_API * nvtxDomainSyncUserDestroy_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireStart_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireFailed_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserAcquireSuccess_impl_fntype)(nvtxSyncUser_t handle);
+typedef void (NVTX_API * nvtxDomainSyncUserReleasing_impl_fntype)(nvtxSyncUser_t handle);
+
+/* ---------------- Types for callback subscription --------------------- */
+
+typedef const void *(NVTX_API * NvtxGetExportTableFunc_t)(uint32_t exportTableId);
+typedef int (NVTX_API * NvtxInitializeInjectionNvtxFunc_t)(NvtxGetExportTableFunc_t exportTable);
+
+typedef enum NvtxCallbackModule
+{
+ NVTX_CB_MODULE_INVALID = 0,
+ NVTX_CB_MODULE_CORE = 1,
+ NVTX_CB_MODULE_CUDA = 2,
+ NVTX_CB_MODULE_OPENCL = 3,
+ NVTX_CB_MODULE_CUDART = 4,
+ NVTX_CB_MODULE_CORE2 = 5,
+ NVTX_CB_MODULE_SYNC = 6,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CB_MODULE_SIZE,
+ NVTX_CB_MODULE_FORCE_INT = 0x7fffffff
+} NvtxCallbackModule;
+
+typedef enum NvtxCallbackIdCore
+{
+ NVTX_CBID_CORE_INVALID = 0,
+ NVTX_CBID_CORE_MarkEx = 1,
+ NVTX_CBID_CORE_MarkA = 2,
+ NVTX_CBID_CORE_MarkW = 3,
+ NVTX_CBID_CORE_RangeStartEx = 4,
+ NVTX_CBID_CORE_RangeStartA = 5,
+ NVTX_CBID_CORE_RangeStartW = 6,
+ NVTX_CBID_CORE_RangeEnd = 7,
+ NVTX_CBID_CORE_RangePushEx = 8,
+ NVTX_CBID_CORE_RangePushA = 9,
+ NVTX_CBID_CORE_RangePushW = 10,
+ NVTX_CBID_CORE_RangePop = 11,
+ NVTX_CBID_CORE_NameCategoryA = 12,
+ NVTX_CBID_CORE_NameCategoryW = 13,
+ NVTX_CBID_CORE_NameOsThreadA = 14,
+ NVTX_CBID_CORE_NameOsThreadW = 15,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_CORE_SIZE,
+ NVTX_CBID_CORE_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdCore;
+
+typedef enum NvtxCallbackIdCore2
+{
+ NVTX_CBID_CORE2_INVALID = 0,
+ NVTX_CBID_CORE2_DomainMarkEx = 1,
+ NVTX_CBID_CORE2_DomainRangeStartEx = 2,
+ NVTX_CBID_CORE2_DomainRangeEnd = 3,
+ NVTX_CBID_CORE2_DomainRangePushEx = 4,
+ NVTX_CBID_CORE2_DomainRangePop = 5,
+ NVTX_CBID_CORE2_DomainResourceCreate = 6,
+ NVTX_CBID_CORE2_DomainResourceDestroy = 7,
+ NVTX_CBID_CORE2_DomainNameCategoryA = 8,
+ NVTX_CBID_CORE2_DomainNameCategoryW = 9,
+ NVTX_CBID_CORE2_DomainRegisterStringA = 10,
+ NVTX_CBID_CORE2_DomainRegisterStringW = 11,
+ NVTX_CBID_CORE2_DomainCreateA = 12,
+ NVTX_CBID_CORE2_DomainCreateW = 13,
+ NVTX_CBID_CORE2_DomainDestroy = 14,
+ NVTX_CBID_CORE2_Initialize = 15,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_CORE2_SIZE,
+ NVTX_CBID_CORE2_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdCore2;
+
+typedef enum NvtxCallbackIdCuda
+{
+ NVTX_CBID_CUDA_INVALID = 0,
+ NVTX_CBID_CUDA_NameCuDeviceA = 1,
+ NVTX_CBID_CUDA_NameCuDeviceW = 2,
+ NVTX_CBID_CUDA_NameCuContextA = 3,
+ NVTX_CBID_CUDA_NameCuContextW = 4,
+ NVTX_CBID_CUDA_NameCuStreamA = 5,
+ NVTX_CBID_CUDA_NameCuStreamW = 6,
+ NVTX_CBID_CUDA_NameCuEventA = 7,
+ NVTX_CBID_CUDA_NameCuEventW = 8,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_CUDA_SIZE,
+ NVTX_CBID_CUDA_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdCuda;
+
+typedef enum NvtxCallbackIdCudaRt
+{
+ NVTX_CBID_CUDART_INVALID = 0,
+ NVTX_CBID_CUDART_NameCudaDeviceA = 1,
+ NVTX_CBID_CUDART_NameCudaDeviceW = 2,
+ NVTX_CBID_CUDART_NameCudaStreamA = 3,
+ NVTX_CBID_CUDART_NameCudaStreamW = 4,
+ NVTX_CBID_CUDART_NameCudaEventA = 5,
+ NVTX_CBID_CUDART_NameCudaEventW = 6,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_CUDART_SIZE,
+ NVTX_CBID_CUDART_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdCudaRt;
+
+typedef enum NvtxCallbackIdOpenCL
+{
+ NVTX_CBID_OPENCL_INVALID = 0,
+ NVTX_CBID_OPENCL_NameClDeviceA = 1,
+ NVTX_CBID_OPENCL_NameClDeviceW = 2,
+ NVTX_CBID_OPENCL_NameClContextA = 3,
+ NVTX_CBID_OPENCL_NameClContextW = 4,
+ NVTX_CBID_OPENCL_NameClCommandQueueA = 5,
+ NVTX_CBID_OPENCL_NameClCommandQueueW = 6,
+ NVTX_CBID_OPENCL_NameClMemObjectA = 7,
+ NVTX_CBID_OPENCL_NameClMemObjectW = 8,
+ NVTX_CBID_OPENCL_NameClSamplerA = 9,
+ NVTX_CBID_OPENCL_NameClSamplerW = 10,
+ NVTX_CBID_OPENCL_NameClProgramA = 11,
+ NVTX_CBID_OPENCL_NameClProgramW = 12,
+ NVTX_CBID_OPENCL_NameClEventA = 13,
+ NVTX_CBID_OPENCL_NameClEventW = 14,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_OPENCL_SIZE,
+ NVTX_CBID_OPENCL_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdOpenCL;
+
+typedef enum NvtxCallbackIdSync
+{
+ NVTX_CBID_SYNC_INVALID = 0,
+ NVTX_CBID_SYNC_DomainSyncUserCreate = 1,
+ NVTX_CBID_SYNC_DomainSyncUserDestroy = 2,
+ NVTX_CBID_SYNC_DomainSyncUserAcquireStart = 3,
+ NVTX_CBID_SYNC_DomainSyncUserAcquireFailed = 4,
+ NVTX_CBID_SYNC_DomainSyncUserAcquireSuccess = 5,
+ NVTX_CBID_SYNC_DomainSyncUserReleasing = 6,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_CBID_SYNC_SIZE,
+ NVTX_CBID_SYNC_FORCE_INT = 0x7fffffff
+} NvtxCallbackIdSync;
+
+/* IDs for NVTX Export Tables */
+typedef enum NvtxExportTableID
+{
+ NVTX_ETID_INVALID = 0,
+ NVTX_ETID_CALLBACKS = 1,
+ NVTX_ETID_RESERVED0 = 2,
+ NVTX_ETID_VERSIONINFO = 3,
+ /* --- New constants must only be added directly above this line --- */
+ NVTX_ETID_SIZE,
+ NVTX_ETID_FORCE_INT = 0x7fffffff
+} NvtxExportTableID;
+
+typedef void (* NvtxFunctionPointer)(void); /* generic uncallable function pointer, must be casted to appropriate function type */
+typedef NvtxFunctionPointer** NvtxFunctionTable; /* double pointer because array(1) of pointers(2) to function pointers */
+
+typedef struct NvtxExportTableCallbacks
+{
+ size_t struct_size;
+
+ /* returns an array of pointer to function pointers*/
+ int (NVTX_API *GetModuleFunctionTable)(
+ NvtxCallbackModule module,
+ NvtxFunctionTable* out_table,
+ unsigned int* out_size);
+} NvtxExportTableCallbacks;
+
+typedef struct NvtxExportTableVersionInfo
+{
+ /* sizeof(NvtxExportTableVersionInfo) */
+ size_t struct_size;
+
+ /* The API version comes from the NVTX library linked to the app. The
+ * injection library is can use this info to make some assumptions */
+ uint32_t version;
+
+ /* Reserved for alignment, do not use */
+ uint32_t reserved0;
+
+ /* This must be set by tools when attaching to provide applications
+ * the ability to, in emergency situations, detect problematic tools
+ * versions and modify the NVTX source to prevent attaching anything
+ * that causes trouble in the app. Currently, this value is ignored. */
+ void (NVTX_API *SetInjectionNvtxVersion)(
+ uint32_t version);
+} NvtxExportTableVersionInfo;
+
+
+
+
+
+
+
diff --git a/src/include/p2p.h b/src/include/p2p.h
index 9d3730e..756c8d2 100644
--- a/src/include/p2p.h
+++ b/src/include/p2p.h
@@ -10,23 +10,34 @@
#define NCCL_P2P_H_
struct ncclP2Pinfo {
- const void* sendbuff;
- void* recvbuff;
- ssize_t sendbytes;
- ssize_t recvbytes;
-};
-
-struct ncclP2PConnect {
- int nrecv[MAXCHANNELS];
- int nsend[MAXCHANNELS];
- int* recv;
- int* send;
+ void* buff;
+ ssize_t nbytes;
+ struct ncclP2Pinfo* next;
};
struct ncclP2Plist {
- struct ncclP2Pinfo *peerlist;
- int count;
- struct ncclP2PConnect connect;
+ struct ncclP2Pinfo *head;
+ struct ncclP2Pinfo *tail;
};
+static ncclResult_t enqueueP2pInfo(ncclP2Plist* p2p, void* buff, ssize_t nBytes) {
+ if (p2p == NULL) return ncclInternalError;
+ struct ncclP2Pinfo* next;
+ NCCLCHECK(ncclCalloc(&next, 1));
+ next->buff = buff;
+ next->nbytes = nBytes;
+ if (p2p->tail != NULL) p2p->tail->next = next;
+ p2p->tail = next;
+ if (p2p->head == NULL) p2p->head = next;
+ return ncclSuccess;
+}
+
+static ncclResult_t dequeueP2pInfo(ncclP2Plist* p2p) {
+ if (p2p == NULL) return ncclInternalError;
+ struct ncclP2Pinfo* temp = p2p->head;
+ p2p->head = p2p->head->next;
+ if (p2p->tail == temp) p2p->tail = NULL;
+ free(temp);
+ return ncclSuccess;
+}
#endif
diff --git a/src/include/param.h b/src/include/param.h
index 5431757..e4c11df 100644
--- a/src/include/param.h
+++ b/src/include/param.h
@@ -31,10 +31,11 @@ static void setEnvFile(const char* fileName) {
int s=0; // Env Var Size
while (line[s] != '\0' && line[s] != '=') s++;
if (line[s] == '\0') continue;
- strncpy(envVar, line, std::min(1024,s));
+ strncpy(envVar, line, std::min(1023,s));
envVar[s] = '\0';
s++;
- strncpy(envValue, line+s, 1024);
+ strncpy(envValue, line+s, 1023);
+ envValue[1023]='\0';
setenv(envVar, envValue, 0);
}
if (line) free(line);
diff --git a/src/include/proxy.h b/src/include/proxy.h
index 04daa84..9796baf 100644
--- a/src/include/proxy.h
+++ b/src/include/proxy.h
@@ -18,18 +18,23 @@ struct ncclProxyArgs {
proxyProgressFunc_t progress;
struct ncclChannel* channel;
struct ncclConnector* connector;
+ size_t sendbytes;
+ size_t recvbytes;
int sliceSteps;
int chunkSteps;
int nsteps;
uint64_t opCount;
int protocol;
+ int segment; // Only for profiling
ncclDataType_t dtype;
ncclRedOp_t redOp;
int state; // add component before this line -- it is left out during initialization
// Internal state
- uint64_t head;
- uint64_t tail;
+ uint64_t posted;
+ uint64_t received; // Only used by recv proxy to wait for flush.
+ uint64_t transmitted;
+ uint64_t done;
uint64_t end;
void* requests[NCCL_STEPS];
int idle;
@@ -38,14 +43,30 @@ struct ncclProxyArgs {
pthread_mutex_t mutex;
struct ncclProxyArgs* next;
struct ncclProxyArgs* nextPeer;
+ struct ncclProxyArgs* nextGroup;
+ struct ncclProxyArgs** proxyAppendPtr;
+};
+
+struct ncclProxySharedBuffers {
+ int nslots;
+ int slotSize;
+ char* cudaBuff[2*MAXCHANNELS];
+ int* cudaUsed[2*MAXCHANNELS];
+ char* hostBuff[2*MAXCHANNELS];
+ int* hostUsed[2*MAXCHANNELS];
+ struct ncclProxyArgs* proxyAppend[2*MAXCHANNELS]; // Separate send and recv
};
struct ncclProxyPool;
struct ncclProxyState {
pthread_cond_t cond;
- pthread_mutex_t mutex;
+ pthread_mutex_t opsMutex;
+ pthread_mutex_t poolMutex;
bool stop;
+ struct ncclProxySharedBuffers* sharedBuffs;
struct ncclProxyArgs* ops;
+ struct ncclProxyArgs* nextOps;
+ struct ncclProxyArgs* nextOpsEnd;
struct ncclProxyArgs* pool;
struct ncclProxyPool* pools;
};
@@ -59,11 +80,16 @@ enum proxyMode {
};
ncclResult_t ncclProxySaveColl(struct ncclProxyArgs* args, int pattern, int root, int nranks);
-ncclResult_t ncclProxySaveP2p(struct ncclInfo* info, struct ncclChannel* channel);
+ncclResult_t ncclProxySaveP2p(struct ncclInfo* info, struct ncclChannel* channel, int segment);
ncclResult_t ncclProxyStart(struct ncclComm* comm);
ncclResult_t ncclProxyCreate(struct ncclComm* comm);
ncclResult_t ncclProxyDestroy(struct ncclComm* comm);
+ncclResult_t ncclProxySharedBuffersInit(struct ncclComm* comm, int cuda, int* size, char** ptr);
+ncclResult_t ncclProxySharedBuffersAlloc(struct ncclComm* comm, int cuda, int type, int channel, int size, char** ptr);
+ncclResult_t ncclProxySharedBuffersFree(struct ncclComm* comm, int cuda, int type, int channel, int size, char* ptr);
+ncclResult_t ncclProxySharedBuffersDestroy(struct ncclComm* comm);
+
#include <unistd.h>
// Spin wait until func evaluates to true
diff --git a/src/include/socket.h b/src/include/socket.h
index 46b204d..e903b04 100644
--- a/src/include/socket.h
+++ b/src/include/socket.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -21,6 +21,7 @@
#define SLEEP_INT 1000 // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
+#define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV)
/* Common socket address storage structure for IPv4/IPv6 */
union socketAddress {
@@ -64,7 +65,7 @@ static inline int envSocketFamily(void) {
static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
#endif
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
@@ -167,9 +168,9 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) {
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
#endif
- char line_a[1024];
+ char line_a[SOCKET_NAME_MAXLEN+1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
@@ -350,7 +351,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname");
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line));
#endif
@@ -365,6 +366,10 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
/* IPv4/IPv6 support */
int family = remoteAddr->sa.sa_family;
+ if (family != AF_INET && family != AF_INET6) {
+ WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)\n", family, AF_INET, AF_INET6);
+ return ncclInternalError;
+ }
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
/* Connect to a hostname / port */
@@ -381,10 +386,8 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
- char line[1024];
-#ifdef ENABLE_TRACE
+ char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line));
-#endif
int ret;
int timedout_retries = 0;
@@ -445,7 +448,7 @@ static ncclResult_t socketSend(int fd, void* ptr, int size) {
return ncclSuccess;
}
-static ncclResult_t socketReceive(int fd, void* ptr, int size) {
+static ncclResult_t socketRecv(int fd, void* ptr, int size) {
int offset = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset));
return ncclSuccess;
diff --git a/src/include/transport.h b/src/include/transport.h
index 5a85688..2ecc727 100644
--- a/src/include/transport.h
+++ b/src/include/transport.h
@@ -41,8 +41,8 @@ struct ncclConnect {
};
struct ncclTransportComm {
- ncclResult_t (*setup)(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*, struct ncclConnect*, struct ncclConnector*, int channelId);
- ncclResult_t (*connect)(struct ncclConnect*, int nranks, int rank, struct ncclConnector*);
+ ncclResult_t (*setup)(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*, struct ncclConnect*, struct ncclConnector*, int channelId);
+ ncclResult_t (*connect)(struct ncclComm* comm, struct ncclConnect*, int nranks, int rank, struct ncclConnector*);
ncclResult_t (*free)(void*);
ncclResult_t (*proxy)(struct ncclProxyArgs*);
};
@@ -54,6 +54,7 @@ struct ncclTransport {
struct ncclTransportComm recv;
};
-ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend);
+ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend);
+ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph);
#endif
diff --git a/src/include/trees.h b/src/include/trees.h
index 7eadd85..ded84a6 100644
--- a/src/include/trees.h
+++ b/src/include/trees.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -7,7 +7,7 @@
#ifndef NCCL_TREES_H_
#define NCCL_TREES_H_
-ncclResult_t ncclGetBtree(int nranks, int rank, int* u0, int* d1, int* d0);
-ncclResult_t ncclGetDtree(int nranks, int rank, int* u0, int* d0_0, int* d0_1, int* u1, int* d1_0, int* d1_1);
+ncclResult_t ncclGetBtree(int nranks, int rank, int* u0, int* d1, int* d0, int* parentChildType);
+ncclResult_t ncclGetDtree(int nranks, int rank, int* u0, int* d0_0, int* d0_1, int* parentChildType0, int* u1, int* d1_0, int* d1_1, int* parentChildType1);
#endif
diff --git a/src/init.cc b/src/init.cc
index 585db4b..81831cf 100644
--- a/src/init.cc
+++ b/src/init.cc
@@ -158,9 +158,11 @@ void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) {
static ncclResult_t commFree(ncclComm_t comm) {
if (comm == NULL)
return ncclSuccess;
- free(comm->p2plist.peerlist);
- free(comm->p2plist.connect.recv);
- free(comm->p2plist.connect.send);
+ free(comm->connectSend);
+ free(comm->connectRecv);
+ free(comm->p2pSends);
+ free(comm->p2pRecvs);
+ free(comm->asyncOps);
free(comm->peerInfo);
ncclTopoFree(comm->topo);
@@ -191,7 +193,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
free(comm->intraCGMode);
free(comm->intraCC);
}
- CUDACHECK(cudaFreeHost((void *)comm->abortFlag));
+ NCCLCHECK(ncclCudaHostFree((void *)comm->abortFlag));
// Poison comm to try and catch a double free
commPoison(comm);
@@ -218,7 +220,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
struct ncclComm* comm;
NCCLCHECK(ncclCalloc(&comm, 1));
- comm->rank = comm->hostDevComm.rank =rank;
+ comm->rank = comm->hostDevComm.rank = rank;
comm->nRanks = comm->hostDevComm.nRanks = ndev;
cudaGetDevice(&comm->cudaDev);
NCCLCHECK(getBusId(comm->cudaDev, &comm->busId));
@@ -240,11 +242,19 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
comm->argsptr = &comm->args;
comm->collNetSupport = 0;
- comm->p2plist.count=0;
- NCCLCHECK(ncclCalloc(&comm->p2plist.peerlist, comm->nRanks));
- for (int r=0; r<comm->nRanks; r++) comm->p2plist.peerlist[r].sendbytes = comm->p2plist.peerlist[r].recvbytes = -1;
- NCCLCHECK(ncclCalloc(&comm->p2plist.connect.recv, MAXCHANNELS*comm->nRanks));
- NCCLCHECK(ncclCalloc(&comm->p2plist.connect.send, MAXCHANNELS*comm->nRanks));
+
+ NCCLCHECK(ncclCalloc(&comm->asyncOps, NCCL_MAX_OPS));
+ comm->asyncOpCount = 0;
+ comm->asyncTotalSize = 0;
+
+ static_assert(MAXCHANNELS <= sizeof(*comm->connectSend)*8, "comm->connectSend must have enough bits for all channels");
+ static_assert(MAXCHANNELS <= sizeof(*comm->connectRecv)*8, "comm->connectRecv must have enough bits for all channels");
+ NCCLCHECK(ncclCalloc(&comm->connectSend, comm->nRanks));
+ NCCLCHECK(ncclCalloc(&comm->connectRecv, comm->nRanks));
+
+ comm->p2pSendCount = comm->p2pRecvCount = 0;
+ NCCLCHECK(ncclCalloc(&comm->p2pSends, comm->nRanks));
+ NCCLCHECK(ncclCalloc(&comm->p2pRecvs, comm->nRanks));
// Mark channels as non initialized.
for (int c=0; c<MAXCHANNELS; c++) comm->channels[c].id = -1;
@@ -396,8 +406,8 @@ ncclResult_t ncclCommSetIntra(struct ncclComm* comm, int rank, int ranks, struct
#define DEFAULT_LL_BUFFSIZE (NCCL_LL_LINES_PER_THREAD*NCCL_LL_MAX_NTHREADS*NCCL_STEPS*sizeof(union ncclLLFifoLine))
#define DEFAULT_LL128_BUFFSIZE (NCCL_LL128_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS*NCCL_STEPS*sizeof(uint64_t))
-#define DEFAULT_BUFFSIZE (1LL << 22) /* 4MiB */
-#define DEFAULT_BUFFSIZE_ARM (1LL << 20) /* 1MiB */
+#define DEFAULT_BUFFSIZE (1 << 22) /* 4MiB */
+#define DEFAULT_BUFFSIZE_ARM (1 << 20) /* 1MiB */
NCCL_PARAM(BuffSize, "BUFFSIZE", -2);
NCCL_PARAM(LlBuffSize, "LL_BUFFSIZE", -2);
NCCL_PARAM(Ll128BuffSize, "LL128_BUFFSIZE", -2);
@@ -455,7 +465,7 @@ static int collNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGrap
// setup
struct ncclConnect myConnect;
if (isMaster && ret > 0) {
- NCCLCHECK(transportComm->setup(comm->topo, collNetGraph, myInfo, peerInfo, &myConnect, conn, channel->id));
+ NCCLCHECK(transportComm->setup(comm, collNetGraph, myInfo, peerInfo, &myConnect, conn, channel->id));
}
// prepare connect handles
ncclResult_t res;
@@ -485,7 +495,7 @@ static int collNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGrap
}
// connect
if (isMaster && ret > 0) {
- NCCLCHECKGOTO(transportComm->connect(masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
+ NCCLCHECKGOTO(transportComm->connect(comm, masterConnects, nMasters, rankInCollNet, conn), res, cleanup);
struct ncclPeer* devRoot = channel->devPeers+nranks;
struct ncclConnector* devConn = (type == 1) ? &devRoot->recv : &devRoot->send;
CUDACHECKGOTO(cudaMemcpy(devConn, conn, sizeof(struct ncclConnector), cudaMemcpyHostToDevice), res, cleanup);
@@ -543,10 +553,9 @@ NCCL_PARAM(CrossNic, "CROSS_NIC", 2);
NCCL_PARAM(GraphDumpFileRank, "GRAPH_DUMP_FILE_RANK", 0);
static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) {
- // We use 3 AllGathers
- // 1. { peerInfo, comm }
- // 2. ConnectTransport[nranks], ConnectValue[nranks]
- // 3. { nThreads, nrings, compCap, prev[MAXCHANNELS], next[MAXCHANNELS] }
+ // We use 2 AllGathers
+ // 1. { peerInfo, comm, compCap}
+ // 2. { nChannels, graphInfo, topoRanks }
int rank = comm->rank;
int nranks = comm->nRanks;
@@ -558,10 +567,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
struct {
struct ncclPeerInfo peerInfo;
struct ncclComm* comm;
+ int cudaCompCap;
} *allGather1Data;
NCCLCHECK(ncclCalloc(&allGather1Data, nranks));
allGather1Data[rank].comm = comm;
+ allGather1Data[rank].cudaCompCap = ncclCudaCompCap();
struct ncclPeerInfo* myInfo = &allGather1Data[rank].peerInfo;
NCCLCHECK(fillInfo(comm, myInfo, commHash));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather1Data, sizeof(*allGather1Data)));
@@ -574,7 +585,42 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
return ncclInvalidUsage;
}
}
- // AllGather1 data is used again below
+
+ // Compute intra ranks and minimum CUDA Compute capabilities of intra-node GPUs and all GPUs
+ int intraRank0 = -1, intraRank = -1, intraRanks = 0;
+ int myCompCap = allGather1Data[rank].cudaCompCap;
+ int minCompCap = myCompCap, maxCompCap = myCompCap;
+ uint64_t otherHostHash;
+ int tmpNnodes = 1;
+ for (int i = 0; i < nranks; i++) {
+ if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) {
+ if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) {
+ if (intraRanks == 0) intraRank0 = i;
+ if (i == rank) intraRank = intraRanks;
+ intraRanks++;
+ }
+ } else { // Determine whether number of nodes is 2 (for use in tree pattern determination)
+ if (tmpNnodes == 1) {
+ otherHostHash = allGather1Data[i].peerInfo.hostHash;
+ tmpNnodes = 2;
+ } else if (tmpNnodes == 2 && otherHostHash != allGather1Data[i].peerInfo.hostHash) {
+ tmpNnodes = 3;
+ }
+ }
+ minCompCap = std::min(allGather1Data[i].cudaCompCap, minCompCap);
+ maxCompCap = std::max(allGather1Data[i].cudaCompCap, maxCompCap);
+ }
+ TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
+ rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
+ if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) {
+ WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
+ rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
+ return ncclInternalError;
+ }
+ struct ncclComm* intraRank0Comm = allGather1Data[intraRank0].comm;
+
+ free(allGather1Data);
+
// AllGather1 - end
// Topo detection / System graph creation
@@ -603,7 +649,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
struct ncclTopoGraph treeGraph;
treeGraph.id = 1;
- treeGraph.pattern = NCCL_TOPO_PATTERN_SPLIT_TREE;
+ treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
treeGraph.crossNic = ncclParamCrossNic();
treeGraph.collNet = 0;
treeGraph.minChannels = 1;
@@ -627,15 +673,16 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
// AllGather3 - begin
struct ncclGraphInfo {
+ int pattern;
int sameChannels;
float speedIntra;
float speedInter;
int typeIntra;
+ int typeInter;
};
struct {
int cudaCompCap;
- int fullCudaCompCap;
int nChannels;
struct ncclGraphInfo tree;
struct ncclGraphInfo ring;
@@ -644,29 +691,35 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
} *allGather3Data;
NCCLCHECK(ncclCalloc(&allGather3Data, nranks));
- allGather3Data[rank].cudaCompCap = ncclCudaCompCap();
allGather3Data[rank].nChannels = comm->nChannels = treeGraph.nChannels = ringGraph.nChannels =
std::min(treeGraph.nChannels, ringGraph.nChannels);
+ allGather3Data[rank].tree.pattern = treeGraph.pattern;
allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels;
allGather3Data[rank].tree.speedIntra = treeGraph.speedIntra;
allGather3Data[rank].tree.speedInter = treeGraph.speedInter;
allGather3Data[rank].tree.typeIntra = treeGraph.typeIntra;
+ allGather3Data[rank].tree.typeInter = treeGraph.typeInter;
+ allGather3Data[rank].ring.pattern = ringGraph.pattern;
allGather3Data[rank].ring.sameChannels = ringGraph.sameChannels;
allGather3Data[rank].ring.speedIntra = ringGraph.speedIntra;
allGather3Data[rank].ring.speedInter = ringGraph.speedInter;
allGather3Data[rank].ring.typeIntra = ringGraph.typeIntra;
+ allGather3Data[rank].ring.typeInter = ringGraph.typeInter;
+ allGather3Data[rank].collNet.pattern = collNetGraph.pattern;
allGather3Data[rank].collNet.sameChannels = collNetGraph.sameChannels;
allGather3Data[rank].collNet.speedIntra = collNetGraph.speedIntra;
allGather3Data[rank].collNet.speedInter = collNetGraph.speedInter;
allGather3Data[rank].collNet.typeIntra = collNetGraph.typeIntra;
+ allGather3Data[rank].collNet.typeInter = collNetGraph.typeInter;
NCCLCHECK(ncclTopoPreset(comm, &treeGraph, &ringGraph, &collNetGraph, &allGather3Data[rank].topoRanks));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather3Data, sizeof(*allGather3Data)));
// Determine nNodes, firstRanks, ...
- int* nodesFirstRank;
+ int *nodesFirstRank, *nodesTreePatterns;
NCCLCHECK(ncclCalloc(&nodesFirstRank, nranks));
+ NCCLCHECK(ncclCalloc(&nodesTreePatterns, nranks));
for (int i=0; i<nranks; i++) {
int node = -1;
int firstRank = allGather3Data[i].topoRanks.ringRecv[0];
@@ -676,18 +729,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
if (node == -1) {
node = comm->nNodes++;
nodesFirstRank[node] = firstRank;
+ // Record tree pattern of each node as they can be different depending on sm arch
+ nodesTreePatterns[node] = allGather3Data[i].tree.pattern;
}
if (i == comm->rank) comm->node = node;
}
- // Determine the minimum CUDA Compute capability of all GPUs
- int myCompCap = allGather3Data[rank].cudaCompCap;
- int minCompCap = myCompCap, maxCompCap = myCompCap;
- for (int i = 0; i < nranks; i++) {
- minCompCap = std::min(allGather3Data[i].cudaCompCap, minCompCap);
- maxCompCap = std::max(allGather3Data[i].cudaCompCap, maxCompCap);
- }
-
int nChannelsOrig = comm->nChannels;
struct ncclTopoRanks** allTopoRanks;
NCCLCHECK(ncclCalloc(&allTopoRanks, comm->nRanks));
@@ -699,14 +746,17 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
treeGraph.speedIntra = std::min(allGather3Data[i].tree.speedIntra, treeGraph.speedIntra);
treeGraph.speedInter = std::min(allGather3Data[i].tree.speedInter, treeGraph.speedInter);
treeGraph.typeIntra = std::min(allGather3Data[i].tree.typeIntra, treeGraph.typeIntra);
+ treeGraph.typeInter = std::min(allGather3Data[i].tree.typeInter, treeGraph.typeInter);
ringGraph.sameChannels = std::min(allGather3Data[i].ring.sameChannels, ringGraph.sameChannels);
ringGraph.speedIntra = std::min(allGather3Data[i].ring.speedIntra, ringGraph.speedIntra);
ringGraph.speedInter = std::min(allGather3Data[i].ring.speedInter, ringGraph.speedInter);
ringGraph.typeIntra = std::min(allGather3Data[i].ring.typeIntra, ringGraph.typeIntra);
+ ringGraph.typeInter = std::min(allGather3Data[i].ring.typeInter, ringGraph.typeInter);
collNetGraph.sameChannels = std::min(allGather3Data[i].collNet.sameChannels, collNetGraph.sameChannels);
collNetGraph.speedIntra = std::min(allGather3Data[i].collNet.speedIntra, collNetGraph.speedIntra);
collNetGraph.speedInter = std::min(allGather3Data[i].collNet.speedInter, collNetGraph.speedInter);
collNetGraph.typeIntra = std::min(allGather3Data[i].collNet.typeIntra, collNetGraph.typeIntra);
+ collNetGraph.typeInter = std::min(allGather3Data[i].collNet.typeInter, collNetGraph.typeInter);
}
if (comm->nChannels < nChannelsOrig) {
@@ -718,7 +768,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
int *rings;
NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS));
- NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, allTopoRanks, rings));
+ NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings));
if (comm->nNodes > 1 &&
ncclParamCollNetEnable() == 1 &&
collNetSupport() && collNetGraph.nChannels) {
@@ -726,6 +776,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
}
free(allTopoRanks);
+ free(nodesTreePatterns);
free(nodesFirstRank);
free(allGather3Data);
@@ -733,16 +784,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels);
- NCCLCHECK(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph));
-
char line[1024];
line[0]='\0';
for (int c=0; c<comm->nChannels; c++) {
- struct ncclTree* treeUp = &comm->channels[c].treeUp;
- struct ncclTree* treeDn = &comm->channels[c].treeDn;
- snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d|%d->%d->%d/%d/%d",
- c, treeUp->down[0], treeUp->down[1], treeUp->down[2], rank, treeUp->up,
- treeDn->up, rank, treeDn->down[0], treeDn->down[1], treeDn->down[2]);
+ struct ncclTree* tree = &comm->channels[c].tree;
+ snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
+ c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
}
line[1023] = '\0';
INFO(NCCL_INIT, "Trees%s", line);
@@ -757,16 +804,24 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
NCCLCHECK(computeBuffSizes(comm));
// Connect with prev/next for each ring
- struct ncclConnect *connect;
- NCCLCHECKGOTO(ncclCalloc(&connect, 2), ret, affinity_restore);
for (int c=0; c<comm->nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
NCCLCHECKGOTO(setupChannel(comm, c, rank, nranks, rings+c*nranks), ret, affinity_restore);
if (comm->nRanks == 1) continue;
- NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, channel, 1, &channel->ring.prev, 1, &channel->ring.next), ret, affinity_restore);
- NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, channel, NCCL_MAX_TREE_ARITY, channel->treeUp.down, 1, &channel->treeUp.up), ret, affinity_restore);
- NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, channel, 1, &channel->treeDn.up, NCCL_MAX_TREE_ARITY, channel->treeDn.down), ret, affinity_restore);
+ NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->ring.prev, 1, &channel->ring.next), ret, affinity_restore);
}
+ NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph), ret, affinity_restore);
+ INFO(NCCL_INIT, "Connected all rings");
+
+ // Connect Trees
+ for (int c=0; c<comm->nChannels; c++) {
+ struct ncclChannel* channel = comm->channels+c;
+ if (comm->nRanks == 1) continue;
+ NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up), ret, affinity_restore);
+ NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down), ret, affinity_restore);
+ }
+ NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph), ret, affinity_restore);
+ INFO(NCCL_INIT, "Connected all trees");
// Check if we can setup CollNet
if (comm->nNodes > 1 &&
@@ -779,8 +834,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
for (int c=0; c<logicChannels; c++) {
struct ncclChannel* channelRecv = comm->channels+logicChannels+c;
struct ncclChannel* channelSend = comm->channels+c;
- NCCLCHECK(ncclTransportP2pSetup(comm, &collNetGraph, channelRecv, 1, &channelRecv->collTreeDn.up, 1, channelRecv->collTreeDn.down));
- NCCLCHECK(ncclTransportP2pSetup(comm, &collNetGraph, channelSend, 1, channelSend->collTreeUp.down, 1, &channelSend->collTreeUp.up));
+ NCCLCHECK(ncclTransportP2pConnect(comm, channelRecv, 1, &channelRecv->collTree.up, 1, channelRecv->collTree.down));
+ NCCLCHECK(ncclTransportP2pConnect(comm, channelSend, 1, channelSend->collTree.down, 1, &channelSend->collTree.up));
const int recvMaster = collNetGraph.intra[c*comm->localRanks+recvIndex];
const int sendMaster = collNetGraph.intra[c*comm->localRanks+sendIndex];
if (collNetSetup(comm, &collNetGraph, channelRecv, rank, nranks, recvMaster, sendMaster, comm->nNodes, 1) != 1)
@@ -788,39 +843,20 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
else if (collNetSetup(comm, &collNetGraph, channelSend, rank, nranks, sendMaster, recvMaster, comm->nNodes, 0) != 1)
collNetSetupFail = 1;
}
+ NCCLCHECK(ncclTransportP2pSetup(comm, &collNetGraph));
// Verify CollNet setup across ranks
NCCLCHECK(checkCollNetSetup(comm, rank, collNetSetupFail));
}
TRACE(NCCL_INIT, "rank %d nranks %d - CONNECTED %d RINGS AND TREES", rank, nranks, comm->nChannels);
- free(connect);
free(rings);
+ // Compute time models for algorithm and protocol combinations
+ NCCLCHECK(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph));
+
// Compute nChannels per peer for p2p
NCCLCHECK(ncclTopoComputeP2pChannels(comm));
- // Compute intra ranks (using AllGather1 data)
- do {
- int intraRank0 = -1, intraRank = -1, intraRanks = 0;
- for (int i = 0; i < nranks; i++) {
- if ((allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) &&
- (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash)) {
- if (intraRanks == 0) intraRank0 = i;
- if (i == rank) intraRank = intraRanks;
- intraRanks++;
- }
- }
- TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
- rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
- if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) {
- WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d",
- rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0);
- return ncclInternalError;
- }
- NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, allGather1Data[intraRank0].comm));
- } while(0);
-
- // Done with AllGather1 data
- free(allGather1Data);
+ NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, intraRank0Comm));
if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm));
@@ -884,6 +920,7 @@ end:
NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank);
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
int cudaDev;
CUDACHECK(cudaGetDevice(&cudaDev));
NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev));
@@ -892,6 +929,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId comm
NCCL_API(ncclResult_t, ncclCommInitAll, ncclComm_t* comms, int ndev, const int* devlist);
ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
NCCLCHECK(PtrCheck(comms, "CommInitAll", "comms"));
if (ndev < 0) {
WARN("Invalid device count requested : %d", ndev);
@@ -911,9 +949,6 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
static ncclResult_t commDestroy(ncclComm_t comm) {
int savedDevice;
-#ifdef ENABLE_TRACE
- int rank = comm->rank;
-#endif
CUDACHECK(cudaGetDevice(&savedDevice));
int commDevice = comm->cudaDev;
@@ -921,7 +956,7 @@ static ncclResult_t commDestroy(ncclComm_t comm) {
CUDACHECK(cudaSetDevice(commDevice));
}
- TRACE(NCCL_INIT, "Destroying comm %p rank %d abortFlag %d fatalError %d", comm, rank, *comm->abortFlag, comm->fatalError);
+ TRACE(NCCL_INIT, "Destroying comm %p rank %d abortFlag %d fatalError %d", comm, comm->rank, *comm->abortFlag, comm->fatalError);
CUDACHECK(cudaStreamSynchronize(comm->groupStream));
NCCLCHECK(ncclProxyDestroy(comm));
@@ -930,13 +965,14 @@ static ncclResult_t commDestroy(ncclComm_t comm) {
if (savedDevice != commDevice)
CUDACHECK(cudaSetDevice(savedDevice));
- TRACE(NCCL_INIT, "Destroyed comm %p rank %d", comm, rank);
+ TRACE(NCCL_INIT, "Destroyed comm %p rank %d", comm, comm->rank);
return ncclSuccess;
}
NCCL_API(ncclResult_t, ncclCommDestroy, ncclComm_t comm);
ncclResult_t ncclCommDestroy(ncclComm_t comm) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
if (comm == NULL)
return ncclSuccess;
@@ -953,6 +989,7 @@ ncclResult_t ncclCommDestroy(ncclComm_t comm) {
NCCL_API(ncclResult_t, ncclCommAbort, ncclComm_t comm);
ncclResult_t ncclCommAbort(ncclComm_t comm) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
if (comm == NULL)
return ncclSuccess;
@@ -985,6 +1022,7 @@ ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) {
NCCL_API(ncclResult_t, ncclCommCount, const ncclComm_t comm, int* count);
ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
NCCLCHECK(PtrCheck(comm, "CommCount", "comm"));
NCCLCHECK(PtrCheck(count, "CommCount", "count"));
*count = comm->nRanks;
@@ -993,6 +1031,7 @@ ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
NCCL_API(ncclResult_t, ncclCommCuDevice, const ncclComm_t comm, int* devid);
ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
NCCLCHECK(PtrCheck(comm, "CommCuDevice", "comm"));
NCCLCHECK(PtrCheck(devid, "CommCuDevice", "devid"));
*devid = comm->cudaDev;
@@ -1001,6 +1040,7 @@ ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* devid) {
NCCL_API(ncclResult_t, ncclCommUserRank, const ncclComm_t comm, int* rank);
ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
+ NVTX3_FUNC_RANGE_IN(nccl_domain);
NCCLCHECK(PtrCheck(comm, "CommUserRank", "comm"));
NCCLCHECK(PtrCheck(rank, "CommUserRank", "rank"));
*rank = comm->rank;
diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc
index 27623b2..c262f8c 100644
--- a/src/misc/argcheck.cc
+++ b/src/misc/argcheck.cc
@@ -45,11 +45,11 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
}
// Type is OK, compute nbytes. Convert Allgather/Broadcast/P2P calls to chars.
info->nBytes = info->count * ncclTypeSize(info->datatype);
- if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) {
+ if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) {
info->count = info->nBytes;
info->datatype = ncclInt8;
}
- if (info->coll == ncclCollAllGather || info->coll == ncclCollReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank
+ if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank
if (info->op < 0 || info->op >= ncclNumOps) {
WARN("%s : invalid reduction operation %d", info->opName, info->op);
@@ -57,7 +57,7 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
}
if (info->comm->checkPointers) {
- if (info->coll == ncclCollSendRecv) {
+ if (info->coll == ncclFuncSendRecv) {
if (strcmp(info->opName, "Send") == 0) {
NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", "Send"));
} else {
@@ -65,10 +65,10 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
}
} else {
// Check CUDA device pointers
- if (info->coll != ncclCollBroadcast || info->comm->rank == info->root) {
+ if (info->coll != ncclFuncBroadcast || info->comm->rank == info->root) {
NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName));
}
- if (info->coll != ncclCollReduce || info->comm->rank == info->root) {
+ if (info->coll != ncclFuncReduce || info->comm->rank == info->root) {
NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName));
}
}
diff --git a/src/misc/nvmlwrap.cc b/src/misc/nvmlwrap.cc
index 34ed0aa..e83392d 100644
--- a/src/misc/nvmlwrap.cc
+++ b/src/misc/nvmlwrap.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -16,14 +16,11 @@ static nvmlReturn_t (*nvmlInternalInit)(void);
static nvmlReturn_t (*nvmlInternalShutdown)(void);
static nvmlReturn_t (*nvmlInternalDeviceGetHandleByPciBusId)(const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t (*nvmlInternalDeviceGetIndex)(nvmlDevice_t device, unsigned* index);
-static nvmlReturn_t (*nvmlInternalDeviceGetHandleByIndex)(unsigned int index, nvmlDevice_t* device);
static const char* (*nvmlInternalErrorString)(nvmlReturn_t r);
static nvmlReturn_t (*nvmlInternalDeviceGetNvLinkState)(nvmlDevice_t device, unsigned int link, nvmlEnableState_t *isActive);
-static nvmlReturn_t (*nvmlInternalDeviceGetPciInfo)(nvmlDevice_t device, nvmlPciInfo_t* pci);
static nvmlReturn_t (*nvmlInternalDeviceGetNvLinkRemotePciInfo)(nvmlDevice_t device, unsigned int link, nvmlPciInfo_t *pci);
static nvmlReturn_t (*nvmlInternalDeviceGetNvLinkCapability)(nvmlDevice_t device, unsigned int link,
nvmlNvLinkCapability_t capability, unsigned int *capResult);
-static nvmlReturn_t (*nvmlInternalDeviceGetMinorNumber)(nvmlDevice_t device, unsigned int* minorNumber);
static nvmlReturn_t (*nvmlInternalDeviceGetCudaComputeCapability)(nvmlDevice_t device, int* major, int* minor);
// Used to make the NVML library calls thread safe
@@ -74,10 +71,7 @@ ncclResult_t wrapNvmlSymbols(void) {
LOAD_SYM(nvmlhandle, "nvmlShutdown", nvmlInternalShutdown);
LOAD_SYM(nvmlhandle, "nvmlDeviceGetHandleByPciBusId", nvmlInternalDeviceGetHandleByPciBusId);
LOAD_SYM(nvmlhandle, "nvmlDeviceGetIndex", nvmlInternalDeviceGetIndex);
- LOAD_SYM(nvmlhandle, "nvmlDeviceGetHandleByIndex", nvmlInternalDeviceGetHandleByIndex);
LOAD_SYM(nvmlhandle, "nvmlErrorString", nvmlInternalErrorString);
- LOAD_SYM(nvmlhandle, "nvmlDeviceGetPciInfo", nvmlInternalDeviceGetPciInfo);
- LOAD_SYM(nvmlhandle, "nvmlDeviceGetMinorNumber", nvmlInternalDeviceGetMinorNumber);
LOAD_SYM_OPTIONAL(nvmlhandle, "nvmlDeviceGetNvLinkState", nvmlInternalDeviceGetNvLinkState);
LOAD_SYM_OPTIONAL(nvmlhandle, "nvmlDeviceGetNvLinkRemotePciInfo", nvmlInternalDeviceGetNvLinkRemotePciInfo);
LOAD_SYM_OPTIONAL(nvmlhandle, "nvmlDeviceGetNvLinkCapability", nvmlInternalDeviceGetNvLinkCapability);
@@ -91,9 +85,6 @@ teardown:
nvmlInternalShutdown = NULL;
nvmlInternalDeviceGetHandleByPciBusId = NULL;
nvmlInternalDeviceGetIndex = NULL;
- nvmlInternalDeviceGetHandleByIndex = NULL;
- nvmlInternalDeviceGetPciInfo = NULL;
- nvmlInternalDeviceGetMinorNumber = NULL;
nvmlInternalDeviceGetNvLinkState = NULL;
nvmlInternalDeviceGetNvLinkRemotePciInfo = NULL;
nvmlInternalDeviceGetNvLinkCapability = NULL;
@@ -162,51 +153,6 @@ ncclResult_t wrapNvmlDeviceGetIndex(nvmlDevice_t device, unsigned* index) {
return ncclSuccess;
}
-ncclResult_t wrapNvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t* device) {
- if (nvmlInternalDeviceGetHandleByIndex == NULL) {
- WARN("lib wrapper not initialized.");
- return ncclInternalError;
- }
- nvmlReturn_t ret;
- NVMLLOCKCALL(nvmlInternalDeviceGetHandleByIndex(index, device), ret);
- if (ret != NVML_SUCCESS) {
- WARN("nvmlDeviceGetHandleByIndex() failed: %s ",
- nvmlInternalErrorString(ret));
- return ncclSystemError;
- }
- return ncclSuccess;
-}
-
-ncclResult_t wrapNvmlDeviceGetPciInfo(nvmlDevice_t device, nvmlPciInfo_t* pci) {
- if (nvmlInternalDeviceGetPciInfo == NULL) {
- WARN("lib wrapper not initialized.");
- return ncclInternalError;
- }
- nvmlReturn_t ret;
- NVMLLOCKCALL(nvmlInternalDeviceGetPciInfo(device, pci), ret);
- if (ret != NVML_SUCCESS) {
- WARN("nvmlDeviceGetPciInfo() failed: %s ",
- nvmlInternalErrorString(ret));
- return ncclSystemError;
- }
- return ncclSuccess;
-}
-
-ncclResult_t wrapNvmlDeviceGetMinorNumber(nvmlDevice_t device, unsigned int* minorNumber) {
- if (nvmlInternalDeviceGetMinorNumber == NULL) {
- WARN("lib wrapper not initialized.");
- return ncclInternalError;
- }
- nvmlReturn_t ret;
- NVMLLOCKCALL(nvmlInternalDeviceGetMinorNumber(device, minorNumber), ret);
- if (ret != NVML_SUCCESS) {
- WARN("nvmlDeviceGetMinorNumber() failed: %s ",
- nvmlInternalErrorString(ret));
- return ncclSystemError;
- }
- return ncclSuccess;
-}
-
ncclResult_t wrapNvmlDeviceGetNvLinkState(nvmlDevice_t device, unsigned int link, nvmlEnableState_t *isActive) {
if (nvmlInternalDeviceGetNvLinkState == NULL) {
/* Do not warn, this symbol is optional. */
diff --git a/src/proxy.cc b/src/proxy.cc
index 19dbced..d3824f2 100644
--- a/src/proxy.cc
+++ b/src/proxy.cc
@@ -6,10 +6,10 @@
#include "comm.h"
#include "info.h"
+#include "graph.h"
#include "collectives.h"
-#define RECV 0
-#define SEND 1
+enum { proxyRecv=0, proxySend=1 };
static bool NeedProxy(int type, int pattern, int root, struct ncclRing* ring, int nranks) {
if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice) return true;
@@ -19,15 +19,13 @@ static bool NeedProxy(int type, int pattern, int root, struct ncclRing* ring, in
const int myrank = 0, nextrank = 1, prevrank = nranks-1;
int index = pattern == ncclPatternPipelineFrom ?
/* no recv / no send if root = */
- /* bcast */ (type == RECV ? myrank : nextrank ):
- /* reduce */ (type == RECV ? prevrank : myrank );
+ /* bcast */ (type == proxyRecv ? myrank : nextrank ):
+ /* reduce */ (type == proxyRecv ? prevrank : myrank );
int rank = ring->userRanks[index];
return (root != rank);
}
-enum { proxyRecv=0, proxySend=1 };
-
-#define PROXYARGS_ALLOCATE_SIZE 32
+#define PROXYARGS_ALLOCATE_SIZE 128
struct ncclProxyPool {
struct ncclProxyPool *next;
struct ncclProxyArgs elems[PROXYARGS_ALLOCATE_SIZE];
@@ -36,7 +34,7 @@ struct ncclProxyPool {
static ncclResult_t allocateArgs(struct ncclComm* comm, struct ncclProxyArgs** argsptr) {
struct ncclProxyState* state = &comm->proxyState;
struct ncclProxyArgs* elem;
- pthread_mutex_lock(&state->mutex);
+ pthread_mutex_lock(&state->poolMutex);
if (state->pool == NULL) {
// Allocate a new pool of elements
struct ncclProxyPool* newPool;
@@ -54,39 +52,113 @@ static ncclResult_t allocateArgs(struct ncclComm* comm, struct ncclProxyArgs** a
}
elem = state->pool;
state->pool = state->pool->next;
- pthread_mutex_unlock(&state->mutex);
- elem->next = elem->nextPeer = NULL;
+ pthread_mutex_unlock(&state->poolMutex);
+ elem->next = elem->nextPeer = elem->nextGroup = NULL;
*argsptr = elem;
return ncclSuccess;
}
-static void ProxyAppend(struct ncclConnector* connector, struct ncclProxyArgs* args) {
- struct ncclComm* comm = connector->comm;
- struct ncclProxyState* state = &comm->proxyState;
- pthread_mutex_lock(&state->mutex);
- if (connector->proxyAppend == NULL) {
- // Nothing running for that peer. Add to the circular list
+//#define DEBUG_PROXY 1
+#ifdef DEBUG_PROXY
+#define DEBUG_PROXY_PRINT printf
+#else
+#define DEBUG_PROXY_PRINT(...)
+#endif
+
+#define OP_INDEX(op) ((op) ? (op)-state->pools->elems : -1)
+#define OP_SEEN 0x100000
+ncclResult_t dumpProxyState(struct ncclProxyState* state) {
+#ifdef DEBUG_PROXY
+ struct ncclProxyArgs* op = state->ops;
+ while (op) {
+ if (op->idle & OP_SEEN) {
+ WARN("Active list loop at element %ld\n", OP_INDEX(op));
+ }
+ op->idle |= OP_SEEN;
+ printf("[%ld]", OP_INDEX(op));
+ if (op->nextPeer) {
+ printf("(%ld)", OP_INDEX(op->nextPeer));
+ struct ncclProxyArgs* n = op->nextPeer;
+ n->idle |= OP_SEEN;
+ while (n->nextGroup || n->nextPeer) {
+ n = n->nextGroup ? n->nextGroup : n->nextPeer;
+ n->idle |= OP_SEEN;
+ }
+ }
+ if (op->nextGroup) {
+ printf("--G->");
+ op = op->nextGroup;
+ } else {
+ printf("--N->");
+ op = op->next;
+ }
+ }
+ printf("[X]\n");
+
+ struct ncclProxyArgs* free = state->pool;
+ while (free) {
+ if (free->idle & OP_SEEN) {
+ WARN("Free list loop at element %ld\n", OP_INDEX(free));
+ }
+ free->idle |= OP_SEEN;
+ free = free->next;
+ }
+
+ struct ncclProxyPool* p = state->pools;
+ int i = 0;
+ while (p) {
+ for (int e=0; e<PROXYARGS_ALLOCATE_SIZE; e++) {
+ if ((p->elems[e].idle & OP_SEEN) == 0) {
+ WARN("Element %d of pool %d has been lost\n", e, i);
+ struct ncclProxyArgs* free = state->pool;
+ printf("Free list ");
+ while (free) {
+ printf("--> %ld ", OP_INDEX(free));
+ free = free->next;
+ }
+ printf("\n");
+ return ncclInternalError;
+ }
+ p->elems[e].idle -= OP_SEEN;
+ }
+ p = p->next;
+ i++;
+ }
+#endif
+ return ncclSuccess;
+}
+
+static ncclResult_t ProxyAppend(struct ncclProxyState* state, struct ncclProxyArgs* args, int shared) {
+ struct ncclProxyArgs* proxyAppend = *args->proxyAppendPtr;
+ if (proxyAppend) {
+ if (shared && proxyAppend->opCount == args->opCount) {
+ args->next = proxyAppend->next;
+ proxyAppend->next = NULL;
+ proxyAppend->nextGroup = args;
+ DEBUG_PROXY_PRINT("Insert %5ld (%d/%5ld/%5ld) as group, prevGroup %5ld, next %5ld : \n", OP_INDEX(args), shared, proxyAppend->opCount, args->opCount, OP_INDEX(proxyAppend), OP_INDEX(args->next));
+ } else {
+ proxyAppend->nextPeer = args;
+ DEBUG_PROXY_PRINT("Insert %5ld (%d/%5ld/%5ld) as nextPeer of %5ld : \n", OP_INDEX(args), shared, proxyAppend->opCount, args->opCount, OP_INDEX(proxyAppend));
+ }
+ } else {
+ // Nothing running for that peer. Add to the list
if (state->ops == NULL) {
// Create the list
- args->next = args;
+ DEBUG_PROXY_PRINT("Insert %5ld (%d/%5ld) as first element : \n", OP_INDEX(args), shared, args->opCount);
state->ops = args;
} else {
- // Insert element in the list
- args->next = state->ops->next;
- state->ops->next = args;
+ // Append element at the end of the list
+ struct ncclProxyArgs* last = state->ops;
+ while (last->nextGroup || last->next) last = last->nextGroup ? last->nextGroup : last->next;
+ last->next = args;
+ DEBUG_PROXY_PRINT("Insert %5ld (%d/%5ld) as last element : \n", OP_INDEX(args),shared, args->opCount);
}
- connector->proxyAppend = args;
- } else {
- // There is an active operation already for that peer.
- // Add it to the per-peer list
- connector->proxyAppend->nextPeer = args;
- connector->proxyAppend = args;
}
- pthread_mutex_unlock(&state->mutex);
+ *(args->proxyAppendPtr) = args;
+ return ncclSuccess;
}
-template <int type>
-static ncclResult_t SaveProxy(int peer, struct ncclProxyArgs* args) {
+static ncclResult_t SaveProxy(int type, int peer, struct ncclProxyArgs* args) {
if (peer < 0) return ncclSuccess;
struct ncclPeer* peerComm = args->channel->peers+peer;
@@ -98,69 +170,169 @@ static ncclResult_t SaveProxy(int peer, struct ncclProxyArgs* args) {
}
if (connector->transportComm->proxy == NULL) return ncclSuccess;
+ struct ncclProxyState* state = &connector->comm->proxyState;
struct ncclProxyArgs* op;
NCCLCHECK(allocateArgs(connector->comm, &op));
memcpy(op, args, sizeof(struct ncclProxyArgs));
op->connector = connector;
op->progress = connector->transportComm->proxy;
op->state = ncclProxyOpReady;
- ProxyAppend(connector, op);
+
+ op->proxyAppendPtr =
+ connector->conn.shared ?
+ state->sharedBuffs->proxyAppend+2*args->channel->id+type : // Shared buffers
+ &connector->proxyAppend; // Dedicated buffers
+
+ if (state->nextOps == NULL) state->nextOps = op;
+ else state->nextOpsEnd->next = op;
+ state->nextOpsEnd = op;
return ncclSuccess;
}
ncclResult_t ncclProxySaveColl(struct ncclProxyArgs* args, int pattern, int root, int nranks) {
if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice || pattern == ncclPatternPipelineFrom || pattern == ncclPatternPipelineTo) {
struct ncclRing* ring = &args->channel->ring;
- if (NeedProxy(RECV, pattern, root, ring, nranks)) NCCLCHECK(SaveProxy<proxyRecv>(ring->prev, args));
- if (NeedProxy(SEND, pattern, root, ring, nranks)) NCCLCHECK(SaveProxy<proxySend>(ring->next, args));
+ if (NeedProxy(proxyRecv, pattern, root, ring, nranks)) NCCLCHECK(SaveProxy(proxyRecv, ring->prev, args));
+ if (NeedProxy(proxySend, pattern, root, ring, nranks)) NCCLCHECK(SaveProxy(proxySend, ring->next, args));
}
if (pattern == ncclPatternTreeUp || pattern == ncclPatternTreeUpDown) {
// Tree up
- struct ncclTree* tree = &args->channel->treeUp;
- for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxyRecv>(tree->down[i], args));
- NCCLCHECK(SaveProxy<proxySend>(tree->up, args));
+ struct ncclTree* tree = &args->channel->tree;
+ for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy(proxyRecv, tree->down[i], args));
+ NCCLCHECK(SaveProxy(proxySend, tree->up, args));
}
if (pattern == ncclPatternTreeDown || pattern == ncclPatternTreeUpDown) {
// Tree down
- struct ncclTree* tree = &args->channel->treeDn;
- for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxySend>(tree->down[i], args));
- NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args));
+ struct ncclTree* tree = &args->channel->tree;
+ for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy(proxySend, tree->down[i], args));
+ NCCLCHECK(SaveProxy(proxyRecv, tree->up, args));
}
if (pattern == ncclPatternCollTreeUp) {
// CollTree up
- struct ncclTree* tree = &args->channel->collTreeUp;
- NCCLCHECK(SaveProxy<proxyRecv>(tree->down[0], args));
- NCCLCHECK(SaveProxy<proxySend>(tree->up, args));
+ struct ncclTree* tree = &args->channel->collTree;
+ NCCLCHECK(SaveProxy(proxyRecv, tree->down[0], args));
+ NCCLCHECK(SaveProxy(proxySend, tree->up, args));
}
if (pattern == ncclPatternCollTreeDown) {
// CollTree down
- struct ncclTree* tree = &args->channel->collTreeDn;
- NCCLCHECK(SaveProxy<proxySend>(tree->down[0], args));
- NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args));
+ struct ncclTree* tree = &args->channel->collTree;
+ NCCLCHECK(SaveProxy(proxySend, tree->down[0], args));
+ NCCLCHECK(SaveProxy(proxyRecv, tree->up, args));
}
return ncclSuccess;
}
-ncclResult_t ncclProxySaveP2p(struct ncclInfo* info, struct ncclChannel* channel) {
+ncclResult_t ncclProxySaveP2p(struct ncclInfo* info, struct ncclChannel* channel, int segment) {
struct ncclProxyArgs args;
memset(&args, 0, sizeof(struct ncclProxyArgs));
args.channel = channel;
args.sliceSteps = 1;
args.chunkSteps = 1;
args.protocol = NCCL_PROTO_SIMPLE;
- args.opCount = info->comm->opCount;
+ args.segment = segment;
+ args.opCount = channel->workFifoTail-1;
args.dtype = info->datatype;
+ if (info->delta > 0 && info->recvbytes >= 0) {
+ int peerrecv = (info->comm->nRanks+info->comm->rank-info->delta)%info->comm->nRanks;
+ args.nsteps = DIVUP(info->recvbytes, info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/SENDRECV_SLICEFACTOR);
+ if (args.nsteps == 0) args.nsteps = 1;
+ args.recvbytes = info->recvbytes;
+ args.sendbytes = 0;
+ NCCLCHECK(SaveProxy(proxyRecv, peerrecv, &args));
+ }
if (info->delta > 0 && info->sendbytes >= 0) {
int peersend = (info->comm->rank+info->delta)%info->comm->nRanks;
args.nsteps = DIVUP(info->sendbytes, info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/SENDRECV_SLICEFACTOR);
if (args.nsteps == 0) args.nsteps = 1;
- NCCLCHECK(SaveProxy<proxySend>(peersend, &args));
+ args.sendbytes = info->sendbytes;
+ args.recvbytes = 0;
+ NCCLCHECK(SaveProxy(proxySend, peersend, &args));
}
- if (info->delta > 0 && info->recvbytes >= 0) {
- int peerrecv = (info->comm->nRanks+info->comm->rank-info->delta)%info->comm->nRanks;
- args.nsteps = DIVUP(info->recvbytes, info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/SENDRECV_SLICEFACTOR);
- if (args.nsteps == 0) args.nsteps = 1;
- NCCLCHECK(SaveProxy<proxyRecv>(peerrecv, &args));
+ return ncclSuccess;
+}
+
+static ncclResult_t removeOp(struct ncclProxyState* state, struct ncclProxyArgs** opPtr, struct ncclProxyArgs** prevOpPtr, struct ncclProxyArgs** prevGroupPtr) {
+ struct ncclProxyArgs* freeOp = *opPtr;
+ DEBUG_PROXY_PRINT("Remove %ld/%ld -> %ld -> %ld/%ld\n", OP_INDEX(*prevOpPtr), OP_INDEX(*prevGroupPtr), OP_INDEX(freeOp), OP_INDEX(freeOp->next), OP_INDEX(freeOp->nextGroup));
+ if (*prevGroupPtr && *prevOpPtr) return ncclInternalError;
+ if (freeOp->nextGroup) {
+ // Part of a group : remove the element
+ struct ncclProxyArgs* next = freeOp->nextGroup;
+ *opPtr = next;
+ if (*prevGroupPtr) {
+ (*prevGroupPtr)->nextGroup = next;
+ } else if (*prevOpPtr) {
+ (*prevOpPtr)->next = next;
+ } else {
+ state->ops = next;
+ }
+ } else {
+ struct ncclProxyArgs* next = freeOp->next;
+ *opPtr = next;
+ if ((*prevGroupPtr)) {
+ (*prevGroupPtr)->next = next;
+ (*prevGroupPtr)->nextGroup = NULL;
+ (*prevGroupPtr)->nextPeer = freeOp->nextPeer;
+ if (*(freeOp->proxyAppendPtr) == freeOp) *(freeOp->proxyAppendPtr) = *prevGroupPtr;
+ (*prevOpPtr) = *prevGroupPtr;
+ (*prevGroupPtr) = NULL;
+ } else {
+ if (freeOp->nextPeer) {
+ // replace op by nextPeer
+ struct ncclProxyArgs* nextPeer = freeOp->nextPeer;
+ if (*prevOpPtr) {
+ (*prevOpPtr)->next = nextPeer;
+ } else {
+ state->ops = nextPeer;
+ }
+ struct ncclProxyArgs* lastGroup = nextPeer;
+ while (lastGroup->nextGroup) lastGroup = lastGroup->nextGroup;
+ lastGroup->next = next;
+ *(prevOpPtr) = lastGroup;
+ } else {
+ *(freeOp->proxyAppendPtr) = NULL;
+ if (*prevOpPtr) {
+ (*prevOpPtr)->next = next;
+ } else {
+ state->ops = next;
+ }
+ }
+ }
+ }
+ pthread_mutex_lock(&state->poolMutex);
+ freeOp->next = state->pool;
+ state->pool = freeOp;
+ pthread_mutex_unlock(&state->poolMutex);
+ DEBUG_PROXY_PRINT("Removed %5ld (%5ld) : ", OP_INDEX(freeOp), OP_INDEX(*freeOp->proxyAppendPtr));
+ NCCLCHECK(dumpProxyState(state));
+ return ncclSuccess;
+}
+
+static ncclResult_t progressOps(struct ncclProxyState* state, struct ncclProxyArgs** opsPtr, int* idle, struct ncclComm* comm) {
+ struct ncclProxyArgs* prevOp = NULL;
+ struct ncclProxyArgs* prevGroup = NULL;
+ struct ncclProxyArgs* op = *opsPtr;
+ while (op) {
+ if (op->state == ncclProxyOpNone) return ncclInternalError;
+ // opCount >= lastOpCount are part of an ongoing GroupStart/GroupEnd that hasn't started
+ // yet and might be cancelled before they even start. Hold on on those.
+ if (op->opCount < comm->lastOpCount) {
+ NCCLCHECK(op->progress(op));
+ *idle &= op->idle;
+ }
+ if (op->state == ncclProxyOpNone) {
+ NCCLCHECK(removeOp(state, &op, &prevOp, &prevGroup));
+ } else {
+ if (op->nextGroup) {
+ prevGroup = op;
+ prevOp = NULL;
+ op = op->nextGroup;
+ } else {
+ prevOp = op;
+ prevGroup = NULL;
+ op = op->next;
+ }
+ }
}
return ncclSuccess;
}
@@ -168,91 +340,170 @@ ncclResult_t ncclProxySaveP2p(struct ncclInfo* info, struct ncclChannel* channel
void* persistentThread(void *comm_) {
struct ncclComm* comm = (struct ncclComm*)comm_;
struct ncclProxyState* state = &comm->proxyState;
- struct ncclProxyArgs* op = NULL;
- ncclResult_t ret = ncclSuccess;
- int idle = 1;
- int idleSpin = 0;
+ char threadName[16];
+ sprintf(threadName, "NCCLproxy %5d", comm->rank);
+ nvtxNameOsThreadA(syscall(SYS_gettid), threadName);
+
+ pthread_mutex_lock(&state->opsMutex);
+ struct ncclProxyArgs** opsPtr = &state->ops;
while (1) {
- do {
- if (*comm->abortFlag) return NULL;
- if (op == NULL) {
- pthread_mutex_lock(&state->mutex);
- op = state->ops;
- if (op == NULL) {
- if (state->stop) {
- // No more commands to process and proxy has been requested to stop
- pthread_mutex_unlock(&state->mutex);
- return NULL;
- }
- pthread_cond_wait(&state->cond, &state->mutex);
- }
- pthread_mutex_unlock(&state->mutex);
+ if (*comm->abortFlag) {
+ pthread_mutex_unlock(&state->opsMutex);
+ return NULL;
+ }
+
+ while (*opsPtr == NULL) {
+ if (state->stop) {
+ // No more commands to process and proxy has been requested to stop
+ pthread_mutex_unlock(&state->opsMutex);
+ return NULL;
}
- } while (op == NULL);
- op->idle = 0;
- // opCount >= lastOpCount are part of an ongoing GroupStart/GroupEnd that hasn't started
- // yet and might be cancelled before they even start. Hold on on those.
- if (op->state != ncclProxyOpNone && op->opCount < comm->lastOpCount) ret = op->progress(op);
+ pthread_cond_wait(&state->cond, &state->opsMutex);
+ }
+ int idle = 1;
+ ncclResult_t ret = progressOps(state, opsPtr, &idle, comm);
if (ret != ncclSuccess) {
comm->fatalError = ret;
INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret);
+ pthread_mutex_unlock(&state->opsMutex);
return NULL;
}
- idle &= op->idle;
- pthread_mutex_lock(&state->mutex);
- if (!idle) idleSpin = 0;
- struct ncclProxyArgs *next = op->next;
- if (next->state == ncclProxyOpNone) {
- struct ncclProxyArgs *freeOp = next;
- if (next->nextPeer) {
- // Replace next by its next per-peer element.
- next = next->nextPeer;
- if (op != freeOp) {
- next->next = freeOp->next;
- op->next = next;
- } else {
- next->next = next;
- }
- } else {
- // Remove next from circular list
- next->connector->proxyAppend = NULL;
- if (op != freeOp) {
- next = next->next;
- op->next = next;
- } else {
- next = NULL;
- }
- }
- if (freeOp == state->ops) state->ops = next;
- freeOp->next = state->pool;
- state->pool = freeOp;
+ if (idle) {
+ pthread_mutex_unlock(&state->opsMutex);
+ sched_yield(); // No request progressed. Let others run.
+ pthread_mutex_lock(&state->opsMutex);
}
+ }
+}
+
+ncclResult_t ncclProxyStart(struct ncclComm* comm) {
+ struct ncclProxyState* state = &comm->proxyState;
+ pthread_mutex_lock(&state->opsMutex);
+
+ // Sort operations as we append them : collectives and
+ // receives first, then sends.
+ ncclProxyArgs* next, *prev = NULL, *op = state->nextOps;
+ while (op) {
+ next = op->next;
+ if (op->sendbytes) {
+ if (prev) prev->next = next;
+ else state->nextOps = next;
+ op->next = NULL;
+ NCCLCHECK(ProxyAppend(state, op, op->connector->conn.shared));
+ } else prev = op;
op = next;
- if (op == state->ops) {
- if (idle == 1) {
- if (++idleSpin == 10) {
- sched_yield();
- idleSpin = 0;
- }
- }
- idle = 1;
+ }
+ op = state->nextOps;
+ while (op) {
+ next = op->next;
+ op->next = NULL;
+ NCCLCHECK(ProxyAppend(state, op, op->connector->conn.shared));
+ op = next;
+ }
+ state->nextOps = state->nextOpsEnd = NULL;
+ NCCLCHECK(dumpProxyState(state));
+
+ if (state->ops != NULL)
+ pthread_cond_signal(&state->cond);
+ pthread_mutex_unlock(&state->opsMutex);
+ return ncclSuccess;
+}
+
+NCCL_PARAM(ProxySharedBuffersCount, "SHARED_BUFF_COUNT", -2);
+
+ncclResult_t ncclProxySharedBuffersInit(struct ncclComm* comm, int cuda, int* size, char** ptr) {
+ struct ncclProxySharedBuffers* state = comm->proxyState.sharedBuffs;
+ if (state == NULL) {
+ NCCLCHECK(ncclCalloc(&state, 1));
+ comm->proxyState.sharedBuffs = state;
+ state->nslots = ncclParamProxySharedBuffersCount();
+ if (state->nslots == -2) {
+ state->nslots = NCCL_STEPS*NCCL_MAX_WORK_ELEMENTS;
+ }
+ state->slotSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(NCCL_STEPS*SENDRECV_SLICEFACTOR);
+ }
+
+ char* buff;
+ int* used;
+ *size = 2*comm->p2pnChannels*state->slotSize*state->nslots;
+
+ if (cuda && state->cudaBuff[0] == NULL) {
+ NCCLCHECK(ncclCudaCalloc(&buff, *size));
+ NCCLCHECK(ncclCalloc(&used, 2*comm->p2pnChannels*state->nslots));
+ for (int i=0; i<2*comm->p2pnChannels; i++) {
+ state->cudaBuff[i] = buff + state->nslots*state->slotSize*i;
+ state->cudaUsed[i] = used + state->nslots*i;
+ }
+ } else if (state->hostBuff[0] == NULL) {
+ NCCLCHECK(ncclCudaHostCalloc(&buff, *size));
+ NCCLCHECK(ncclCalloc(&used, 2*comm->p2pnChannels*state->nslots));
+ for (int i=0; i<2*comm->p2pnChannels; i++) {
+ state->hostBuff[i] = buff + state->nslots*state->slotSize*i;
+ state->hostUsed[i] = used + state->nslots*i;
+ }
+ }
+ buff = cuda ? state->cudaBuff[0] : state->hostBuff[0];
+
+ *ptr = buff;
+ return ncclSuccess;
+}
+
+ncclResult_t ncclProxySharedBuffersAlloc(struct ncclComm* comm, int cuda, int type, int channel, int size, char** ptr) {
+ struct ncclProxySharedBuffers* state = comm->proxyState.sharedBuffs;
+ // Use different pools for different channels and also separate send/recv.
+ int p = 2*channel+type;
+ int* used = cuda ? state->cudaUsed[p] : state->hostUsed[p];
+ char* buff = cuda ? state->cudaBuff[p] : state->hostBuff[p];
+ if (buff == NULL) return ncclInternalError;
+ int nslots = 1;
+ while (nslots*state->slotSize < size) nslots *= 2;
+ for (int s=0; s<state->nslots; s+=nslots) {
+ int u = 0;
+ for (int i=0; i<nslots; i++) u += used[s+i];
+ if (u == 0) {
+ for (int i=0; i<nslots; i++) used[s+i] = 1;
+ *ptr = buff+state->slotSize*s;
+ return ncclSuccess;
}
- pthread_mutex_unlock(&state->mutex);
}
+ *ptr = NULL;
+ return ncclSuccess;
}
-ncclResult_t ncclProxyStart(struct ncclComm* comm) {
- pthread_mutex_lock(&comm->proxyState.mutex);
- if (comm->proxyState.ops != NULL)
- pthread_cond_signal(&comm->proxyState.cond);
- pthread_mutex_unlock(&comm->proxyState.mutex);
+ncclResult_t ncclProxySharedBuffersFree(struct ncclComm* comm, int cuda, int type, int channel, int size, char* ptr) {
+ struct ncclProxySharedBuffers* state = comm->proxyState.sharedBuffs;
+ int p = 2*channel+type;
+ int* used = cuda ? state->cudaUsed[p] : state->hostUsed[p];
+ char* buff = cuda ? state->cudaBuff[p] : state->hostBuff[p];
+ if (buff == NULL) return ncclInternalError;
+ int nslots = 1;
+ while (nslots*state->slotSize < size) nslots *= 2;
+ int s = (ptr-buff)/state->slotSize;
+ if (s < 0 || s+nslots > state->nslots) {
+ WARN("Error freeing shared buffer : freeing ptr %p size %d (start %p slot size %d nslots %d)\n", ptr, size, buff, state->slotSize, state->nslots);
+ return ncclInternalError;
+ }
+ for (int i=0; i<nslots; i++) used[s+i] = 0;
+ return ncclSuccess;
+}
+
+ncclResult_t ncclProxySharedBuffersDestroy(struct ncclComm* comm) {
+ struct ncclProxySharedBuffers* state = comm->proxyState.sharedBuffs;
+ if (state) {
+ CUDACHECK(cudaFree(state->cudaBuff[0]));
+ free(state->cudaUsed[0]);
+ NCCLCHECK(ncclCudaHostFree(state->hostBuff[0]));
+ free(state->hostUsed[0]);
+ free(state);
+ }
return ncclSuccess;
}
ncclResult_t ncclProxyCreate(struct ncclComm* comm) {
if (!comm->proxyThread) {
comm->proxyState.cond = PTHREAD_COND_INITIALIZER;
- comm->proxyState.mutex = PTHREAD_MUTEX_INITIALIZER;
+ comm->proxyState.opsMutex = PTHREAD_MUTEX_INITIALIZER;
+ comm->proxyState.poolMutex = PTHREAD_MUTEX_INITIALIZER;
comm->proxyState.ops = NULL;
pthread_create(&comm->proxyThread, NULL, persistentThread, comm);
}
@@ -263,21 +514,23 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) {
struct ncclProxyState* state = &comm->proxyState;
// Request the proxy to stop and then wake it
- pthread_mutex_lock(&state->mutex);
+ pthread_mutex_lock(&state->opsMutex);
state->stop = true;
pthread_cond_signal(&state->cond);
- pthread_mutex_unlock(&state->mutex);
+ pthread_mutex_unlock(&state->opsMutex);
if (comm->proxyThread) pthread_join(comm->proxyThread, NULL);
// Free off any memory allocated for the proxy arg pools
- pthread_mutex_lock(&state->mutex);
+ pthread_mutex_lock(&state->poolMutex);
struct ncclProxyState* proxyState = &comm->proxyState;
while (proxyState->pools != NULL) {
struct ncclProxyPool *next = proxyState->pools->next;
free(proxyState->pools);
proxyState->pools = next;
}
- pthread_mutex_unlock(&state->mutex);
+ pthread_mutex_unlock(&state->poolMutex);
+
+ NCCLCHECK(ncclProxySharedBuffersDestroy(comm));
return ncclSuccess;
}
diff --git a/src/transport.cc b/src/transport.cc
index 7219ea3..a5af541 100644
--- a/src/transport.cc
+++ b/src/transport.cc
@@ -19,15 +19,15 @@ struct ncclTransport ncclTransports[NTRANSPORTS] = {
};
template <int type>
-static ncclResult_t selectTransport(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connect, struct ncclConnector* connector, int channelId) {
+static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connect, struct ncclConnector* connector, int channelId) {
for (int t=0; t<NTRANSPORTS; t++) {
struct ncclTransport *transport = ncclTransports+t;
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
int ret = 0;
- NCCLCHECK(transport->canConnect(&ret, topo, graph, myInfo, peerInfo));
+ NCCLCHECK(transport->canConnect(&ret, comm->topo, graph, myInfo, peerInfo));
if (ret) {
connector->transportComm = transportComm;
- NCCLCHECK(transportComm->setup(topo, graph, myInfo, peerInfo, connect, connector, channelId));
+ NCCLCHECK(transportComm->setup(comm, graph, myInfo, peerInfo, connect, connector, channelId));
return ncclSuccess;
}
}
@@ -35,53 +35,87 @@ static ncclResult_t selectTransport(struct ncclTopoSystem* topo, struct ncclTopo
return ncclInternalError;
}
-ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend) {
+ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend) {
TRACE(NCCL_INIT, "nsend %d nrecv %d", nsend, nrecv);
- uint32_t nSkippedSend = 0, nSkippedRecv = 0; /* for tracing */
- struct ncclConnect connect;
- struct ncclConnector* conn;
+ uint32_t mask = 1 << channel->id;
for (int i=0; i<nrecv; i++) {
int peer = peerRecv[i];
- if (peer == -1 || peer >= comm->nRanks) continue;
- conn = &channel->peers[peer].recv;
- if (conn->connected) { ++nSkippedRecv; continue; }
- memset(&connect, 0, sizeof(connect));
- NCCLCHECK(selectTransport<0>(comm->topo, graph, comm->peerInfo+comm->rank, comm->peerInfo+peer, &connect, conn, channel->id));
- NCCLCHECK(bootstrapSend(comm->bootstrap, peer, &connect, sizeof(struct ncclConnect)));
+ if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].recv.connected) continue;
+ comm->connectRecv[peer] |= mask;
}
for (int i=0; i<nsend; i++) {
int peer = peerSend[i];
- if (peer == -1 || peer >= comm->nRanks) continue;
- conn = &channel->peers[peer].send;
- if (conn->connected) { ++nSkippedSend; continue; }
- memset(&connect, 0, sizeof(connect));
- NCCLCHECK(selectTransport<1>(comm->topo, graph, comm->peerInfo+comm->rank, comm->peerInfo+peer, &connect, conn, channel->id));
- NCCLCHECK(bootstrapSend(comm->bootstrap, peer, &connect, sizeof(struct ncclConnect)));
+ if (peer == -1 || peer >= comm->nRanks || peer == comm->rank || channel->peers[peer].send.connected) continue;
+ comm->connectSend[peer] |= mask;
}
- for (int i=0; i<nsend; i++) {
- int peer = peerSend[i];
- if (peer == -1 || peer >= comm->nRanks) continue;
- conn = &channel->peers[peer].send;
- if (conn->connected) {++nSkippedSend; continue; }
- memset(&connect, 0, sizeof(connect));
- NCCLCHECK(bootstrapRecv(comm->bootstrap, peer, &connect, sizeof(struct ncclConnect)));
- NCCLCHECK(conn->transportComm->connect(&connect, 1, comm->rank, conn));
- conn->connected = 1;
- CUDACHECK(cudaMemcpy(&channel->devPeers[peer].send, conn, sizeof(struct ncclConnector), cudaMemcpyHostToDevice));
+ return ncclSuccess;
+}
+
+void dumpData(struct ncclConnect* data, int ndata) {
+ for (int n=0; n<ndata; n++) {
+ printf("[%d] ", n);
+ uint8_t* d = (uint8_t*)data;
+ for (int i=0; i<sizeof(struct ncclConnect); i++) printf("%02x", d[i]);
+ printf("\n");
}
- for (int i=0; i<nrecv; i++) {
- int peer = peerRecv[i];
- if (peer == -1 || peer >= comm->nRanks) continue;
- conn = &channel->peers[peer].recv;
- if (conn->connected) {++nSkippedRecv; continue; }
- memset(&connect, 0, sizeof(connect));
- NCCLCHECK(bootstrapRecv(comm->bootstrap, peer, &connect, sizeof(struct ncclConnect)));
- NCCLCHECK(conn->transportComm->connect(&connect, 1, comm->rank, conn));
- conn->connected = 1;
- CUDACHECK(cudaMemcpy(&channel->devPeers[peer].recv, conn, sizeof(struct ncclConnector), cudaMemcpyHostToDevice));
+}
+
+ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph) {
+ struct ncclConnect data[2*MAXCHANNELS];
+ for (int i=1; i<comm->nRanks; i++) {
+ int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks;
+ int sendPeer = (comm->rank + i) % comm->nRanks;
+ uint32_t recvMask = comm->connectRecv[recvPeer];
+ uint32_t sendMask = comm->connectSend[sendPeer];
+
+ struct ncclConnect* recvData = data;
+ int sendChannels = 0, recvChannels = 0;
+ for (int c=0; c<MAXCHANNELS; c++) {
+ if (recvMask & (1<<c)) {
+ struct ncclConnector* conn = &comm->channels[c].peers[recvPeer].recv;
+ NCCLCHECK(selectTransport<0>(comm, graph, comm->peerInfo+comm->rank, comm->peerInfo+recvPeer, recvData+recvChannels++, conn, c));
+ }
+ }
+ struct ncclConnect* sendData = recvData+recvChannels;
+ for (int c=0; c<MAXCHANNELS; c++) {
+ if (sendMask & (1<<c)) {
+ struct ncclConnector* conn = &comm->channels[c].peers[sendPeer].send;
+ NCCLCHECK(selectTransport<1>(comm, graph, comm->peerInfo+comm->rank, comm->peerInfo+sendPeer, sendData+sendChannels++, conn, c));
+ }
+ }
+
+ if (sendPeer == recvPeer) {
+ if (recvChannels+sendChannels) {
+ NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
+ NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)));
+ sendData = data;
+ recvData = data+sendChannels;
+ }
+ } else {
+ if (recvChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, recvPeer, recvData, sizeof(struct ncclConnect)*recvChannels));
+ if (sendChannels) NCCLCHECK(bootstrapSend(comm->bootstrap, sendPeer, sendData, sizeof(struct ncclConnect)*sendChannels));
+ if (sendChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, sendPeer, sendData, sizeof(struct ncclConnect)*sendChannels));
+ if (recvChannels) NCCLCHECK(bootstrapRecv(comm->bootstrap, recvPeer, recvData, sizeof(struct ncclConnect)*recvChannels));
+ }
+
+ for (int c=0; c<MAXCHANNELS; c++) {
+ if (sendMask & (1<<c)) {
+ struct ncclConnector* conn = &comm->channels[c].peers[sendPeer].send;
+ NCCLCHECK(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn));
+ conn->connected = 1;
+ CUDACHECK(cudaMemcpy(&comm->channels[c].devPeers[sendPeer].send, conn, sizeof(struct ncclConnector), cudaMemcpyHostToDevice));
+ }
+ }
+ for (int c=0; c<MAXCHANNELS; c++) {
+ if (recvMask & (1<<c)) {
+ struct ncclConnector* conn = &comm->channels[c].peers[recvPeer].recv;
+ NCCLCHECK(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn));
+ conn->connected = 1;
+ CUDACHECK(cudaMemcpy(&comm->channels[c].devPeers[recvPeer].recv, conn, sizeof(struct ncclConnector), cudaMemcpyHostToDevice));
+ }
+ }
+ comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0;
}
- TRACE(NCCL_INIT, "nsend %d nrecv %d nSkippedSend %u nSkippedRecv %u - DONE", nsend, nrecv, nSkippedSend, nSkippedRecv);
return ncclSuccess;
}
-
diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc
index 132f4fa..af865ce 100644
--- a/src/transport/coll_net.cc
+++ b/src/transport/coll_net.cc
@@ -26,10 +26,8 @@ struct reqSlot {
struct collNetSendResources {
void* collNetSendComm;
- struct ncclSendMem* hostSendMem;
- struct ncclRecvMem* hostRecvMem;
- struct ncclSendMem* devHostSendMem;
- struct ncclRecvMem* devHostRecvMem;
+ struct ncclSendMem* sendMem;
+ struct ncclRecvMem* recvMem;
uint32_t* llData;
int netDev;
int useGdr;
@@ -45,10 +43,8 @@ struct collNetSendResources {
struct collNetRecvResources {
void* netListenComm;
void* collNetRecvComm;
- struct ncclSendMem* hostSendMem;
- struct ncclRecvMem* hostRecvMem;
- struct ncclSendMem* devHostSendMem;
- struct ncclRecvMem* devHostRecvMem;
+ struct ncclSendMem* sendMem;
+ struct ncclRecvMem* recvMem;
uint32_t* llData;
int netDev;
int useGdr;
@@ -67,16 +63,15 @@ ncclResult_t collNetCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncc
}
/* Setup send connector, and return connect information for others in the coll communicator to connect to me */
-ncclResult_t collNetSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
+ncclResult_t collNetSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
struct collNetSendResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
send->transportResources = resources;
- NCCLCHECK(ncclTopoGetNetDev(topo, myInfo->rank, graph, channelId, &resources->netDev));
- NCCLCHECK(ncclTopoCheckGdr(topo, myInfo->busId, resources->netDev, 1, &resources->useGdr));
+ NCCLCHECK(ncclTopoGetNetDev(comm->topo, myInfo->rank, graph, channelId, &resources->netDev));
+ NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, resources->netDev, 1, &resources->useGdr));
- NCCLCHECK(ncclCudaHostCalloc(&resources->hostSendMem, 1));
- resources->devHostSendMem = resources->hostSendMem;
+ NCCLCHECK(ncclCudaHostCalloc(&resources->sendMem, 1));
int recvSize = offsetof(struct ncclRecvMem, buff);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) recvSize += send->comm->buffSizes[p];
@@ -84,8 +79,7 @@ ncclResult_t collNetSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph*
if (resources->useGdr) {
NCCLCHECK(ncclCudaCalloc((char**)(&resources->devRecvMem), recvSize));
}
- NCCLCHECK(ncclCudaHostCalloc((char**)&resources->hostRecvMem, recvSize));
- resources->devHostRecvMem = resources->hostRecvMem;
+ NCCLCHECK(ncclCudaHostCalloc((char**)&resources->recvMem, recvSize));
NCCLCHECK(ncclIbMalloc((void**)&(resources->llData), send->comm->buffSizes[NCCL_PROTO_LL]/2));
INFO(NCCL_INIT|NCCL_NET,"Coll %02d : %d [send] via COLLNET/%s/%d%s", channelId, myInfo->rank, collNetName(), resources->netDev,
@@ -94,16 +88,15 @@ ncclResult_t collNetSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph*
}
/* Setup recv connector */
-ncclResult_t collNetRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
+ncclResult_t collNetRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
struct collNetRecvResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources;
- NCCLCHECK(ncclTopoGetNetDev(topo, myInfo->rank, graph, channelId, &resources->netDev));
- NCCLCHECK(ncclTopoCheckGdr(topo, myInfo->busId, resources->netDev, 0, &resources->useGdr));
+ NCCLCHECK(ncclTopoGetNetDev(comm->topo, myInfo->rank, graph, channelId, &resources->netDev));
+ NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, resources->netDev, 0, &resources->useGdr));
- NCCLCHECK(ncclCudaHostCalloc(&resources->hostSendMem, 1));
- resources->devHostSendMem = resources->hostSendMem;
+ NCCLCHECK(ncclCudaHostCalloc(&resources->sendMem, 1));
int recvSize = offsetof(struct ncclRecvMem, buff);
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) recvSize += recv->comm->buffSizes[p];
@@ -111,8 +104,7 @@ ncclResult_t collNetRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph*
if (resources->useGdr) {
NCCLCHECK(ncclCudaCalloc((char**)(&resources->devRecvMem), recvSize));
}
- NCCLCHECK(ncclCudaHostCalloc((char**)&resources->hostRecvMem, recvSize));
- resources->devHostRecvMem = resources->hostRecvMem;
+ NCCLCHECK(ncclCudaHostCalloc((char**)&resources->recvMem, recvSize));
NCCLCHECK(ncclIbMalloc((void**)&(resources->llData), recv->comm->buffSizes[NCCL_PROTO_LL]/2));
@@ -123,25 +115,25 @@ ncclResult_t collNetRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph*
return ncclSuccess;
}
-ncclResult_t collNetSendConnect(struct ncclConnect* connectInfos, int nranks, int rank, struct ncclConnector* send) {
+ncclResult_t collNetSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfos, int nranks, int rank, struct ncclConnector* send) {
// Setup device pointers
struct collNetSendResources* resources = (struct collNetSendResources*)send->transportResources;
struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(connectInfos+rank);
// Intermediate buffering on GPU for GPU Direct RDMA, but LL buffer is always on host
- struct ncclRecvMem* recvMem = resources->useGdr ? resources->devRecvMem : resources->devHostRecvMem;
+ struct ncclRecvMem* recvMem = resources->useGdr ? resources->devRecvMem : resources->recvMem;
int offset = 0;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- send->conn.buffs[p] = (p == NCCL_PROTO_LL ? resources->devHostRecvMem->buff : recvMem->buff) + offset;
+ send->conn.buffs[p] = (p == NCCL_PROTO_LL ? resources->recvMem->buff : recvMem->buff) + offset;
offset += send->comm->buffSizes[p];
}
send->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
// Head/Tail/Opcount/Fifos are always on host
- send->conn.tail = &resources->devHostRecvMem->tail;
- send->conn.fifo = resources->devHostRecvMem->sizesFifo;
- send->conn.head = &resources->devHostSendMem->head;
- for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
+ send->conn.tail = &resources->recvMem->tail;
+ send->conn.sizesFifo = resources->recvMem->sizesFifo;
+ send->conn.head = &resources->sendMem->head;
+ for (int i=0; i<NCCL_STEPS; i++) send->conn.sizesFifo[i] = -1;
// Get info from recv side
resources->collNetRank = rank;
@@ -159,24 +151,24 @@ ncclResult_t collNetSendConnect(struct ncclConnect* connectInfos, int nranks, in
return ncclSuccess;
}
-ncclResult_t collNetRecvConnect(struct ncclConnect* connectInfos, int nranks, int rank, struct ncclConnector* recv) {
+ncclResult_t collNetRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfos, int nranks, int rank, struct ncclConnector* recv) {
// Setup device pointers
struct collNetRecvResources* resources = (struct collNetRecvResources*)recv->transportResources;
struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(connectInfos+rank);
resources->collNetRank = rank;
// Intermediate buffering on GPU for GPU Direct RDMA
- struct ncclRecvMem* recvMem = resources->useGdr ? resources->devRecvMem : resources->devHostRecvMem;
+ struct ncclRecvMem* recvMem = resources->useGdr ? resources->devRecvMem : resources->recvMem;
int offset = 0;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- recv->conn.buffs[p] = (p == NCCL_PROTO_LL ? resources->devHostRecvMem->buff : recvMem->buff) + offset;
+ recv->conn.buffs[p] = (p == NCCL_PROTO_LL ? resources->recvMem->buff : recvMem->buff) + offset;
offset += recv->comm->buffSizes[p];
}
recv->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
// Head/Tail/Opcount are always on host
- recv->conn.tail = &resources->devHostRecvMem->tail;
- recv->conn.head = &resources->devHostSendMem->head;
+ recv->conn.tail = &resources->recvMem->tail;
+ recv->conn.head = &resources->sendMem->head;
// Connect to coll comm
collNetHandle_t** handlePtrs = NULL;
@@ -213,8 +205,8 @@ cleanup:
ncclResult_t collNetSendFree(void* sendTransportResources) {
struct collNetSendResources* resources = (struct collNetSendResources*)sendTransportResources;
- NCCLCHECK(ncclCudaHostFree(resources->hostSendMem));
- NCCLCHECK(ncclCudaHostFree(resources->hostRecvMem));
+ NCCLCHECK(ncclCudaHostFree(resources->sendMem));
+ NCCLCHECK(ncclCudaHostFree(resources->recvMem));
if (resources->collNetSendComm) {
NCCLCHECK(collNetDeregMr(resources->collNetSendComm, resources->sendMhandles[NCCL_PROTO_LL]));
NCCLCHECK(collNetDeregMr(resources->collNetSendComm, resources->sendMhandles[NCCL_PROTO_SIMPLE]));
@@ -228,12 +220,12 @@ ncclResult_t collNetSendFree(void* sendTransportResources) {
ncclResult_t collNetRecvFree(void* recvTransportResources) {
struct collNetRecvResources* resources = (struct collNetRecvResources*)recvTransportResources;
- NCCLCHECK(ncclCudaHostFree(resources->hostSendMem));
+ NCCLCHECK(ncclCudaHostFree(resources->sendMem));
if (resources->collNetRecvComm) {
NCCLCHECK(collNetDeregMr(resources->collNetRecvComm, resources->mhandles[NCCL_PROTO_LL]));
NCCLCHECK(collNetDeregMr(resources->collNetRecvComm, resources->mhandles[NCCL_PROTO_SIMPLE]));
}
- NCCLCHECK(ncclCudaHostFree(resources->hostRecvMem));
+ NCCLCHECK(ncclCudaHostFree(resources->recvMem));
if (resources->useGdr)
CUDACHECK(cudaFree(resources->devRecvMem));
free(resources->llData);
@@ -256,97 +248,85 @@ ncclResult_t collNetSendProxy(struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) {
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
- args->head = resources->step;
- args->tail = resources->step;
- args->end = args->head + args->nsteps;
+ args->posted = args->transmitted = args->done = resources->step;
+ args->end = resources->step + args->nsteps;
args->state = ncclProxyOpProgress;
}
+ args->idle = 1;
if (args->state == ncclProxyOpProgress) {
int p = args->protocol;
int stepSize = args->connector->comm->buffSizes[p] / NCCL_STEPS;
char* localBuff = args->connector->conn.buffs[p];
void* sendMhandle = resources->sendMhandles[p];
void* recvMhandle = resources->recvMhandles[p];
- args->idle = 1;
struct reqSlot* reqFifo = resources->reqFifo;
- if (args->head < args->end) {
- int buffSlot = args->tail%NCCL_STEPS;
- if (args->tail < args->end && args->tail < args->head + NCCL_STEPS
- && reqFifo[buffSlot].recvBuff != NULL) {
- volatile int* sizesFifo = resources->hostRecvMem->sizesFifo;
- volatile uint64_t* recvTail = &resources->hostRecvMem->tail;
+ int buffSlot = args->transmitted%NCCL_STEPS;
+ if (args->transmitted < args->end && args->transmitted < args->done + NCCL_STEPS
+ && reqFifo[buffSlot].recvBuff != NULL) {
+ volatile int* sizesFifo = resources->recvMem->sizesFifo;
+ volatile uint64_t* recvTail = &resources->recvMem->tail;
+ if (sizesFifo[buffSlot] != -1 && (*recvTail > args->transmitted || args->protocol == NCCL_PROTO_LL)) {
+ // We have something to receive, let's check if it's completely ready.
+ int size = sizesFifo[buffSlot];
+ char* buff = localBuff+buffSlot*stepSize;
+ int ready = 1;
if (args->protocol == NCCL_PROTO_LL) {
- int size = sizesFifo[buffSlot];
- if (size != -1) {
- uint32_t flag = NCCL_LL_FLAG(args->tail + 1);
- int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine));
- union ncclLLFifoLine* lines = (union ncclLLFifoLine*)(localBuff+buffSlot*stepSize);
- int ready = 1;
- for (int i=0; i<nFifoLines; i++) {
- volatile uint32_t *f1 = &lines[i].flag1;
- volatile uint32_t *f2 = &lines[i].flag2;
- if (f1[0] != flag || f2[0] != flag) { ready = 0; break; }
- }
- if (ready) {
- int stepLines = stepSize / sizeof(union ncclLLFifoLine);
- //separate data from flag
- uint32_t* sendBuff = resources->llData+buffSlot*2*stepLines; // each line has two data elements
- for (int i=0; i<nFifoLines; i++) {
- volatile uint32_t *d1 = &lines[i].data1;
- volatile uint32_t *d2 = &lines[i].data2;
- sendBuff[2*i] = d1[0];
- sendBuff[2*i+1] = d2[0];
- }
- int count = nFifoLines*2*sizeof(uint32_t) / ncclTypeSize(args->dtype);
- NCCLCHECK(collNetIallreduce(resources->collNetSendComm, (void*)sendBuff, (void*)(reqFifo[buffSlot].recvBuff), count, args->dtype, args->redOp, sendMhandle, recvMhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- TRACE(NCCL_NET, "sendProxy [%d/%d] Iallreduce (LL) posted, req %p", args->head, buffSlot, args->requests[buffSlot]);
- sizesFifo[buffSlot] = -1;
- // Make sure size is reset to zero before we update the head.
- __sync_synchronize();
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
- }
+ uint32_t flag = NCCL_LL_FLAG(args->transmitted + 1);
+ int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine));
+ union ncclLLFifoLine* lines = (union ncclLLFifoLine*)buff;
+ // Pack data into another buffer
+ int stepLines = stepSize / sizeof(union ncclLLFifoLine);
+ uint32_t* sendBuff = resources->llData+buffSlot*2*stepLines; // each line has two data elements
+ buff = (char*)sendBuff;
+ for (int i=0; i<nFifoLines; i++) {
+ volatile uint32_t *f1 = &lines[i].flag1;
+ volatile uint32_t *d1 = &lines[i].data1;
+ volatile uint32_t *f2 = &lines[i].flag2;
+ volatile uint32_t *d2 = &lines[i].data2;
+ if (f1[0] != flag || f2[0] != flag) { ready = 0; break; }
+ sendBuff[2*i] = d1[0];
+ sendBuff[2*i+1] = d2[0];
}
- } else if (args->tail < *recvTail) {
- // Send through network
- if (sizesFifo[buffSlot] != -1) {
- int count = sizesFifo[buffSlot]/ncclTypeSize(args->dtype);
- NCCLCHECK(collNetIallreduce(resources->collNetSendComm, localBuff+buffSlot*stepSize, (void*)(reqFifo[buffSlot].recvBuff), count, args->dtype, args->redOp, sendMhandle, recvMhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- TRACE(NCCL_NET, "sendProxy [%d/%d] Iallreduce posted, req %p count %d", args->head, buffSlot, args->requests[buffSlot], count);
- sizesFifo[buffSlot] = -1;
- // Make sure size is reset to zero before we update the head.
- __sync_synchronize();
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
+ size = nFifoLines*2*sizeof(uint32_t);
+ }
+ if (ready) {
+ // Data is ready, try to send.
+ int count = size/ncclTypeSize(args->dtype);
+ NCCLCHECK(collNetIallreduce(resources->collNetSendComm, (void*) buff, (void*)(reqFifo[buffSlot].recvBuff), count, args->dtype, args->redOp, sendMhandle, recvMhandle, args->requests+buffSlot));
+ if (args->requests[buffSlot] != NULL) {
+ TRACE(NCCL_NET, "sendProxy [%d/%d] Iallreduce posted, req %p", args->transmitted, buffSlot, args->requests[buffSlot]);
+ sizesFifo[buffSlot] = -1;
+ // Make sure size is reset to zero before we update the head.
+ __sync_synchronize();
+ args->transmitted += args->sliceSteps;
+ args->idle = 0;
+ return ncclSuccess;
}
}
}
- if (args->head < args->tail) {
- int done, size;
- int buffSlot = args->head%NCCL_STEPS;
- NCCLCHECK(collNetTest((void*)(args->requests[buffSlot]), &done, &size));
- if (done) {
- TRACE(NCCL_NET, "sendProxy [%d/%d] request %p done, size %d", args->head, buffSlot, args->requests[buffSlot], size);
- reqFifo[buffSlot].size = size;
- // Make sure size is updated before we set recvBuff to NULL (from the view of recv proxy, concerning the flush)
- // (reordered store after store is possible on POWER, though not on x86)
- __sync_synchronize();
- reqFifo[buffSlot].recvBuff = NULL; // Notify recvProxy
- args->head += args->sliceSteps;
- resources->hostSendMem->head = args->head;
- args->idle = 0;
+ }
+ // Check whether the network has completed some send operations.
+ if (args->done < args->transmitted) {
+ int done, size;
+ int buffSlot = args->done%NCCL_STEPS;
+ NCCLCHECK(collNetTest((void*)(args->requests[buffSlot]), &done, &size));
+ if (done) {
+ TRACE(NCCL_NET, "sendProxy [%d/%d] request %p done, size %d", args->done, buffSlot, args->requests[buffSlot], size);
+ reqFifo[buffSlot].size = size;
+ // Make sure size is updated before we set recvBuff to NULL (from the view of recv proxy, concerning the flush)
+ // (reordered store after store is possible on POWER, though not on x86)
+ __sync_synchronize();
+ reqFifo[buffSlot].recvBuff = NULL; // Notify recvProxy
+ args->done += args->sliceSteps;
+ resources->sendMem->head = args->done;
+ args->idle = 0;
+ if (args->done == args->end) {
+ resources->step = args->end;
+ args->state = ncclProxyOpNone;
}
+ return ncclSuccess;
}
}
- if (args->head == args->end) {
- resources->step = args->end;
- args->idle = 0;
- args->state = ncclProxyOpNone;
- }
}
return ncclSuccess;
}
@@ -360,56 +340,79 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) {
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
- args->head = resources->step;
- args->tail = resources->step;
- args->end = args->head + args->nsteps;
+ args->posted = args->received = args->transmitted = args->done = resources->step;
+ args->end = resources->step + args->nsteps;
args->state = ncclProxyOpProgress;
}
+ args->idle = 1;
if (args->state == ncclProxyOpProgress) {
- args->idle = 1;
int p = args->protocol;
int stepSize = args->connector->comm->buffSizes[p] / NCCL_STEPS;
char* localBuff = args->connector->conn.buffs[p];
void* mhandle = resources->mhandles[p];
struct reqSlot* reqFifo = resources->reqFifo;
- if (args->head < args->end) {
- if ((args->tail < args->head + NCCL_STEPS) && (args->tail < (resources->hostSendMem->head) + NCCL_STEPS) && (args->tail < args->end)) {
- int buffSlot = args->tail%NCCL_STEPS;
- char* recvBuff = p == NCCL_PROTO_LL ? (char*)resources->llData : localBuff;
- int recvStepSize = p == NCCL_PROTO_LL ? stepSize/2 : stepSize;
- reqFifo[buffSlot].recvBuff = recvBuff+buffSlot*recvStepSize;
- TRACE(NCCL_NET, "recvProxy [%d/%d] posted buffer %p", args->tail, buffSlot, reqFifo[buffSlot].recvBuff);
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
- if (args->tail > args->head) {
- int buffSlot = args->head%NCCL_STEPS;
- if (reqFifo[buffSlot].recvBuff == NULL) { // Buffer is cleared : coll is complete
- TRACE(NCCL_NET, "recvProxy [%d/%d] done, size %d", args->head, buffSlot, reqFifo[buffSlot].size);
- args->head += args->sliceSteps;
- if (args->protocol == NCCL_PROTO_LL) { // ll
- // re-attach flag
- uint32_t flag = args->head;
- int stepLines = stepSize / sizeof(union ncclLLFifoLine);
- union ncclLLFifoLine* lines = (union ncclLLFifoLine*)(localBuff+buffSlot*stepSize);
- uint32_t* recvData = resources->llData+buffSlot*2*stepLines;
- int nFifoLines = DIVUP(reqFifo[buffSlot].size, 2*sizeof(uint32_t));
- for (int i=0; i<nFifoLines; i++) {
- lines[i].v[0] = ((uint64_t)flag << 32) + recvData[2*i];
- lines[i].v[1] = ((uint64_t)flag << 32) + recvData[2*i+1];
- }
- } else if (args->protocol == NCCL_PROTO_SIMPLE) {
- if (resources->useGdr) NCCLCHECK(collNetFlush(resources->collNetRecvComm, localBuff+buffSlot*stepSize, reqFifo[buffSlot].size, mhandle));
- resources->hostRecvMem->tail = args->head;
+ if ((args->posted < args->done + NCCL_STEPS) && (args->posted < args->end)) {
+ int buffSlot = args->posted%NCCL_STEPS;
+ char* recvBuff = p == NCCL_PROTO_LL ? (char*)resources->llData : localBuff;
+ int recvStepSize = p == NCCL_PROTO_LL ? stepSize/2 : stepSize;
+ reqFifo[buffSlot].recvBuff = recvBuff+buffSlot*recvStepSize;
+ TRACE(NCCL_NET, "recvProxy [%d/%d] posted buffer %p", args->posted, buffSlot, reqFifo[buffSlot].recvBuff);
+ args->posted += args->sliceSteps;
+ args->idle = 0;
+ return ncclSuccess;
+ }
+ if (args->posted > args->received) {
+ int buffSlot = args->received%NCCL_STEPS;
+ if (reqFifo[buffSlot].recvBuff == NULL) { // Buffer is cleared : coll is complete
+ TRACE(NCCL_NET, "recvProxy [%d/%d] done, size %d", args->received, buffSlot, reqFifo[buffSlot].size);
+ if (args->protocol == NCCL_PROTO_LL) { // ll
+ // re-attach flag
+ uint32_t flag = NCCL_LL_FLAG(args->received + 1);
+ int stepLines = stepSize / sizeof(union ncclLLFifoLine);
+ union ncclLLFifoLine* lines = (union ncclLLFifoLine*)(localBuff+buffSlot*stepSize);
+ uint32_t* recvData = resources->llData+buffSlot*2*stepLines;
+ int nFifoLines = DIVUP(reqFifo[buffSlot].size, 2*sizeof(uint32_t));
+ for (int i=0; i<nFifoLines; i++) {
+ lines[i].v[0] = ((uint64_t)flag << 32) + recvData[2*i];
+ lines[i].v[1] = ((uint64_t)flag << 32) + recvData[2*i+1];
}
- args->idle = 0;
}
+ args->received += args->sliceSteps;
+ if (reqFifo[buffSlot].size > 0 && args->protocol == NCCL_PROTO_SIMPLE && resources->useGdr) {
+ NCCLCHECK(collNetIflush(resources->collNetRecvComm, localBuff+buffSlot*stepSize, reqFifo[buffSlot].size, mhandle, args->requests+buffSlot));
+ } else {
+ args->requests[buffSlot] = NULL;
+ }
+ args->idle = 0;
+ return ncclSuccess;
}
}
- if (args->head == args->end) {
- resources->step = args->end;
- args->idle = 0;
- args->state = ncclProxyOpNone;
+ if (args->received > args->transmitted) {
+ // Progress flush operations
+ int buffSlot = args->transmitted%NCCL_STEPS;
+ int done = 1;
+ if (args->requests[buffSlot]) NCCLCHECK(collNetTest(args->requests[buffSlot], &done, NULL));
+ if (done) {
+ args->transmitted += args->sliceSteps;
+ __sync_synchronize();
+ resources->recvMem->tail = args->transmitted;
+ args->idle = 0;
+ return ncclSuccess;
+ }
+ }
+ if (args->transmitted > args->done) {
+ volatile uint64_t* sendHead = &resources->sendMem->head;
+ uint64_t done = *sendHead;
+ while (done > args->done &&
+ // LL and LL128 can acknowledge 0-bytes send before they even happen. Don't go past what we transmitted.
+ args->transmitted > args->done) {
+ args->done += args->sliceSteps;
+ args->idle = 0;
+ if (args->done == args->end) {
+ resources->step = args->end;
+ args->state = ncclProxyOpNone;
+ }
+ }
}
}
return ncclSuccess;
diff --git a/src/transport/net.cc b/src/transport/net.cc
index 49cd8d2..86c43f8 100644
--- a/src/transport/net.cc
+++ b/src/transport/net.cc
@@ -7,6 +7,7 @@
#include "comm.h"
#include "net.h"
#include "graph.h"
+#include "collectives.h"
struct netConnectInfo {
ncclNetHandle_t netHandle;
@@ -22,6 +23,7 @@ struct netSendResources {
struct ncclRecvMem* recvMem;
int netDev;
int useGdr;
+ int shared;
char* buffers[LOC_COUNT];
int buffSizes[LOC_COUNT];
void* mhandles[LOC_COUNT];
@@ -37,6 +39,7 @@ struct netRecvResources {
struct ncclRecvMem* recvMem;
int netDev;
int useGdr;
+ int shared;
char* buffers[LOC_COUNT];
int buffSizes[LOC_COUNT];
void* mhandles[LOC_COUNT];
@@ -51,108 +54,118 @@ ncclResult_t netCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
return ncclSuccess;
}
+NCCL_PARAM(NetSharedBuffers, "NET_SHARED_BUFFERS", -2);
+
/* Determine if we will use this transport for this peer and return connect
* information for this peer */
-ncclResult_t netSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
+ncclResult_t netSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
struct netSendResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
send->transportResources = resources;
+ send->conn.shared = resources->shared = ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : graph ? 0 : 1;
- NCCLCHECK(ncclTopoGetNetDev(topo, myInfo->rank, graph, channelId, &resources->netDev));
- NCCLCHECK(ncclTopoCheckGdr(topo, myInfo->busId, resources->netDev, 1, &resources->useGdr));
+ NCCLCHECK(ncclTopoGetNetDev(comm->topo, myInfo->rank, graph, channelId, &resources->netDev));
+ NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, resources->netDev, 1, &resources->useGdr));
NCCLCHECK(ncclCudaHostCalloc(&resources->sendMem, 1));
NCCLCHECK(ncclCudaHostCalloc(&resources->recvMem, 1));
send->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
send->conn.tail = &resources->recvMem->tail;
- send->conn.fifo = resources->recvMem->sizesFifo;
+ send->conn.sizesFifo = resources->recvMem->sizesFifo;
+ // Only fuse P2P buffers, continue to allocate dedicated buffers for ring/tree
+ send->conn.ptrsFifo = resources->shared ? resources->recvMem->ptrsFifo : NULL;
send->conn.head = &resources->sendMem->head;
- for (int i=0; i<NCCL_STEPS; i++) send->conn.fifo[i] = -1;
-
- int protoLoc[NCCL_NUM_PROTOCOLS];
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- protoLoc[p] = p != NCCL_PROTO_LL && resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
- }
+ resources->sendMem->head = resources->shared ? -NCCL_STEPS : 0; // Don't give any credit yet when sharing buffers
+ for (int i=0; i<NCCL_STEPS; i++) send->conn.sizesFifo[i] = -1;
- int buffSizes[NCCL_NUM_PROTOCOLS];
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- // Only allocate buffers for simple for p2p connections
- buffSizes[p] = graph == NULL && p != NCCL_PROTO_SIMPLE ? 0 : send->comm->buffSizes[p];
- resources->buffSizes[protoLoc[p]] += buffSizes[p];
- }
+ if (resources->shared == 0) {
+ int protoLoc[NCCL_NUM_PROTOCOLS];
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ protoLoc[p] = p != NCCL_PROTO_LL && resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
+ }
+ int buffSizes[NCCL_NUM_PROTOCOLS];
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ buffSizes[p] = send->comm->buffSizes[p];
+ resources->buffSizes[protoLoc[p]] += buffSizes[p];
+ }
- if (resources->buffSizes[LOC_DEVMEM]) {
- NCCLCHECK(ncclCudaCalloc(resources->buffers+LOC_DEVMEM, resources->buffSizes[LOC_DEVMEM]));
- }
- if (resources->buffSizes[LOC_HOSTMEM]) {
- NCCLCHECK(ncclCudaHostCalloc(resources->buffers+LOC_HOSTMEM, resources->buffSizes[LOC_HOSTMEM]));
- }
+ if (resources->buffSizes[LOC_DEVMEM]) {
+ NCCLCHECK(ncclCudaCalloc(resources->buffers+LOC_DEVMEM, resources->buffSizes[LOC_DEVMEM]));
+ }
+ if (resources->buffSizes[LOC_HOSTMEM]) {
+ NCCLCHECK(ncclCudaHostCalloc(resources->buffers+LOC_HOSTMEM, resources->buffSizes[LOC_HOSTMEM]));
+ }
- int offsets[LOC_COUNT];
- offsets[LOC_HOSTMEM] = offsets[LOC_DEVMEM] = 0;
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- resources->mhandlesProto[p] = resources->mhandles+protoLoc[p];
- send->conn.buffs[p] = resources->buffers[protoLoc[p]] + offsets[protoLoc[p]];
- offsets[protoLoc[p]] += buffSizes[p];
+ int offsets[LOC_COUNT];
+ offsets[LOC_HOSTMEM] = offsets[LOC_DEVMEM] = 0;
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ resources->mhandlesProto[p] = resources->mhandles+protoLoc[p];
+ send->conn.buffs[p] = resources->buffers[protoLoc[p]] + offsets[protoLoc[p]];
+ offsets[protoLoc[p]] += buffSizes[p];
+ }
}
- INFO(NCCL_INIT|NCCL_NET,"Channel %02d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(), resources->netDev,
- resources->useGdr ? "/GDRDMA" : "");
+ INFO(NCCL_INIT|NCCL_NET,"Channel %02d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(), resources->netDev,
+ resources->useGdr ? "/GDRDMA" : "", resources->shared ? "/Shared" : "");
return ncclSuccess;
}
-ncclResult_t netRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
+ncclResult_t netRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
struct netRecvResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources;
+ recv->conn.shared = resources->shared = ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : graph ? 0 : 1;
- NCCLCHECK(ncclTopoGetNetDev(topo, myInfo->rank, graph, channelId, &resources->netDev));
- NCCLCHECK(ncclTopoCheckGdr(topo, myInfo->busId, resources->netDev, 0, &resources->useGdr));
+ NCCLCHECK(ncclTopoGetNetDev(comm->topo, myInfo->rank, graph, channelId, &resources->netDev));
+ NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, resources->netDev, 0, &resources->useGdr));
NCCLCHECK(ncclCudaHostCalloc(&resources->sendMem, 1));
NCCLCHECK(ncclCudaHostCalloc(&resources->recvMem, 1));
recv->conn.direct |= resources->useGdr ? NCCL_DIRECT_NIC : 0;
recv->conn.tail = &resources->recvMem->tail;
+ // Only fuse P2P buffers, continue to allocate dedicated buffers for ring/tree
+ recv->conn.ptrsFifo = resources->shared ? resources->recvMem->ptrsFifo : NULL;
recv->conn.head = &resources->sendMem->head;
- int protoLoc[NCCL_NUM_PROTOCOLS];
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- protoLoc[p] = resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
- }
+ if (resources->shared == 0) { // Only allocate dedicated buffers for ring/tree not for p2p
+ int protoLoc[NCCL_NUM_PROTOCOLS];
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ protoLoc[p] = resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
+ }
- int buffSizes[NCCL_NUM_PROTOCOLS];
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- // Only allocate buffers for simple for p2p connections
- buffSizes[p] = graph == NULL && p != NCCL_PROTO_SIMPLE ? 0 : recv->comm->buffSizes[p];
- resources->buffSizes[protoLoc[p]] += buffSizes[p];
- }
+ int buffSizes[NCCL_NUM_PROTOCOLS];
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ buffSizes[p] = recv->comm->buffSizes[p];
+ resources->buffSizes[protoLoc[p]] += buffSizes[p];
+ }
- if (resources->buffSizes[LOC_DEVMEM]) {
- NCCLCHECK(ncclCudaCalloc(resources->buffers+LOC_DEVMEM, resources->buffSizes[LOC_DEVMEM]));
- }
- if (resources->buffSizes[LOC_HOSTMEM]) {
- NCCLCHECK(ncclCudaHostCalloc(resources->buffers+LOC_HOSTMEM, resources->buffSizes[LOC_HOSTMEM]));
- }
+ if (resources->buffSizes[LOC_DEVMEM]) {
+ NCCLCHECK(ncclCudaCalloc(resources->buffers+LOC_DEVMEM, resources->buffSizes[LOC_DEVMEM]));
+ }
+ if (resources->buffSizes[LOC_HOSTMEM]) {
+ NCCLCHECK(ncclCudaHostCalloc(resources->buffers+LOC_HOSTMEM, resources->buffSizes[LOC_HOSTMEM]));
+ }
- int offsets[LOC_COUNT];
- offsets[LOC_HOSTMEM] = offsets[LOC_DEVMEM] = 0;
- for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
- resources->mhandlesProto[p] = resources->mhandles+protoLoc[p];
- recv->conn.buffs[p] = resources->buffers[protoLoc[p]] + offsets[protoLoc[p]];
- offsets[protoLoc[p]] += buffSizes[p];
+ int offsets[LOC_COUNT];
+ offsets[LOC_HOSTMEM] = offsets[LOC_DEVMEM] = 0;
+ for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
+ resources->mhandlesProto[p] = resources->mhandles+protoLoc[p];
+ recv->conn.buffs[p] = resources->buffers[protoLoc[p]] + offsets[protoLoc[p]];
+ offsets[protoLoc[p]] += buffSizes[p];
+ }
}
- INFO(NCCL_INIT|NCCL_NET,"Channel %02d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s", channelId, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, ncclNetName(), resources->netDev,
- resources->useGdr ? "/GDRDMA" : "");
+ INFO(NCCL_INIT|NCCL_NET,"Channel %02d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s", channelId, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, ncclNetName(), resources->netDev,
+ resources->useGdr ? "/GDRDMA" : "", resources->shared ? "/Shared" : "");
struct netConnectInfo* info = (struct netConnectInfo*) connectInfo;
NCCLCHECK(ncclNetListen(resources->netDev, &info->netHandle, &resources->netListenComm));
return ncclSuccess;
}
-ncclResult_t netSendConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
+ncclResult_t netSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
// Setup device pointers
struct netSendResources* resources = (struct netSendResources*)send->transportResources;
struct netConnectInfo* info = (struct netConnectInfo*)connectInfo;
@@ -160,6 +173,13 @@ ncclResult_t netSendConnect(struct ncclConnect* connectInfo, int nranks, int ran
// Connect to remote peer
NCCLCHECK(ncclNetConnect(resources->netDev, info->netHandle, &resources->netSendComm));
+ if (resources->shared) {
+ // Get shared buffers
+ int loc = resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
+ NCCLCHECK(ncclProxySharedBuffersInit(send->comm, resources->useGdr, resources->buffSizes+loc, resources->buffers+loc));
+ resources->mhandlesProto[NCCL_PROTO_SIMPLE] = resources->mhandles+loc;
+ }
+
if (resources->buffSizes[LOC_DEVMEM]) {
NCCLCHECK(ncclNetRegMr(resources->netSendComm, resources->buffers[LOC_DEVMEM], resources->buffSizes[LOC_DEVMEM], NCCL_PTR_CUDA, &resources->mhandles[LOC_DEVMEM]));
}
@@ -170,7 +190,7 @@ ncclResult_t netSendConnect(struct ncclConnect* connectInfo, int nranks, int ran
}
/* Connect to this peer */
-ncclResult_t netRecvConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
+ncclResult_t netRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
// Setup device pointers
struct netRecvResources* resources = (struct netRecvResources*)recv->transportResources;
@@ -178,6 +198,13 @@ ncclResult_t netRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
NCCLCHECK(ncclNetAccept(resources->netListenComm, &resources->netRecvComm));
NCCLCHECK(ncclNetCloseListen(resources->netListenComm));
+ if (resources->shared) {
+ // Get shared buffers
+ int loc = resources->useGdr ? LOC_DEVMEM : LOC_HOSTMEM;
+ NCCLCHECK(ncclProxySharedBuffersInit(recv->comm, resources->useGdr, resources->buffSizes+loc, resources->buffers+loc));
+ resources->mhandlesProto[NCCL_PROTO_SIMPLE] = resources->mhandles+loc;
+ }
+
if (resources->buffSizes[LOC_DEVMEM]) {
NCCLCHECK(ncclNetRegMr(resources->netRecvComm, resources->buffers[LOC_DEVMEM], resources->buffSizes[LOC_DEVMEM], NCCL_PTR_CUDA, &resources->mhandles[LOC_DEVMEM]));
}
@@ -195,8 +222,10 @@ ncclResult_t netSendFree(void* transportResources) {
if (resources->buffers[l])
NCCLCHECK(ncclNetDeregMr(resources->netSendComm, resources->mhandles[l]));
}
- NCCLCHECK(ncclCudaHostFree(resources->buffers[LOC_HOSTMEM]));
- CUDACHECK(cudaFree(resources->buffers[LOC_DEVMEM]));
+ if (resources->shared == 0) {
+ NCCLCHECK(ncclCudaHostFree(resources->buffers[LOC_HOSTMEM]));
+ CUDACHECK(cudaFree(resources->buffers[LOC_DEVMEM]));
+ }
NCCLCHECK(ncclNetCloseSend(resources->netSendComm));
free(resources);
return ncclSuccess;
@@ -210,116 +239,122 @@ ncclResult_t netRecvFree(void* transportResources) {
if (resources->buffers[l])
NCCLCHECK(ncclNetDeregMr(resources->netRecvComm, resources->mhandles[l]));
}
- NCCLCHECK(ncclCudaHostFree(resources->buffers[LOC_HOSTMEM]));
- CUDACHECK(cudaFree(resources->buffers[LOC_DEVMEM]));
+ if (resources->shared == 0) {
+ NCCLCHECK(ncclCudaHostFree(resources->buffers[LOC_HOSTMEM]));
+ CUDACHECK(cudaFree(resources->buffers[LOC_DEVMEM]));
+ }
NCCLCHECK(ncclNetCloseRecv(resources->netRecvComm));
free(resources);
return ncclSuccess;
}
+static_assert(NCCL_STEPS <= NCCL_NET_MAX_REQUESTS, "Not enough net requests to cover for steps");
+
ncclResult_t netSendProxy(struct ncclProxyArgs* args) {
struct netSendResources* resources = (struct netSendResources*) (args->connector->transportResources);
if (args->state == ncclProxyOpReady) {
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
- args->head = resources->step;
- args->tail = resources->step;
- args->end = args->head + args->nsteps;
+ args->posted = args->transmitted = args->done = resources->step;
+ args->end = resources->step + args->nsteps;
args->state = ncclProxyOpProgress;
}
+ args->idle = 1;
if (args->state == ncclProxyOpProgress) {
int p = args->protocol;
int stepSize = args->connector->comm->buffSizes[p] / NCCL_STEPS;
char* localBuff = args->connector->conn.buffs[p];
void* mhandle = *(resources->mhandlesProto[p]);
- args->idle = 1;
- if (args->head < args->end) {
- int buffSlot = args->tail%NCCL_STEPS;
- if (args->tail < args->end && args->tail < args->head + NCCL_STEPS) {
- volatile int* sizesFifo = resources->recvMem->sizesFifo;
- volatile uint64_t* recvTail = &resources->recvMem->tail;
+ int buffSize = stepSize*args->sliceSteps;
+ if (resources->shared) buffSize /= SENDRECV_SLICEFACTOR;
+ if (args->sendbytes < buffSize) buffSize = args->sendbytes;
+ // Post buffers to the GPU
+ if (args->posted < args->end && args->posted < args->done + NCCL_STEPS) {
+ if (resources->shared) {
+ char* ptr;
+ NCCLCHECK(ncclProxySharedBuffersAlloc(args->connector->comm, resources->useGdr, 0, args->channel->id, buffSize, &ptr));
+ if (ptr == NULL) return ncclInternalError;
+ resources->recvMem->ptrsFifo[args->posted%NCCL_STEPS] = ptr;
+ __sync_synchronize();
+ volatile uint64_t* sendHead = &resources->sendMem->head;
+ args->posted += args->sliceSteps;
+ *sendHead = args->posted - NCCL_STEPS;
+ } else args->posted += args->sliceSteps;
+ args->idle = 0;
+ return ncclSuccess;
+ }
+ // Check whether we received data from the GPU and send it to the network
+ int buffSlot = args->transmitted%NCCL_STEPS;
+ if (args->transmitted < args->posted && args->transmitted < args->done + NCCL_STEPS) {
+ volatile int* sizesFifo = resources->recvMem->sizesFifo;
+ volatile uint64_t* recvTail = &resources->recvMem->tail;
+ if (sizesFifo[buffSlot] != -1 && (*recvTail > args->transmitted || args->protocol == NCCL_PROTO_LL)) {
+ // We have something to receive, let's check if it's completely ready.
+ int size = sizesFifo[buffSlot];
+ char* buff = resources->shared ? (char*)resources->recvMem->ptrsFifo[buffSlot] : localBuff+buffSlot*stepSize;
+ int ready = 1;
if (args->protocol == NCCL_PROTO_LL128) {
- if (args->tail < *recvTail) {
- if (sizesFifo[buffSlot] != -1) {
- int ready = resources->useGdr;
- if (!ready) {
- // When data is in sysmem, we need to wait until all flags are correct since the GPU only
- // called threadfence()
- uint64_t flag = args->tail + 1;
- int nFifoLines = DIVUP(sizesFifo[buffSlot], sizeof(uint64_t)*NCCL_LL128_LINEELEMS);
- volatile uint64_t* lines = (volatile uint64_t*)(localBuff+buffSlot*stepSize);
- ready = 1;
- for (int i=0; i<nFifoLines; i++) {
- if (lines[i*NCCL_LL128_LINEELEMS+NCCL_LL128_DATAELEMS] != flag) { ready = 0; break; }
- }
- }
- if (ready) {
- // Send through network
- NCCLCHECK(ncclNetIsend(resources->netSendComm, localBuff+buffSlot*stepSize, sizesFifo[buffSlot], mhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- sizesFifo[buffSlot] = -1;
- // Make sure size is reset to zero before we update the head.
- __sync_synchronize();
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
- }
+ int ready = resources->useGdr;
+ if (!ready) {
+ // When data is in sysmem, we need to wait until all flags are correct since the GPU only
+ // called threadfence()
+ uint64_t flag = args->transmitted + 1;
+ int nFifoLines = DIVUP(sizesFifo[buffSlot], sizeof(uint64_t)*NCCL_LL128_LINEELEMS);
+ volatile uint64_t* lines = (volatile uint64_t*)buff;
+ ready = 1;
+ for (int i=0; i<nFifoLines; i++) {
+ if (lines[i*NCCL_LL128_LINEELEMS+NCCL_LL128_DATAELEMS] != flag) { ready = 0; break; }
}
}
} else if (args->protocol == NCCL_PROTO_LL) {
- int size = sizesFifo[buffSlot];
- if (size != -1) {
- uint32_t flag = NCCL_LL_FLAG(args->tail + 1);
- int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine));
- size = nFifoLines * sizeof(union ncclLLFifoLine);
- union ncclLLFifoLine* lines = (union ncclLLFifoLine*)(localBuff+buffSlot*stepSize);
- int ready = 1;
- for (int i=0; i<nFifoLines; i++) {
- volatile uint32_t *f1 = &lines[i].flag1;
- volatile uint32_t *f2 = &lines[i].flag2;
- if (f1[0] != flag || f2[0] != flag) { ready = 0; break; }
- }
- if (ready) {
- NCCLCHECK(ncclNetIsend(resources->netSendComm, lines, size, mhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- sizesFifo[buffSlot] = -1;
- // Make sure size is reset to zero before we update the head.
- __sync_synchronize();
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
- }
+ uint32_t flag = NCCL_LL_FLAG(args->transmitted + 1);
+ int nFifoLines = DIVUP(size, sizeof(union ncclLLFifoLine));
+ union ncclLLFifoLine* lines = (union ncclLLFifoLine*)buff;
+ for (int i=0; i<nFifoLines; i++) {
+ volatile uint32_t *f1 = &lines[i].flag1;
+ volatile uint32_t *f2 = &lines[i].flag2;
+ if (f1[0] != flag || f2[0] != flag) { ready = 0; break; }
}
- } else if (args->tail < *recvTail) {
- // Send through network
- if (sizesFifo[buffSlot] != -1) {
- NCCLCHECK(ncclNetIsend(resources->netSendComm, localBuff+buffSlot*stepSize, sizesFifo[buffSlot], mhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- sizesFifo[buffSlot] = -1;
- // Make sure size is reset to zero before we update the head.
- __sync_synchronize();
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
+ }
+ if (ready) {
+ // Data is ready, try to send.
+ NCCLCHECK(ncclNetIsend(resources->netSendComm, buff, size, mhandle, args->requests+buffSlot));
+ if (args->requests[buffSlot] != NULL) {
+ TRACE(NCCL_NET, "sendProxy [%d/%d] Isend (LL) posted, req %p", args->transmitted, buffSlot, args->requests[buffSlot]);
+ sizesFifo[buffSlot] = -1;
+ // Make sure size is reset to zero before we update the head.
+ __sync_synchronize();
+ args->transmitted += args->sliceSteps;
+ args->idle = 0;
+ return ncclSuccess;
}
}
}
- if (args->head < args->tail) {
- int done;
- int buffSlot = args->head%NCCL_STEPS;
- NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, NULL));
- if (done) {
- args->head += args->sliceSteps;
- resources->sendMem->head = args->head;
- args->idle = 0;
+ }
+ // Check whether the network has completed some send operations.
+ if (args->done < args->transmitted) {
+ int done;
+ int buffSlot = args->done%NCCL_STEPS;
+ NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, NULL));
+ if (done) {
+ TRACE(NCCL_NET, "sendProxy [%d/%d] request %p done, size %d", args->done, buffSlot, args->requests[buffSlot]);
+ if (resources->shared) {
+ char* ptr = (char*)resources->recvMem->ptrsFifo[args->done%NCCL_STEPS];
+ NCCLCHECK(ncclProxySharedBuffersFree(args->connector->comm, resources->useGdr, 0, args->channel->id, buffSize, ptr));
+ }
+ args->done += args->sliceSteps;
+
+ if (resources->shared == 0) {
+ resources->sendMem->head = args->done;
+ }
+ args->idle = 0;
+ if (args->done == args->end) {
+ resources->step = args->end;
+ args->state = ncclProxyOpNone;
}
+ return ncclSuccess;
}
}
- if (args->head == args->end) {
- resources->step = args->end;
- args->idle = 0;
- args->state = ncclProxyOpNone;
- }
}
return ncclSuccess;
}
@@ -329,46 +364,88 @@ ncclResult_t netRecvProxy(struct ncclProxyArgs* args) {
if (args->state == ncclProxyOpReady) {
// Round to next multiple of sliceSteps
resources->step = ROUNDUP(resources->step, args->chunkSteps);
- args->head = resources->step;
- args->tail = resources->step;
- args->end = args->head + args->nsteps;
+ args->posted = args->received = args->transmitted = args->done = resources->step;
+ args->end = resources->step + args->nsteps;
args->state = ncclProxyOpProgress;
}
+ args->idle = 1;
if (args->state == ncclProxyOpProgress) {
- args->idle = 1;
int p = args->protocol;
int stepSize = args->connector->comm->buffSizes[p] / NCCL_STEPS;
char* localBuff = args->connector->conn.buffs[p];
void* mhandle = *(resources->mhandlesProto[p]);
- if (args->head < args->end) {
- volatile uint64_t* sendHead = &resources->sendMem->head;
- if ((args->tail < args->head + NCCL_STEPS) && (args->tail < *sendHead + NCCL_STEPS) && (args->tail < args->end)) {
- int buffSlot = args->tail%NCCL_STEPS;
- int sliceSize = stepSize * args->sliceSteps;
- NCCLCHECK(ncclNetIrecv(resources->netRecvComm, localBuff+buffSlot*stepSize, sliceSize, mhandle, args->requests+buffSlot));
- if (args->requests[buffSlot] != NULL) {
- args->tail += args->sliceSteps;
- args->idle = 0;
- }
+ int buffSize = stepSize*args->sliceSteps;
+ if (resources->shared) buffSize /= SENDRECV_SLICEFACTOR;
+ if (args->recvbytes < buffSize) buffSize = args->recvbytes;
+ if ((args->posted < args->done + NCCL_STEPS) && (args->posted < args->end)) {
+ int buffSlot = args->posted%NCCL_STEPS;
+ char* ptr;
+ if (resources->shared) {
+ NCCLCHECK(ncclProxySharedBuffersAlloc(args->connector->comm, resources->useGdr, 1, args->channel->id, buffSize, &ptr));
+ if (ptr == NULL) return ncclInternalError;
+ volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo;
+ ptrsFifo[buffSlot] = ptr;
+ } else {
+ ptr = localBuff+buffSlot*stepSize;
}
- if (args->tail > args->head) {
- int buffSlot = args->head%NCCL_STEPS;
- int done, size;
- NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, &size));
- if (done) {
- args->head += args->sliceSteps;
- if (args->protocol == NCCL_PROTO_SIMPLE) {
- if (resources->useGdr) NCCLCHECK(ncclNetFlush(resources->netRecvComm, localBuff+buffSlot*stepSize, size, mhandle));
- resources->recvMem->tail = args->head;
- }
- args->idle = 0;
+ NCCLCHECK(ncclNetIrecv(resources->netRecvComm, ptr, buffSize, mhandle, args->requests+buffSlot));
+ if (args->requests[buffSlot] != NULL) {
+ TRACE(NCCL_NET, "recvProxy [%d/%d] posted recv request %p", args->posted, buffSlot, args->requests[buffSlot]);
+ args->posted += args->sliceSteps;
+ args->idle = 0;
+ return ncclSuccess;
+ } else if (resources->shared) {
+ NCCLCHECK(ncclProxySharedBuffersFree(args->connector->comm, resources->useGdr, 1, args->channel->id, buffSize, ptr));
+ }
+ }
+ if (args->posted > args->received) {
+ int buffSlot = args->received%NCCL_STEPS;
+ int done, size;
+ NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, &size));
+ if (done) {
+ args->received += args->sliceSteps;
+ if (size > 0 && args->protocol == NCCL_PROTO_SIMPLE && resources->useGdr) {
+ // Don't pass data to the GPU yet, flush first.
+ volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo;
+ char* ptr = resources->shared ? (char*)(ptrsFifo[buffSlot]) : localBuff+buffSlot*stepSize;
+ NCCLCHECK(ncclNetIflush(resources->netRecvComm, ptr, size, mhandle, args->requests+buffSlot));
+ } else {
+ args->requests[buffSlot] = NULL;
}
+ args->idle = 0;
+ return ncclSuccess;
}
}
- if (args->head == args->end) {
- resources->step = args->end;
- args->idle = 0;
- args->state = ncclProxyOpNone;
+ if (args->received > args->transmitted) {
+ // Progress flush operations
+ int buffSlot = args->transmitted%NCCL_STEPS;
+ int done = 1;
+ if (args->requests[buffSlot]) NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, NULL));
+ if (done) {
+ args->transmitted += args->sliceSteps;
+ __sync_synchronize();
+ resources->recvMem->tail = args->transmitted;
+ args->idle = 0;
+ return ncclSuccess;
+ }
+ }
+ if (args->transmitted > args->done) {
+ volatile uint64_t* sendHead = &resources->sendMem->head;
+ uint64_t done = *sendHead;
+ while (done > args->done &&
+ // LL and LL128 can acknowledge 0-bytes send before they even happen. Don't go past what we transmitted.
+ args->transmitted > args->done) {
+ if (resources->shared) {
+ char* ptr = (char*)resources->recvMem->ptrsFifo[args->done%NCCL_STEPS];
+ NCCLCHECK(ncclProxySharedBuffersFree(args->connector->comm, resources->useGdr, 1, args->channel->id, buffSize, ptr));
+ }
+ args->done += args->sliceSteps;
+ args->idle = 0;
+ if (args->done == args->end) {
+ resources->step = args->end;
+ args->state = ncclProxyOpNone;
+ }
+ }
}
}
return ncclSuccess;
diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc
index 97eca9f..f8242a5 100644
--- a/src/transport/net_ib.cc
+++ b/src/transport/net_ib.cc
@@ -24,9 +24,8 @@
#include "ibvwrap.h"
#define USE_RDMA_WRITE 1
-#define USE_RDMA_SEND_INLINE 0
#define MAXNAMESIZE 64
-static char ncclIbIfName[MAX_IF_NAME_SIZE];
+static char ncclIbIfName[MAX_IF_NAME_SIZE+1];
static union socketAddress ncclIbIfAddr;
static int ncclNIbDevs = -1;
@@ -57,6 +56,8 @@ pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0);
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 14);
NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7);
+NCCL_PARAM(IbPkey, "IB_PKEY", 0);
+NCCL_PARAM(IbUseInline, "IB_USE_INLINE", 0);
NCCL_PARAM(IbSl, "IB_SL", 0);
NCCL_PARAM(IbTc, "IB_TC", 0);
NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192);
@@ -199,7 +200,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
ncclIbDevs[d].port, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE");
}
line[1023] = '\0';
- char addrline[1024];
+ char addrline[SOCKET_NAME_MAXLEN+1];
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s ; OOB %s:%s", line, ncclIbIfName, socketToString(&ncclIbIfAddr.sa, addrline));
}
pthread_mutex_unlock(&ncclIbLock);
@@ -246,7 +247,7 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) {
return ncclSuccess;
}
-#define MAX_REQUESTS 128
+#define MAX_REQUESTS NCCL_NET_MAX_REQUESTS
struct ncclIbQpInfo {
uint32_t lid;
@@ -267,18 +268,19 @@ struct ncclIbHandle {
union socketAddress connectAddr;
};
-struct ncclIbVerbs {
- struct ibv_pd* pd;
- struct ibv_cq* cq;
-};
-
struct ncclIbRequest {
int used;
int type;
struct ncclIbVerbs* verbs;
- int done;
+ int events;
int size;
- int free;
+};
+
+struct ncclIbVerbs {
+ struct ibv_pd* pd;
+ struct ibv_cq* cq;
+ uint64_t pad[2];
+ struct ncclIbRequest reqs[MAX_REQUESTS];
};
struct ncclIbListenComm {
@@ -292,18 +294,23 @@ struct ncclIbSendFifo {
uint32_t seq;
uint32_t rkey;
uint32_t ready;
+ uint64_t pad[1]; // Pad FIFO element size to be 32-bytes
};
struct ncclIbSendComm {
struct ncclIbVerbs verbs;
struct ncclIbSendFifo fifo[MAX_REQUESTS];
- struct ncclIbRequest reqs[MAX_REQUESTS];
uint32_t fifoHead;
int fd;
int ready;
struct ibv_qp* qp;
struct ibv_mr* fifoMr;
};
+// The SendFifo needs to be 32-byte aligned and each element needs
+// to be a 32-byte multiple, so that an entry does not get split and
+// written out of order when IB Relaxed Ordering is enabled
+static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
+static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples");
struct ncclIbGpuFlush {
int enabled;
@@ -326,16 +333,17 @@ struct ncclIbRemFifo {
struct ncclIbRecvComm {
struct ncclIbVerbs verbs;
struct ncclIbRemFifo remFifo;
- struct ncclIbRequest reqs[MAX_REQUESTS];
int fd;
int ready;
struct ibv_qp* qp;
struct ncclIbGpuFlush gpuFlush;
};
+static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
ncclResult_t ncclIbInitVerbs(ibv_context* ctx, struct ncclIbVerbs* verbs) {
NCCLCHECK(wrap_ibv_alloc_pd(&verbs->pd, ctx));
- NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, MAX_REQUESTS, NULL, NULL, 0));
+ // Recv requests can generate 2 completions (one for the post FIFO, one for the Recv).
+ NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS, NULL, NULL, 0));
return ncclSuccess;
}
@@ -351,17 +359,17 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce
qpInitAttr.send_cq = verbs->cq;
qpInitAttr.recv_cq = verbs->cq;
qpInitAttr.qp_type = IBV_QPT_RC;
- // We might send 2 requests per send (RDMA_WRITE+RDMA_WRITE_WITH_IMM)
+ // We might send 2 messages per send (RDMA and RDMA_WITH_IMM)
qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS;
qpInitAttr.cap.max_recv_wr = MAX_REQUESTS;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
- qpInitAttr.cap.max_inline_data = 0;
+ qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0;
NCCLCHECK(wrap_ibv_create_qp(qp, verbs->pd, &qpInitAttr));
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_INIT;
- qpAttr.pkey_index = 0;
+ qpAttr.pkey_index = ncclParamIbPkey();
qpAttr.port_num = ib_port;
qpAttr.qp_access_flags = access_flags;
NCCLCHECK(wrap_ibv_modify_qp(*qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS));
@@ -476,7 +484,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
socklen_t socklen = sizeof(struct sockaddr_in);
SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
struct ncclIbQpInfo remQpInfo;
- NCCLCHECK(socketReceive(rComm->fd, &remQpInfo, sizeof(remQpInfo)));
+ NCCLCHECK(socketRecv(rComm->fd, &remQpInfo, sizeof(remQpInfo)));
// IB setup
ibv_context* ctx = ncclIbDevs[lComm->dev].context;
@@ -504,14 +512,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
NCCLCHECK(wrap_ibv_reg_mr(&rComm->remFifo.mr, rComm->verbs.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ));
rComm->remFifo.sge.length = sizeof(struct ncclIbSendFifo);
rComm->remFifo.sge.lkey = rComm->remFifo.mr->lkey;
-
-#if USE_RDMA_SEND_INLINE
- // Determine whether the remFifo element data can be sent INLINE
- struct ibv_qp_attr attr;
- struct ibv_qp_init_attr init_attr;
- NCCLCHECK(wrap_ibv_query_qp(qp, &attr, IBV_QP_CAP, &init_attr));
- if (init_attr.cap.max_inline_data >= rComm->remFifo.sge.length) rComm->remFifo.flags = IBV_SEND_INLINE;
-#endif
+ if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE;
// Allocate Flush dummy buffer for GPU Direct RDMA
rComm->gpuFlush.enabled = (ncclIbGdrSupport(lComm->dev) == 0) && (ncclParamIbGdrFlushDisable() == 0) ? 1 : 0;
@@ -548,16 +549,15 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) {
return ncclSuccess;
}
-ncclResult_t ncclIbGetRequest(struct ncclIbRequest* reqs, struct ncclIbRequest** req) {
+ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** req) {
for (int i=0; i<MAX_REQUESTS; i++) {
- struct ncclIbRequest* r = reqs+i;
+ struct ncclIbRequest* r = verbs->reqs+i;
if (r->used == 0) {
r->used = 1;
r->type = 0;
- r->verbs = NULL;
- r->done = 0;
+ r->verbs = verbs;
+ r->events = 1;
r->size = -1;
- r->free = 0;
*req = r;
return ncclSuccess;
}
@@ -566,6 +566,10 @@ ncclResult_t ncclIbGetRequest(struct ncclIbRequest* reqs, struct ncclIbRequest**
*req = NULL;
return ncclInternalError;
}
+ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) {
+ r->used = 0;
+ return ncclSuccess;
+}
ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
struct ncclIbQpInfo remQpInfo;
@@ -580,7 +584,6 @@ ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) {
NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo));
NCCLCHECK(ncclIbRtsQp(qp));
comm->ready = 1;
-
// Block until this is done. It *should* not block indefinitely.
NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int)));
@@ -601,6 +604,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size);
#define REG_ALIGN (4096)
ncclResult_t ncclIbRegMr(void* comm, void* data, int size, int type, void** mhandle) {
+ static_assert(offsetof(struct ncclIbSendComm, verbs) == offsetof(struct ncclIbRecvComm, verbs), "Send and recv comms must have verbs at the same offset");
struct ncclIbVerbs* verbs = (struct ncclIbVerbs*)comm;
uint64_t addr = (uint64_t)data;
assert(size > 0);
@@ -634,8 +638,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
if (*readyPtr == 0) { *request = NULL; return ncclSuccess; }
struct ncclIbRequest* req;
- NCCLCHECK(ncclIbGetRequest(comm->reqs, &req));
- req->verbs = &comm->verbs;
+ NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
req->size = size;
struct ibv_send_wr wr;
@@ -651,23 +654,24 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
wr.sg_list = &sge;
wr.num_sge = 1;
}
+#if USE_RDMA_WRITE == 0
wr.opcode = IBV_WR_SEND;
wr.send_flags = IBV_SEND_SIGNALED;
-
- int useAr = 0;
- if (size > ncclParamIbArThreshold()) {
- useAr = 1;
- }
-#if USE_RDMA_WRITE
+#else
__sync_synchronize(); // order the readyPtr load against rkey load below
// Sanity checks to catch user collective call count/size mismatches
// plus any potential programming errors
- if (size > slot->size || slot->size <= 0 || slot->addr == 0 || slot->rkey == 0 || slot->seq != comm->fifoHead) {
+ if (size > slot->size || slot->size < 0 || slot->addr == 0 || slot->rkey == 0 || slot->seq != comm->fifoHead) {
WARN("NET/IB : collective mismatch error local size %d remote %d addr %lx rkey %x seq %x/%x",
size, slot->size, slot->addr, slot->rkey, slot->seq, comm->fifoHead);
return ncclInternalError;
}
+ int useAr = 0;
+ if (size > ncclParamIbArThreshold()) {
+ useAr = 1;
+ }
wr.opcode = useAr ? IBV_WR_RDMA_WRITE : IBV_WR_RDMA_WRITE_WITH_IMM;
+ wr.send_flags = useAr ? 0 : IBV_SEND_SIGNALED;
wr.wr.rdma.remote_addr = slot->addr;
wr.wr.rdma.rkey = slot->rkey;
wr.imm_data = size; // Send the message size via imm_data
@@ -691,7 +695,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.sg_list = NULL;
wr.num_sge = 0;
- wr.send_flags &= ~IBV_SEND_SIGNALED;
+ wr.send_flags |= IBV_SEND_SIGNALED;
NCCLCHECK(wrap_ibv_post_send(comm->qp, &wr, &bad_wr));
}
#endif
@@ -699,28 +703,51 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, void* mhandle, vo
return ncclSuccess;
}
-ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t addr, int size) {
+ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t addr, int size, struct ncclIbRequest* req) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- struct ncclIbRequest* req;
- NCCLCHECK(ncclIbGetRequest(comm->reqs, &req));
- req->verbs = &comm->verbs;
- req->free = 1; // Not a user req ; free as soon as it is complete.
- wr.wr_id = (uint64_t)req;
- struct ncclIbSendFifo* localElem = comm->remFifo.elems + (comm->remFifo.tail % MAX_REQUESTS);
+ int slot = comm->remFifo.tail%MAX_REQUESTS;
+ struct ncclIbSendFifo* localElem = comm->remFifo.elems + slot;
localElem->addr = addr;
localElem->rkey = rkey;
localElem->ready = 1;
localElem->size = size; // Sanity/Debugging
localElem->seq = comm->remFifo.tail; // Sanity/Debugging
- wr.wr.rdma.remote_addr = comm->remFifo.addr + (comm->remFifo.tail % MAX_REQUESTS) * sizeof(struct ncclIbSendFifo);
+ wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*sizeof(struct ncclIbSendFifo);
wr.wr.rdma.rkey = comm->remFifo.rkey;
comm->remFifo.sge.addr = (uint64_t)localElem;
wr.sg_list = &comm->remFifo.sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE;
- wr.send_flags = IBV_SEND_SIGNALED | comm->remFifo.flags; // IBV_SEND_INLINE
+ wr.send_flags = comm->remFifo.flags; // IBV_SEND_INLINE
+
+ // We need to occasionally post a request with the IBV_SEND_SIGNALED flag, otherwise
+ // the send queue will never empty.
+ //
+ // From https://www.rdmamojo.com/2014/06/30/working-unsignaled-completions/
+ // "How to use Unsignaled Completion?" / "Gotchas and Pitfalls"
+ // All posted Send Requested, Signaled and Unsignaled, are considered outstanding until
+ // a Work Completion that they, or Send Requests that were posted after them, was polled
+ // from the Completion Queue associated with the Send Queue. This means if one works with
+ // a Queue Pair that was configured to work with Unsignaled Completions, he must make
+ // sure that occasionally (before the Send Queue is full with outstanding Send Requests)
+ // a Send Request that generate Work Completion will be posted.
+ //
+ // Not following this rule may lead to a case that the Send Queue is full with Send
+ // Requests that won't generate Work Completion:
+ //
+ // - The Send Queue is full, so no new Send Requests can be posted to it
+ // - The Send Queue can't be emptied, since no Work Completion can be generated anymore
+ // (the reason is that no Work Completion, that can generate Work Completion that
+ // polling it will empty the Send Queue, can be posted)
+ // - The status of all posted Send Request is considered unknown
+ //
+ if (slot == 0) {
+ wr.send_flags |= IBV_SEND_SIGNALED;
+ wr.wr_id = (uint64_t)req;
+ req->events++;
+ }
struct ibv_send_wr* bad_wr;
NCCLCHECK(wrap_ibv_post_send(comm->qp, &wr, &bad_wr));
@@ -737,8 +764,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, void* mhandle, vo
struct ibv_mr* mr = (struct ibv_mr*)mhandle;
struct ncclIbRequest* req;
- NCCLCHECK(ncclIbGetRequest(comm->reqs, &req));
- req->verbs = &comm->verbs;
+ NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
req->size = size;
struct ibv_recv_wr wr;
@@ -760,17 +786,16 @@ ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, void* mhandle, vo
*request = req;
// Post to FIFO to notify sender
- NCCLCHECK(ncclIbPostFifo(comm, mr->rkey, (uint64_t)data, size));
+ NCCLCHECK(ncclIbPostFifo(comm, mr->rkey, (uint64_t)data, size, req));
return ncclSuccess;
}
-ncclResult_t ncclIbFlush(void* recvComm, void* data, int size, void* mhandle) {
+ncclResult_t ncclIbIflush(void* recvComm, void* data, int size, void* mhandle, void** request) {
struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm;
if (comm->gpuFlush.enabled == 0 || size == 0) return ncclSuccess;
struct ncclIbRequest* req;
- NCCLCHECK(ncclIbGetRequest(comm->reqs, &req));
- req->verbs = &comm->verbs;
+ NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req));
struct ibv_mr* mr = (struct ibv_mr*)mhandle;
struct ibv_send_wr wr;
@@ -787,11 +812,7 @@ ncclResult_t ncclIbFlush(void* recvComm, void* data, int size, void* mhandle) {
struct ibv_send_wr* bad_wr;
NCCLCHECK(wrap_ibv_post_send(comm->gpuFlush.qp, &wr, &bad_wr));
- int done = 0;
- while (done == 0) {
- NCCLCHECK((ncclResult_t)ncclIbTest(req, &done, NULL));
- }
-
+ *request = req;
return ncclSuccess;
}
@@ -800,10 +821,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
*done = 0;
while (1) {
- if (r->done == 1) {
+ if (r->events == 0) {
*done = 1;
if (size) *size = r->size;
- r->used = 0;
+ NCCLCHECK(ncclIbFreeRequest(r));
return ncclSuccess;
}
@@ -828,11 +849,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* size) {
doneReq->size = wc->imm_data;
#endif
}
- doneReq->done = 1;
- if (doneReq->free == 1) {
- // This is an internal (FIFO post) req. Free it immediately.
- doneReq->used = 0;
- }
+ doneReq->events--;
}
}
}
@@ -887,7 +904,7 @@ ncclNet_t ncclNetIb = {
ncclIbDeregMr,
ncclIbIsend,
ncclIbIrecv,
- ncclIbFlush,
+ ncclIbIflush,
ncclIbTest,
ncclIbCloseSend,
ncclIbCloseRecv,
diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc
index 5bc22c3..272d8cd 100644
--- a/src/transport/net_socket.cc
+++ b/src/transport/net_socket.cc
@@ -48,17 +48,19 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) {
WARN("NET/Socket : no interface found");
return ncclInternalError;
} else {
- char line[1024];
- char addrline[1024];
+ #define MAX_LINE_LEN (2047)
+ char line[MAX_LINE_LEN+1];
+ char addrline[SOCKET_NAME_MAXLEN+1];
line[0] = '\0';
+ addrline[SOCKET_NAME_MAXLEN] = '\0';
for (int i=0; i<ncclNetIfs; i++) {
strcpy(ncclSocketDevs[i].devName, names+i*MAX_IF_NAME_SIZE);
memcpy(&ncclSocketDevs[i].addr, addrs+i, sizeof(union socketAddress));
NCCLCHECK(ncclSocketGetPciPath(ncclSocketDevs[i].devName, &ncclSocketDevs[i].pciPath));
- snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE,
+ snprintf(line+strlen(line), MAX_LINE_LEN-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE,
socketToString(&addrs[i].sa, addrline));
}
- line[1023] = '\0';
+ line[MAX_LINE_LEN] = '\0';
INFO(NCCL_INIT|NCCL_NET,"NET/Socket : Using%s", line);
}
}
@@ -112,8 +114,7 @@ ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
#define MAX_SOCKETS 64
#define MAX_THREADS 16
-#define MAX_REQUESTS 128
-#define MAX_QUEUE_LEN MAX_REQUESTS
+#define MAX_REQUESTS NCCL_NET_MAX_REQUESTS
#define MIN_CHUNKSIZE (64*1024)
NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
@@ -149,6 +150,7 @@ struct ncclSocketRequest {
struct ncclSocketTaskQueue {
int next;
+ int len;
struct ncclSocketTask* tasks;
};
@@ -188,7 +190,7 @@ void* persistentSocketThread(void *args_) {
while (1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
- for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) {
+ for (int i=0; i<myQueue->len; i+=nSocksPerThread) {
int repeat;
do {
repeat = 0;
@@ -363,7 +365,11 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
struct ncclSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if (queue->tasks == NULL) {
- NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN));
+ // each request can be divided up to nSocks tasks, and
+ // these tasks are distributed to nThreads threads,
+ // we need to make sure each thread queue has enough slots for MAX_REQUESTS
+ queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
+ NCCLCHECK(ncclCalloc(&queue->tasks, queue->len));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
@@ -382,7 +388,7 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
- queue->next = (queue->next+1)%MAX_QUEUE_LEN;
+ queue->next = (queue->next+1)%queue->len;
res->state = start;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
@@ -420,6 +426,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
// divide into subtasks
int chunkOffset = 0, i = 0;
if (r->comm->nSocks > 0) {
+ // each request can be divided up to nSocks tasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
while (chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size-chunkOffset);
@@ -477,7 +484,7 @@ ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle
return ncclSuccess;
}
-ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle) {
+ncclResult_t ncclSocketIflush(void* recvComm, void* data, int size, void* mhandle, void** request) {
// We don't support CUDA pointers, so we don't need a flush operation
return ncclInternalError;
}
@@ -526,7 +533,7 @@ ncclNet_t ncclNetSocket = {
ncclSocketDeregMr,
ncclSocketIsend,
ncclSocketIrecv,
- ncclSocketFlush,
+ ncclSocketIflush,
ncclSocketTest,
ncclSocketClose,
ncclSocketClose,
diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc
index 2cbe390..e05a433 100644
--- a/src/transport/p2p.cc
+++ b/src/transport/p2p.cc
@@ -7,24 +7,29 @@
#include "comm.h"
#include "graph.h"
#include "utils.h"
+#include "bootstrap.h"
struct p2pConnectInfo {
- int direct;
+ int rank;
int read;
- union {
- void* directPtr;
- cudaIpcMemHandle_t devIpc;
- };
+ void* directPtr;
+ cudaIpcMemHandle_t devIpc;
};
struct p2pSendResources {
struct ncclSendMem* devMem;
void* ipcPtr;
+ int remoteId;
+ int memRank;
+ void* bootstrap;
};
struct p2pRecvResources {
struct ncclRecvMem* devMem;
void* ipcPtr;
+ int remoteId;
+ int memRank;
+ void* bootstrap;
};
#include <sys/types.h>
@@ -55,9 +60,10 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
}
// Check topology / p2p level.
- int read;
- NCCLCHECK(ncclTopoCheckP2p(topo, info1->busId, info2->busId, ret, &read));
+ int intermediateRank;
+ NCCLCHECK(ncclTopoCheckP2p(topo, info1->busId, info2->busId, ret, NULL, &intermediateRank));
if (*ret == 0) return ncclSuccess;
+ if (intermediateRank != -1) return ncclSuccess;
// Convert the peer's busId into a local cudaDev index (cf. CUDA_VISIBLE_DEVICES)
int cudaDev1 = busIdToCudaDev(info1->busId);
@@ -100,145 +106,134 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
// Setting this to non zero causes P2P to use Reads rather than Writes
NCCL_PARAM(P2pReadEnable, "P2P_READ_ENABLE", -2);
-static int p2pUseRead(struct ncclTopoSystem* topo, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) {
- int readEnable = ncclParamP2pReadEnable();
- if (readEnable != -2) return readEnable;
-
- int p2p, read;
+static ncclResult_t p2pGetInfo(struct ncclTopoSystem* topo, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* read, int* intermediateRank) {
+ int p2p;
// Queries the topology to see if the GPUs are Ampere and
// connected via NVLink, if so we enable P2P Read by default
- NCCLCHECK(ncclTopoCheckP2p(topo, info1->busId, info2->busId, &p2p, &read));
+ NCCLCHECK(ncclTopoCheckP2p(topo, info1->busId, info2->busId, &p2p, read, intermediateRank));
+
+ int readEnable = ncclParamP2pReadEnable();
+ if (readEnable != -2) *read = readEnable;
+ return ncclSuccess;
+}
- return read;
+static ncclResult_t p2pMap(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct p2pConnectInfo* p2pInfo, void** devMem, void** ipcPtr) {
+ if (myInfo->pidHash == peerInfo->pidHash) {
+ if (peerInfo->cudaDev != myInfo->cudaDev) {
+ // Enable P2P access
+ cudaError_t err = cudaDeviceEnablePeerAccess(peerInfo->cudaDev, 0);
+ if (err == cudaErrorPeerAccessAlreadyEnabled) {
+ cudaGetLastError();
+ } else if (err != cudaSuccess) {
+ WARN("failed to peer with device %d(=%lx): %d %s",
+ peerInfo->cudaDev, peerInfo->busId, err, cudaGetErrorString(err));
+ return ncclInternalError;
+ }
+ }
+ *devMem = p2pInfo->directPtr;
+ *ipcPtr = NULL;
+ } else {
+ CUDACHECK(cudaIpcOpenMemHandle(devMem, p2pInfo->devIpc, cudaIpcMemLazyEnablePeerAccess));
+ *ipcPtr = *devMem;
+ }
+ return ncclSuccess;
}
/* Send: Create and return connect structures for this peer to connect to me */
-ncclResult_t p2pSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
+ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
struct p2pSendResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
send->transportResources = resources;
- int useRead = p2pUseRead(topo, myInfo, peerInfo);
+ int useRead, intermediateRank;
+ NCCLCHECK(p2pGetInfo(comm->topo, myInfo, peerInfo, &useRead, &intermediateRank));
int sendSize = sizeof(struct ncclSendMem);
// For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure
if (useRead) sendSize += send->comm->buffSizes[NCCL_PROTO_SIMPLE];
ALIGN_SIZE(sendSize, CUDA_IPC_MIN);
- NCCLCHECK(ncclCudaCalloc((char**)&resources->devMem, sendSize));
struct p2pConnectInfo info;
info.read = useRead;
const char* useReadStr = info.read ? "/read" : "";
- if (myInfo->pidHash == peerInfo->pidHash) {
- info.direct = 1;
- info.directPtr = resources->devMem;
- if (myInfo->cudaDev == peerInfo->cudaDev) {
- INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%d] -> %d[%d] via P2P/common device%s",
- channelId, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev, useReadStr);
- return ncclInternalError;
+
+ resources->remoteId = -1;
+ resources->bootstrap = comm->bootstrap;
+ if (intermediateRank == -1) {
+ NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, sendSize));
+ info.rank = myInfo->rank;
+ if (myInfo->pidHash == peerInfo->pidHash) {
+ if (useRead == 0) send->conn.direct |= NCCL_DIRECT_GPU;
+ INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s",
+ channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
} else {
- // Enable P2P access
- cudaError_t err = cudaDeviceEnablePeerAccess(peerInfo->cudaDev, 0);
- if (err == cudaErrorPeerAccessAlreadyEnabled) {
- cudaGetLastError();
- } else if (err != cudaSuccess) {
- WARN("failed to peer with device %d(=%lx): %d %s",
- peerInfo->cudaDev, peerInfo->busId, err, cudaGetErrorString(err));
- return ncclInternalError;
- }
- INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s",
+ CUDACHECK(cudaIpcGetMemHandle(&info.devIpc, info.directPtr));
+ INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s",
channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
}
} else {
- // Convert the peer's busId into a local cudaDev index (cf. CUDA_VISIBLE_DEVICES)
- int peerCudaDev = busIdToCudaDev(peerInfo->busId);
- info.direct = 0;
- // Map IPC and enable P2P access
- cudaError_t err = cudaIpcGetMemHandle(&info.devIpc, (void*)resources->devMem);
- if (err != cudaSuccess) {
- WARN("rank %d failed to get CUDA IPC handle to device %d(=%lx) : %d %s",
- myInfo->rank, peerCudaDev, peerInfo->busId, err, cudaGetErrorString(err));
- return ncclInternalError;
- }
- INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s",
- channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
- //TRACE_DUMP_IPC(&info.devIpc);
+ NCCLCHECK(bootstrapRemAlloc(sendSize, intermediateRank, resources->bootstrap, &resources->remoteId, &info.devIpc, &info.directPtr));
+ info.rank = intermediateRank;
+ INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/indirect/%d[%lx]%s",
+ channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, intermediateRank,
+ comm->peerInfo[intermediateRank].busId, useReadStr);
}
+ resources->memRank = info.rank;
+
+ NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info.rank, &info, (void**)&resources->devMem, &resources->ipcPtr));
+
static_assert(sizeof(struct p2pConnectInfo) <= sizeof(struct ncclConnect), "p2p Connect Info is too big");
memcpy(connectInfo, &info, sizeof(struct p2pConnectInfo));
return ncclSuccess;
}
/* Create and return connect structures for this peer to connect to me */
-ncclResult_t p2pRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
+ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo,
struct ncclConnect* connectInfo, struct ncclConnector * recv, int channelId) {
struct p2pRecvResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources;
- int useRead = p2pUseRead(topo, myInfo, peerInfo);
+ int useRead, intermediateRank;
+ NCCLCHECK(p2pGetInfo(comm->topo, myInfo, peerInfo, &useRead, &intermediateRank));
int recvSize = offsetof(struct ncclRecvMem, buff);
// For P2P Read the SIMPLE buffer is tagged on the end of the ncclSendMem structure
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) if (!(useRead && p == NCCL_PROTO_SIMPLE)) recvSize += recv->comm->buffSizes[p];
ALIGN_SIZE(recvSize, CUDA_IPC_MIN);
- NCCLCHECK(ncclCudaCalloc((char**)&resources->devMem, recvSize));
struct p2pConnectInfo info;
info.read = useRead;
- if (myInfo->pidHash == peerInfo->pidHash) {
- info.direct = 1;
- info.directPtr = resources->devMem;
- if (myInfo->cudaDev == peerInfo->cudaDev) {
- TRACE(NCCL_INIT|NCCL_P2P,"%d <- %d via P2P/common device", myInfo->rank, peerInfo->rank);
+
+ resources->remoteId = -1;
+ resources->bootstrap = comm->bootstrap;
+ if (intermediateRank == -1) {
+ NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, recvSize));
+ info.rank = myInfo->rank;
+ if (myInfo->pidHash == peerInfo->pidHash) {
+ if (useRead == 0) recv->conn.direct |= NCCL_DIRECT_GPU;
} else {
- // Enable P2P access
- cudaError_t err = cudaDeviceEnablePeerAccess(peerInfo->cudaDev, 0);
- if (err == cudaErrorPeerAccessAlreadyEnabled) {
- cudaGetLastError();
- } else if (err != cudaSuccess) {
- WARN("failed to peer with device %d(=%lx): %d %s",
- peerInfo->cudaDev, peerInfo->busId, err, cudaGetErrorString(err));
- return ncclInternalError;
- }
- TRACE(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] <- %d[%lx] via P2P/direct pointer", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId);
+ CUDACHECK(cudaIpcGetMemHandle(&info.devIpc, info.directPtr));
}
} else {
- // Convert the peer's busId into a local cudaDev index (cf. CUDA_VISIBLE_DEVICES)
- int peerCudaDev = busIdToCudaDev(peerInfo->busId);
- info.direct = 0;
- // Map IPC and enable P2P access
- cudaError_t err = cudaIpcGetMemHandle(&info.devIpc, (void*)resources->devMem);
- if (err != cudaSuccess) {
- WARN("rank %d failed to get CUDA IPC handle to device %d(=%lx) : %d %s",
- myInfo->rank, peerCudaDev, peerInfo->busId, err, cudaGetErrorString(err));
- return ncclInternalError;
- }
- TRACE(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] <- %d[%lx] via P2P/IPC", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId);
- //TRACE_DUMP_IPC(&info.devIpc);
+ NCCLCHECK(bootstrapRemAlloc(recvSize, intermediateRank, resources->bootstrap, &resources->remoteId, &info.devIpc, &info.directPtr));
+ info.rank = intermediateRank;
}
+ resources->memRank = info.rank;
+
+ NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info.rank, &info, (void**)&resources->devMem, &resources->ipcPtr));
+
static_assert(sizeof(struct p2pConnectInfo) <= sizeof(struct ncclConnect), "p2p Connect Info is too big");
memcpy(connectInfo, &info, sizeof(struct p2pConnectInfo));
return ncclSuccess;
}
/* Connect/Send to this peer */
-static ncclResult_t p2pSendConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
+static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
struct p2pSendResources* resources = (struct p2pSendResources*)send->transportResources;
struct ncclRecvMem* remDevMem;
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
- if (info->direct) {
- remDevMem = (struct ncclRecvMem*)(info->directPtr);
- if (info->read == 0) send->conn.direct |= NCCL_DIRECT_GPU;
- } else {
- //TRACE_DUMP_IPC(&info->devIpc);
- cudaError_t err = cudaIpcOpenMemHandle(&resources->ipcPtr, info->devIpc, cudaIpcMemLazyEnablePeerAccess);
- remDevMem = (struct ncclRecvMem*)resources->ipcPtr;
- if (err != cudaSuccess) {
- WARN("failed to open CUDA IPC handle : %d %s",
- err, cudaGetErrorString(err));
- return ncclUnhandledCudaError;
- }
- }
+
+ NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
int offset = 0;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@@ -257,26 +252,12 @@ static ncclResult_t p2pSendConnect(struct ncclConnect* connectInfo, int nranks,
}
/* Connect/Recv from this peer */
-ncclResult_t p2pRecvConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
+ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
struct p2pRecvResources* resources = (struct p2pRecvResources*)recv->transportResources;
struct ncclSendMem* remDevMem;
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
- if (info->direct) {
- remDevMem = (struct ncclSendMem*)(info->directPtr);
- if (info->read == 0) {
- recv->conn.direct |= NCCL_DIRECT_GPU;
- recv->conn.ptrExchange = &remDevMem->ptrExchange;
- }
- } else {
- //TRACE_DUMP_IPC(&info->devIpc);
- cudaError_t err = cudaIpcOpenMemHandle(&resources->ipcPtr, info->devIpc, cudaIpcMemLazyEnablePeerAccess);
- remDevMem = (struct ncclSendMem*)resources->ipcPtr;
- if (err != cudaSuccess) {
- WARN("failed to open CUDA IPC handle : %d %s",
- err, cudaGetErrorString(err));
- return ncclUnhandledCudaError;
- }
- }
+
+ NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
int offset = 0;
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@@ -290,6 +271,7 @@ ncclResult_t p2pRecvConnect(struct ncclConnect* connectInfo, int nranks, int ran
}
recv->conn.tail = &resources->devMem->tail;
recv->conn.head = &remDevMem->head;
+ recv->conn.ptrExchange = &remDevMem->ptrExchange;
return ncclSuccess;
}
@@ -297,6 +279,10 @@ ncclResult_t p2pSendFree(void* resources) {
struct p2pSendResources* sendRes = (struct p2pSendResources*)resources;
if (sendRes->ipcPtr)
CUDACHECK(cudaIpcCloseMemHandle(sendRes->ipcPtr));
+ if (sendRes->remoteId != -1) {
+ NCCLCHECK(bootstrapRemFree(sendRes->remoteId, sendRes->memRank, sendRes->bootstrap));
+ sendRes->devMem = NULL;
+ }
CUDACHECK(cudaFree(sendRes->devMem));
free(sendRes);
return ncclSuccess;
@@ -306,6 +292,10 @@ ncclResult_t p2pRecvFree(void* resources) {
struct p2pRecvResources* recvRes = (struct p2pRecvResources*)resources;
if (recvRes->ipcPtr)
CUDACHECK(cudaIpcCloseMemHandle(recvRes->ipcPtr));
+ if (recvRes->remoteId != -1) {
+ NCCLCHECK(bootstrapRemFree(recvRes->remoteId, recvRes->memRank, recvRes->bootstrap));
+ recvRes->devMem = NULL;
+ }
CUDACHECK(cudaFree(recvRes->devMem));
free(recvRes);
return ncclSuccess;
diff --git a/src/transport/shm.cc b/src/transport/shm.cc
index 488f456..d8a5b52 100644
--- a/src/transport/shm.cc
+++ b/src/transport/shm.cc
@@ -57,7 +57,7 @@ ncclResult_t shmCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
#define MAX_SHM_NAME_LEN 1024
/* Create and return connect structures for this peer to connect to me */
-ncclResult_t shmSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
+ncclResult_t shmSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId) {
struct shmSendResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
@@ -81,7 +81,7 @@ ncclResult_t shmSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
return ncclSuccess;
}
-ncclResult_t shmRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
+ncclResult_t shmRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId) {
struct shmRecvResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources;
@@ -106,7 +106,7 @@ ncclResult_t shmRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* gra
}
/* Connect to this peer */
-ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
+ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) {
// Setup device pointers
struct shmConnectInfo* info = (struct shmConnectInfo*)connectInfo;
struct shmSendResources* resources = (struct shmSendResources*)send->transportResources;
@@ -131,7 +131,7 @@ ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, int nranks, int ran
return ncclSuccess;
}
-ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
+ncclResult_t shmRecvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) {
// Setup device pointers
struct shmRecvResources* resources = (struct shmRecvResources*)recv->transportResources;
struct shmConnectInfo* info = (struct shmConnectInfo*)connectInfo;