Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/bootstrap.cc')
-rw-r--r--src/bootstrap.cc450
1 files changed, 256 insertions, 194 deletions
diff --git a/src/bootstrap.cc b/src/bootstrap.cc
index e90dd66..bd6ec99 100644
--- a/src/bootstrap.cc
+++ b/src/bootstrap.cc
@@ -13,144 +13,77 @@
#include <unistd.h>
#include <sys/types.h>
-struct bootstrapNetComm {
- int fd;
-};
-
/* Init functions */
-static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
-static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
-static int bootstrapNetIfs = -1;
+static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1];
+static union socketAddress bootstrapNetIfAddr;
+static int bootstrapNetInitDone = 0;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
ncclResult_t bootstrapNetInit() {
- if (bootstrapNetIfs == -1) {
+ if (bootstrapNetInitDone == 0) {
pthread_mutex_lock(&bootstrapNetLock);
- if (bootstrapNetIfs == -1) {
- bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
- if (bootstrapNetIfs <= 0) {
- WARN("Bootstrap : no socket interface found");
- return ncclInternalError;
+ if (bootstrapNetInitDone == 0) {
+ char* env = getenv("NCCL_COMM_ID");
+ if (env) {
+ union socketAddress remoteAddr;
+ if (GetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) {
+ WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
+ return ncclInvalidArgument;
+ }
+ if (findInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
+ WARN("NET/Socket : No usable listening interface found");
+ return ncclSystemError;
+ }
} else {
- char line[1024];
- char addrline[1024];
- line[0] = '\0';
- for (int i=0; i<bootstrapNetIfs; i++) {
- snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
- socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
+ int nIfs = findInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1);
+ if (nIfs <= 0) {
+ WARN("Bootstrap : no socket interface found");
+ return ncclInternalError;
}
- line[1023] = '\0';
- INFO(NCCL_INIT, "Bootstrap : Using%s", line);
}
+ char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2];
+ sprintf(line, " %s:", bootstrapNetIfName);
+ socketToString(&bootstrapNetIfAddr.sa, line+strlen(line));
+ INFO(NCCL_INIT, "Bootstrap : Using%s", line);
+ bootstrapNetInitDone = 1;
}
pthread_mutex_unlock(&bootstrapNetLock);
}
return ncclSuccess;
}
-static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
- NCCLCHECK(ncclCalloc(comm, 1));
- (*comm)->fd = -1;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
- if (dev >= bootstrapNetIfs) return ncclInternalError;
- memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
- return ncclSuccess;
-}
-
/* Socket Interface Selection type */
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
-static ncclResult_t bootstrapNetListen(int dev, ncclNetHandle_t* netHandle, void** listenComm) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- static_assert(sizeof(union socketAddress) < NCCL_NET_HANDLE_MAXSIZE, "union socketAddress size is too large");
- // if dev >= 0, listen based on dev
- if (dev >= 0) {
- NCCLCHECK(bootstrapNetGetSocketAddr(dev, connectAddr));
- } else if (dev == findSubnetIf) {
- // handle stores a remote address
- // need to find a local addr that is in the same network as the remote addr
- union socketAddress localAddr;
- char ifName[MAX_IF_NAME_SIZE];
- if (findInterfaceMatchSubnet(ifName, &localAddr, connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
- WARN("NET/Socket : No usable listening interface found");
- return ncclSystemError;
- }
- // pass the local address back
- memcpy(connectAddr, &localAddr, sizeof(localAddr));
- } // Otherwise, handle stores a local address
- struct bootstrapNetComm* comm;
- NCCLCHECK(bootstrapNetNewComm(&comm));
- NCCLCHECK(createListenSocket(&comm->fd, connectAddr));
- *listenComm = comm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetConnect(int dev, ncclNetHandle_t* netHandle, void** sendComm) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- struct bootstrapNetComm* comm;
- NCCLCHECK(bootstrapNetNewComm(&comm));
- NCCLCHECK(connectAddress(&comm->fd, connectAddr));
- *sendComm = comm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
- struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
- struct bootstrapNetComm* rComm;
- NCCLCHECK(bootstrapNetNewComm(&rComm));
+static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) {
struct sockaddr_in sockaddr;
socklen_t socklen = sizeof(struct sockaddr_in);
- SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
- *recvComm = rComm;
- return ncclSuccess;
-}
-
-static ncclResult_t bootstrapNetClose(void* opaqueComm) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
- if (comm) {
- close(comm->fd);
- free(comm);
- }
+ SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd);
return ncclSuccess;
}
-static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
-
// Additional sync functions
-static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
- NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
- NCCLCHECK(socketSend(comm->fd, data, size));
+static ncclResult_t bootstrapNetSend(int fd, void* data, int size) {
+ NCCLCHECK(socketSend(fd, &size, sizeof(int)));
+ NCCLCHECK(socketSend(fd, data, size));
return ncclSuccess;
}
-static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
- struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
+static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) {
int recvSize;
- NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
+ NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int)));
if (recvSize > size) {
WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
return ncclInternalError;
}
- NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
- return ncclSuccess;
-}
-
-ncclResult_t bootstrapNetCreateHandle(ncclNetHandle_t* netHandle, const char* str) {
- union socketAddress* connectAddr = (union socketAddress*) netHandle;
- NCCLCHECK(GetSocketAddrFromString(connectAddr, str));
+ NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size)));
return ncclSuccess;
}
struct extInfo {
int rank;
int nranks;
- ncclNetHandle_t extHandleListenRoot;
- ncclNetHandle_t extHandleListen;
+ union socketAddress extAddressListenRoot;
+ union socketAddress extAddressListen;
};
#include <sys/resource.h>
@@ -163,27 +96,29 @@ static ncclResult_t setFilesLimit() {
return ncclSuccess;
}
-static void *bootstrapRoot(void* listenComm) {
+static void *bootstrapRoot(void* args) {
+ int listenFd = (uint64_t)args;
+ ncclResult_t res = ncclSuccess;
+ int nranks = 0, c = 0;
struct extInfo info;
- ncclNetHandle_t *rankHandles = NULL;
- ncclNetHandle_t *rankHandlesRoot = NULL; // for initial rank <-> root information exchange
- ncclNetHandle_t zero = { 0 }; // for sanity checking
- void* tmpComm;
- ncclResult_t res;
+ union socketAddress *rankAddresses = NULL;
+ union socketAddress *rankAddressesRoot = NULL; // for initial rank <-> root information exchange
+ union socketAddress *zero = NULL;
+ NCCLCHECKGOTO(ncclCalloc(&zero, 1), res, out);
setFilesLimit();
TRACE(NCCL_INIT, "BEGIN");
/* Receive addresses from all ranks */
- int nranks = 0, c = 0;
do {
- NCCLCHECKGOTO(bootstrapNetAccept(listenComm, &tmpComm), res, out);
- NCCLCHECKGOTO(bootstrapNetRecv(tmpComm, &info, sizeof(info)), res, out);
- NCCLCHECKGOTO(bootstrapNetCloseRecv(tmpComm), res, out);
+ int tmpFd;
+ NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out);
+ NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out);
+ close(tmpFd);
if (c == 0) {
nranks = info.nranks;
- NCCLCHECKGOTO(ncclCalloc(&rankHandles, nranks), res, out);
- NCCLCHECKGOTO(ncclCalloc(&rankHandlesRoot, nranks), res, out);
+ NCCLCHECKGOTO(ncclCalloc(&rankAddresses, nranks), res, out);
+ NCCLCHECKGOTO(ncclCalloc(&rankAddressesRoot, nranks), res, out);
}
if (nranks != info.nranks) {
@@ -191,14 +126,14 @@ static void *bootstrapRoot(void* listenComm) {
goto out;
}
- if (memcmp(&zero, &rankHandlesRoot[info.rank], sizeof(ncclNetHandle_t)) != 0) {
+ if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union socketAddress)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
goto out;
}
// Save the connection handle for that rank
- memcpy(rankHandlesRoot+info.rank, info.extHandleListenRoot, sizeof(ncclNetHandle_t));
- memcpy(rankHandles+info.rank, info.extHandleListen, sizeof(ncclNetHandle_t));
+ memcpy(rankAddressesRoot+info.rank, &info.extAddressListenRoot, sizeof(union socketAddress));
+ memcpy(rankAddresses+info.rank, &info.extAddressListen, sizeof(union socketAddress));
++c;
TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks);
@@ -208,44 +143,46 @@ static void *bootstrapRoot(void* listenComm) {
// Send the connect handle for the next rank in the AllGather ring
for (int r=0; r<nranks; ++r) {
int next = (r+1) % nranks;
- void *tmpSendComm;
- NCCLCHECKGOTO(bootstrapNetConnect(0, rankHandlesRoot+r, &tmpSendComm), res, out);
- NCCLCHECKGOTO(bootstrapNetSend(tmpSendComm, rankHandles+next, sizeof(ncclNetHandle_t)), res, out);
- NCCLCHECKGOTO(bootstrapNetCloseSend(tmpSendComm), res, out);
+ int tmpSendFd;
+ NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out);
+ NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out);
+ close(tmpSendFd);
}
TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks);
out:
- bootstrapNetCloseListen(listenComm);
- if (rankHandles) free(rankHandles);
- if (rankHandlesRoot) free(rankHandlesRoot);
+ close(listenFd);
+ if (rankAddresses) free(rankAddresses);
+ if (rankAddressesRoot) free(rankAddressesRoot);
+ if (zero) free(zero);
TRACE(NCCL_INIT, "DONE");
return NULL;
}
ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) {
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
- void* listenComm;
- NCCLCHECK(bootstrapNetListen(idFromEnv ? dontCareIf : 0, netHandle, &listenComm));
+ union socketAddress* connectAddr = (union socketAddress*) id;
+ int listenFd;
+ NCCLCHECK(createListenSocket(&listenFd, connectAddr));
pthread_t thread;
- pthread_create(&thread, NULL, bootstrapRoot, listenComm);
+ pthread_create(&thread, NULL, bootstrapRoot, (void*)(uint64_t)listenFd);
return ncclSuccess;
}
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) {
- static_assert(sizeof(ncclNetHandle_t) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId");
+ static_assert(sizeof(union socketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId");
memset(id, 0, sizeof(ncclUniqueId));
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
+ union socketAddress* connectAddr = (union socketAddress*) id;
char* env = getenv("NCCL_COMM_ID");
if (env) {
INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env);
- if (bootstrapNetCreateHandle(netHandle, env) != 0) {
+ if (GetSocketAddrFromString(connectAddr, env) != ncclSuccess) {
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return ncclInvalidArgument;
}
} else {
+ memcpy(id, &bootstrapNetIfAddr, sizeof(union socketAddress));
NCCLCHECK(bootstrapCreateRoot(id, false));
}
@@ -254,24 +191,135 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) {
struct unexConn {
int peer;
- void* comm;
+ int fd;
struct unexConn* next;
};
+// Remote allocator state
+struct remAllocState {
+ int cudaDev;
+ int listenFd;
+ int stop;
+};
+
struct extState {
- void* extBstrapListenComm;
- void* extBstrapRingRecvComm;
- void* extBstrapRingSendComm;
- ncclNetHandle_t* peerBstrapHandles;
+ int extListenFd;
+ int extRingRecvFd;
+ int extRingSendFd;
+ union socketAddress* peerCommAddresses;
+ union socketAddress* peerAllocAddresses;
struct unexConn* unexpectedConnections;
+ int cudaDev;
int rank;
int nranks;
- int dev;
+
+ // Intermediate memory allocation service
+ struct remAllocState* allocState;
+ pthread_t allocThread;
};
+#define MAX_SEGMENTS 128
+
+static ncclResult_t remoteAlloc(void** ptr, int fd) {
+ size_t size;
+ NCCLCHECK(socketRecv(fd, &size, sizeof(size_t)));
+ cudaIpcMemHandle_t devIpc;
+ NCCLCHECK(ncclCudaCalloc((char**)ptr, size));
+ cudaError_t res = cudaIpcGetMemHandle(&devIpc, *ptr);
+ if (res != cudaSuccess) {
+ WARN("[Rem Allocator] cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res));
+ cudaFree(*ptr);
+ CUDACHECK(res);
+ }
+ // The CUDA IPC
+ NCCLCHECK(socketSend(fd, &devIpc, sizeof(cudaIpcMemHandle_t)));
+ // And the direct pointer
+ NCCLCHECK(socketSend(fd, ptr, sizeof(void*)));
+ return ncclSuccess;
+}
+
+#include <poll.h>
+
+// Service thread to allocate memory for other GPUs, used as intermediate step.
+void* ncclRemoteMemAllocationService(void* args) {
+ struct remAllocState* state = (struct remAllocState *) args;
+ if (cudaSetDevice(state->cudaDev) != cudaSuccess) {
+ WARN("[Rem Allocator] Failed to set CUDA device %d\n", state->cudaDev);
+ }
+
+ // Prepare poll descriptor
+ void* segments[MAX_SEGMENTS];
+ struct pollfd pollfds[MAX_SEGMENTS+1];
+ for (int s=0; s<MAX_SEGMENTS; s++) segments[s] = NULL;
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ pollfds[s].fd = -1;
+ pollfds[s].events = POLLHUP;
+ }
+ pollfds[MAX_SEGMENTS].fd = state->listenFd;
+ pollfds[MAX_SEGMENTS].events = POLLIN;
+
+ int nbuffers = 0;
+ while (state->stop == 0 || (state->stop == 1 && nbuffers > 0)) {
+ if (int error = poll(pollfds, MAX_SEGMENTS+1, 100/*ms*/) < 0) {
+ WARN("[Rem Allocator] Poll failed with error %d", error);
+ return NULL;
+ }
+ if (pollfds[MAX_SEGMENTS].revents) {
+ int s = 0;
+ while (segments[s] != NULL && s < MAX_SEGMENTS) s++;
+ if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) {
+ pollfds[s].fd = -1;
+ } else {
+ if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) {
+ WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd);
+ close(pollfds[s].fd);
+ pollfds[s].fd = -1;
+ } else {
+ nbuffers++;
+ }
+ }
+ }
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ if (pollfds[s].revents & POLLHUP) {
+ if (cudaFree(segments[s]) != cudaSuccess) {
+ WARN("[Rem Allocator] cudaFree %p failed", segments[s]);
+ }
+ segments[s] = NULL;
+ close(pollfds[s].fd);
+ pollfds[s].fd = -1;
+ nbuffers--;
+ }
+ }
+ }
+ for (int s=0; s<MAX_SEGMENTS; s++) {
+ if (segments[s]) cudaFree(segments[s]);
+ close(pollfds[s].fd);
+ }
+ close(state->listenFd);
+ free(state);
+ return NULL;
+}
+
+ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr) {
+ struct extState* state = (struct extState*)commState;
+ int fd;
+ ncclResult_t res;
+ *id = -1;
+ NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank));
+ NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end);
+ NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(cudaIpcMemHandle_t)), res, end);
+ NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end);
+ *id = fd;
+end:
+ return res;
+}
+
+ncclResult_t bootstrapRemFree(int id, int rank, void* commState) {
+ SYSCHECK(close(id), "close");
+ return ncclSuccess;
+}
+
ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {
- ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
- bool idFromEnv = getenv("NCCL_COMM_ID") != NULL;
struct extState* state;
NCCLCHECK(ncclCalloc(&state, 1));
state->rank = rank;
@@ -283,19 +331,15 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
struct extInfo info = { 0 };
info.rank = rank;
info.nranks = nranks;
- void *tmpSendComm, *tmpRecvComm;
- // Pass the remote address to listen via info
- if (idFromEnv) {
- memcpy(&info.extHandleListen, netHandle, sizeof(ncclNetHandle_t));
- memcpy(&info.extHandleListenRoot, netHandle, sizeof(ncclNetHandle_t));
- }
- // listen will return the local address via info (specify interface type 'findSubnetIf')
- state->dev = idFromEnv ? findSubnetIf : 0;
- void* extBstrapListenCommRoot;
- NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListen, &state->extBstrapListenComm));
- NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListenRoot, &extBstrapListenCommRoot));
+ int tmpSendFd, tmpRecvFd;
- // stagger connection times to avoid an overload of the root at very high rank counts
+ int extListenFdRoot;
+ memcpy(&info.extAddressListen, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ memcpy(&info.extAddressListenRoot, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ NCCLCHECK(createListenSocket(&state->extListenFd, &info.extAddressListen));
+ NCCLCHECK(createListenSocket(&extListenFdRoot, &info.extAddressListenRoot));
+
+ // stagger connection times to avoid an overload of the root
if (nranks > 128) {
long msec = rank;
struct timespec tv;
@@ -306,25 +350,35 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS
}
// send info on my listening socket to root
- NCCLCHECK(bootstrapNetConnect(state->dev, netHandle, &tmpSendComm));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, &info, sizeof(info)));
- NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));
+ union socketAddress* rootAddr = (union socketAddress*)id;
+ NCCLCHECK(connectAddress(&tmpSendFd, rootAddr));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info)));
+ close(tmpSendFd);
// get info on my "next" rank in the bootstrap ring from root
- ncclNetHandle_t extHandleNext;
- NCCLCHECK(bootstrapNetAccept(extBstrapListenCommRoot, &tmpRecvComm));
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext)));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
- NCCLCHECK(bootstrapNetCloseListen(extBstrapListenCommRoot));
+ union socketAddress extAddressNext;
+ NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext)));
+ close(tmpRecvFd);
+ close(extListenFdRoot);
- NCCLCHECK(bootstrapNetConnect(state->dev, &extHandleNext, &state->extBstrapRingSendComm));
+ NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext));
// Accept the connect request from the previous rank in the AllGather ring
- NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &state->extBstrapRingRecvComm));
+ NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd));
// AllGather all listen handlers
- NCCLCHECK(ncclCalloc(&state->peerBstrapHandles, nranks));
- memcpy(state->peerBstrapHandles+rank, info.extHandleListen, sizeof(ncclNetHandle_t));
- NCCLCHECK(bootstrapAllGather(state, state->peerBstrapHandles, sizeof(ncclNetHandle_t)));
+ NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks));
+ memcpy(state->peerCommAddresses+rank, &info.extAddressListen, sizeof(union socketAddress));
+ NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union socketAddress)));
+
+ // Create the memory allocation service
+ NCCLCHECK(ncclCalloc(&state->peerAllocAddresses, nranks));
+ memcpy(state->peerAllocAddresses+rank, &bootstrapNetIfAddr, sizeof(union socketAddress));
+ NCCLCHECK(ncclCalloc(&state->allocState, 1));
+ CUDACHECK(cudaGetDevice(&state->allocState->cudaDev));
+ NCCLCHECK(createListenSocket(&state->allocState->listenFd, state->peerAllocAddresses+rank));
+ pthread_create(&state->allocThread, NULL, ncclRemoteMemAllocationService, state->allocState);
+ NCCLCHECK(bootstrapAllGather(state, state->peerAllocAddresses, sizeof(union socketAddress)));
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
@@ -348,9 +402,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
size_t sslice = (rank - i + nranks) % nranks;
// Send slice to the right
- NCCLCHECK(bootstrapNetSend(state->extBstrapRingSendComm, data+sslice*size, size));
+ NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size));
// Recv slice from the left
- NCCLCHECK(bootstrapNetRecv(state->extBstrapRingRecvComm, data+rslice*size, size));
+ NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size));
}
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
@@ -359,20 +413,20 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
ncclResult_t bootstrapSend(void* commState, int peer, void* data, int size) {
struct extState* state = (struct extState*)commState;
- void* tmpSendComm;
- NCCLCHECK(bootstrapNetConnect(state->dev, state->peerBstrapHandles+peer, &tmpSendComm));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, &state->rank, sizeof(int)));
- NCCLCHECK(bootstrapNetSend(tmpSendComm, data, size));
- NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));
+ int tmpSendFd;
+ NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int)));
+ NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size));
+ close(tmpSendFd);
return ncclSuccess;
}
-ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) {
+ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int fd) {
// New unex
struct unexConn* unex;
NCCLCHECK(ncclCalloc(&unex, 1));
unex->peer = peer;
- unex->comm = comm;
+ unex->fd = fd;
// Enqueue
struct unexConn* list = state->unexpectedConnections;
@@ -385,7 +439,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) {
return ncclSuccess;
}
-void* unexpectedDequeue(struct extState* state, int peer) {
+int unexpectedDequeue(struct extState* state, int peer) {
struct unexConn* elem = state->unexpectedConnections;
struct unexConn* prev = NULL;
while (elem) {
@@ -395,41 +449,41 @@ void* unexpectedDequeue(struct extState* state, int peer) {
} else {
prev->next = elem->next;
}
- void* comm = elem->comm;
+ int fd = elem->fd;
free(elem);
- return comm;
+ return fd;
}
prev = elem;
elem = elem->next;
}
- return NULL;
+ return -1;
}
// We can't know who we'll receive from, so we need to receive everything at once
ncclResult_t bootstrapRecv(void* commState, int peer, void* data, int size) {
struct extState* state = (struct extState*)commState;
- void* tmpRecvComm;
+ int tmpRecvFd;
// Search unexpected connections first
- if ((tmpRecvComm = unexpectedDequeue(state, peer)) != NULL) {
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
+ if ((tmpRecvFd = unexpectedDequeue(state, peer)) != -1) {
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
+ close(tmpRecvFd);
return ncclSuccess;
}
// Then look for new connections
while (1) {
- NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &tmpRecvComm));
+ NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd));
int newPeer;
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &newPeer, sizeof(int)));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int)));
if (newPeer == peer) {
- NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size));
- NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
+ NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size));
+ close(tmpRecvFd);
return ncclSuccess;
}
// Unexpected connection. Save for later.
- NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvComm));
+ NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvFd));
}
}
@@ -439,11 +493,17 @@ ncclResult_t bootstrapClose(void* commState) {
WARN("Unexpected connections are not empty.\n");
return ncclInternalError;
}
- NCCLCHECK(bootstrapNetCloseListen(state->extBstrapListenComm));
- NCCLCHECK(bootstrapNetCloseSend(state->extBstrapRingSendComm));
- NCCLCHECK(bootstrapNetCloseRecv(state->extBstrapRingRecvComm));
+ close(state->extListenFd);
+ close(state->extRingSendFd);
+ close(state->extRingRecvFd);
+
+ state->allocState->stop = 1;
+
+ // Join the allocThread so we catch resource leaks as being hung here
+ // pthread_join(state->allocThread, nullptr);
- free(state->peerBstrapHandles);
+ free(state->peerCommAddresses);
+ free(state->peerAllocAddresses);
free(state);
return ncclSuccess;
@@ -451,10 +511,12 @@ ncclResult_t bootstrapClose(void* commState) {
ncclResult_t bootstrapAbort(void* commState) {
struct extState* state = (struct extState*)commState;
- bootstrapNetCloseListen(state->extBstrapListenComm);
- bootstrapNetCloseSend(state->extBstrapRingSendComm);
- bootstrapNetCloseRecv(state->extBstrapRingRecvComm);
- free(state->peerBstrapHandles);
+ close(state->extListenFd);
+ close(state->extRingSendFd);
+ close(state->extRingRecvFd);
+ state->allocState->stop = 2;
+ free(state->peerCommAddresses);
+ free(state->peerAllocAddresses);
free(state);
return ncclSuccess;
}