diff options
Diffstat (limited to 'src/bootstrap.cc')
-rw-r--r-- | src/bootstrap.cc | 450 |
1 files changed, 256 insertions, 194 deletions
diff --git a/src/bootstrap.cc b/src/bootstrap.cc index e90dd66..bd6ec99 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -13,144 +13,77 @@ #include <unistd.h> #include <sys/types.h> -struct bootstrapNetComm { - int fd; -}; - /* Init functions */ -static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS]; -static union socketAddress bootstrapNetIfAddrs[MAX_IFS]; -static int bootstrapNetIfs = -1; +static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1]; +static union socketAddress bootstrapNetIfAddr; +static int bootstrapNetInitDone = 0; pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; ncclResult_t bootstrapNetInit() { - if (bootstrapNetIfs == -1) { + if (bootstrapNetInitDone == 0) { pthread_mutex_lock(&bootstrapNetLock); - if (bootstrapNetIfs == -1) { - bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS); - if (bootstrapNetIfs <= 0) { - WARN("Bootstrap : no socket interface found"); - return ncclInternalError; + if (bootstrapNetInitDone == 0) { + char* env = getenv("NCCL_COMM_ID"); + if (env) { + union socketAddress remoteAddr; + if (GetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) { + WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>"); + return ncclInvalidArgument; + } + if (findInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + WARN("NET/Socket : No usable listening interface found"); + return ncclSystemError; + } } else { - char line[1024]; - char addrline[1024]; - line[0] = '\0'; - for (int i=0; i<bootstrapNetIfs; i++) { - snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE, - socketToString(&bootstrapNetIfAddrs[i].sa, addrline)); + int nIfs = findInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1); + if (nIfs <= 0) { + WARN("Bootstrap : no socket interface found"); + return ncclInternalError; } - line[1023] = '\0'; - INFO(NCCL_INIT, "Bootstrap : Using%s", line); } + char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2]; + sprintf(line, " %s:", bootstrapNetIfName); + socketToString(&bootstrapNetIfAddr.sa, line+strlen(line)); + INFO(NCCL_INIT, "Bootstrap : Using%s", line); + bootstrapNetInitDone = 1; } pthread_mutex_unlock(&bootstrapNetLock); } return ncclSuccess; } -static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) { - NCCLCHECK(ncclCalloc(comm, 1)); - (*comm)->fd = -1; - return ncclSuccess; -} - -static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) { - if (dev >= bootstrapNetIfs) return ncclInternalError; - memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr)); - return ncclSuccess; -} - /* Socket Interface Selection type */ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; -static ncclResult_t bootstrapNetListen(int dev, ncclNetHandle_t* netHandle, void** listenComm) { - union socketAddress* connectAddr = (union socketAddress*) netHandle; - static_assert(sizeof(union socketAddress) < NCCL_NET_HANDLE_MAXSIZE, "union socketAddress size is too large"); - // if dev >= 0, listen based on dev - if (dev >= 0) { - NCCLCHECK(bootstrapNetGetSocketAddr(dev, connectAddr)); - } else if (dev == findSubnetIf) { - // handle stores a remote address - // need to find a local addr that is in the same network as the remote addr - union socketAddress localAddr; - char ifName[MAX_IF_NAME_SIZE]; - if (findInterfaceMatchSubnet(ifName, &localAddr, connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) { - WARN("NET/Socket : No usable listening interface found"); - return ncclSystemError; - } - // pass the local address back - memcpy(connectAddr, &localAddr, sizeof(localAddr)); - } // Otherwise, handle stores a local address - struct bootstrapNetComm* comm; - NCCLCHECK(bootstrapNetNewComm(&comm)); - NCCLCHECK(createListenSocket(&comm->fd, connectAddr)); - *listenComm = comm; - return ncclSuccess; -} - -static ncclResult_t bootstrapNetConnect(int dev, ncclNetHandle_t* netHandle, void** sendComm) { - union socketAddress* connectAddr = (union socketAddress*) netHandle; - struct bootstrapNetComm* comm; - NCCLCHECK(bootstrapNetNewComm(&comm)); - NCCLCHECK(connectAddress(&comm->fd, connectAddr)); - *sendComm = comm; - return ncclSuccess; -} - -static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { - struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm; - struct bootstrapNetComm* rComm; - NCCLCHECK(bootstrapNetNewComm(&rComm)); +static ncclResult_t bootstrapNetAccept(int listenFd, int* recvFd) { struct sockaddr_in sockaddr; socklen_t socklen = sizeof(struct sockaddr_in); - SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd); - *recvComm = rComm; - return ncclSuccess; -} - -static ncclResult_t bootstrapNetClose(void* opaqueComm) { - struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm; - if (comm) { - close(comm->fd); - free(comm); - } + SYSCHECKVAL(accept(listenFd, (struct sockaddr*)&sockaddr, &socklen), "accept", *recvFd); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; } - // Additional sync functions -static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) { - struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm; - NCCLCHECK(socketSend(comm->fd, &size, sizeof(int))); - NCCLCHECK(socketSend(comm->fd, data, size)); +static ncclResult_t bootstrapNetSend(int fd, void* data, int size) { + NCCLCHECK(socketSend(fd, &size, sizeof(int))); + NCCLCHECK(socketSend(fd, data, size)); return ncclSuccess; } -static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) { - struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm; +static ncclResult_t bootstrapNetRecv(int fd, void* data, int size) { int recvSize; - NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int))); + NCCLCHECK(socketRecv(fd, &recvSize, sizeof(int))); if (recvSize > size) { WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size); return ncclInternalError; } - NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size))); - return ncclSuccess; -} - -ncclResult_t bootstrapNetCreateHandle(ncclNetHandle_t* netHandle, const char* str) { - union socketAddress* connectAddr = (union socketAddress*) netHandle; - NCCLCHECK(GetSocketAddrFromString(connectAddr, str)); + NCCLCHECK(socketRecv(fd, data, std::min(recvSize, size))); return ncclSuccess; } struct extInfo { int rank; int nranks; - ncclNetHandle_t extHandleListenRoot; - ncclNetHandle_t extHandleListen; + union socketAddress extAddressListenRoot; + union socketAddress extAddressListen; }; #include <sys/resource.h> @@ -163,27 +96,29 @@ static ncclResult_t setFilesLimit() { return ncclSuccess; } -static void *bootstrapRoot(void* listenComm) { +static void *bootstrapRoot(void* args) { + int listenFd = (uint64_t)args; + ncclResult_t res = ncclSuccess; + int nranks = 0, c = 0; struct extInfo info; - ncclNetHandle_t *rankHandles = NULL; - ncclNetHandle_t *rankHandlesRoot = NULL; // for initial rank <-> root information exchange - ncclNetHandle_t zero = { 0 }; // for sanity checking - void* tmpComm; - ncclResult_t res; + union socketAddress *rankAddresses = NULL; + union socketAddress *rankAddressesRoot = NULL; // for initial rank <-> root information exchange + union socketAddress *zero = NULL; + NCCLCHECKGOTO(ncclCalloc(&zero, 1), res, out); setFilesLimit(); TRACE(NCCL_INIT, "BEGIN"); /* Receive addresses from all ranks */ - int nranks = 0, c = 0; do { - NCCLCHECKGOTO(bootstrapNetAccept(listenComm, &tmpComm), res, out); - NCCLCHECKGOTO(bootstrapNetRecv(tmpComm, &info, sizeof(info)), res, out); - NCCLCHECKGOTO(bootstrapNetCloseRecv(tmpComm), res, out); + int tmpFd; + NCCLCHECKGOTO(bootstrapNetAccept(listenFd, &tmpFd), res, out); + NCCLCHECKGOTO(bootstrapNetRecv(tmpFd, &info, sizeof(info)), res, out); + close(tmpFd); if (c == 0) { nranks = info.nranks; - NCCLCHECKGOTO(ncclCalloc(&rankHandles, nranks), res, out); - NCCLCHECKGOTO(ncclCalloc(&rankHandlesRoot, nranks), res, out); + NCCLCHECKGOTO(ncclCalloc(&rankAddresses, nranks), res, out); + NCCLCHECKGOTO(ncclCalloc(&rankAddressesRoot, nranks), res, out); } if (nranks != info.nranks) { @@ -191,14 +126,14 @@ static void *bootstrapRoot(void* listenComm) { goto out; } - if (memcmp(&zero, &rankHandlesRoot[info.rank], sizeof(ncclNetHandle_t)) != 0) { + if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union socketAddress)) != 0) { WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks); goto out; } // Save the connection handle for that rank - memcpy(rankHandlesRoot+info.rank, info.extHandleListenRoot, sizeof(ncclNetHandle_t)); - memcpy(rankHandles+info.rank, info.extHandleListen, sizeof(ncclNetHandle_t)); + memcpy(rankAddressesRoot+info.rank, &info.extAddressListenRoot, sizeof(union socketAddress)); + memcpy(rankAddresses+info.rank, &info.extAddressListen, sizeof(union socketAddress)); ++c; TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks); @@ -208,44 +143,46 @@ static void *bootstrapRoot(void* listenComm) { // Send the connect handle for the next rank in the AllGather ring for (int r=0; r<nranks; ++r) { int next = (r+1) % nranks; - void *tmpSendComm; - NCCLCHECKGOTO(bootstrapNetConnect(0, rankHandlesRoot+r, &tmpSendComm), res, out); - NCCLCHECKGOTO(bootstrapNetSend(tmpSendComm, rankHandles+next, sizeof(ncclNetHandle_t)), res, out); - NCCLCHECKGOTO(bootstrapNetCloseSend(tmpSendComm), res, out); + int tmpSendFd; + NCCLCHECKGOTO(connectAddress(&tmpSendFd, rankAddressesRoot+r), res, out); + NCCLCHECKGOTO(bootstrapNetSend(tmpSendFd, rankAddresses+next, sizeof(union socketAddress)), res, out); + close(tmpSendFd); } TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks); out: - bootstrapNetCloseListen(listenComm); - if (rankHandles) free(rankHandles); - if (rankHandlesRoot) free(rankHandlesRoot); + close(listenFd); + if (rankAddresses) free(rankAddresses); + if (rankAddressesRoot) free(rankAddressesRoot); + if (zero) free(zero); TRACE(NCCL_INIT, "DONE"); return NULL; } ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) { - ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id; - void* listenComm; - NCCLCHECK(bootstrapNetListen(idFromEnv ? dontCareIf : 0, netHandle, &listenComm)); + union socketAddress* connectAddr = (union socketAddress*) id; + int listenFd; + NCCLCHECK(createListenSocket(&listenFd, connectAddr)); pthread_t thread; - pthread_create(&thread, NULL, bootstrapRoot, listenComm); + pthread_create(&thread, NULL, bootstrapRoot, (void*)(uint64_t)listenFd); return ncclSuccess; } ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) { - static_assert(sizeof(ncclNetHandle_t) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId"); + static_assert(sizeof(union socketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId"); memset(id, 0, sizeof(ncclUniqueId)); - ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id; + union socketAddress* connectAddr = (union socketAddress*) id; char* env = getenv("NCCL_COMM_ID"); if (env) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); - if (bootstrapNetCreateHandle(netHandle, env) != 0) { + if (GetSocketAddrFromString(connectAddr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>"); return ncclInvalidArgument; } } else { + memcpy(id, &bootstrapNetIfAddr, sizeof(union socketAddress)); NCCLCHECK(bootstrapCreateRoot(id, false)); } @@ -254,24 +191,135 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) { struct unexConn { int peer; - void* comm; + int fd; struct unexConn* next; }; +// Remote allocator state +struct remAllocState { + int cudaDev; + int listenFd; + int stop; +}; + struct extState { - void* extBstrapListenComm; - void* extBstrapRingRecvComm; - void* extBstrapRingSendComm; - ncclNetHandle_t* peerBstrapHandles; + int extListenFd; + int extRingRecvFd; + int extRingSendFd; + union socketAddress* peerCommAddresses; + union socketAddress* peerAllocAddresses; struct unexConn* unexpectedConnections; + int cudaDev; int rank; int nranks; - int dev; + + // Intermediate memory allocation service + struct remAllocState* allocState; + pthread_t allocThread; }; +#define MAX_SEGMENTS 128 + +static ncclResult_t remoteAlloc(void** ptr, int fd) { + size_t size; + NCCLCHECK(socketRecv(fd, &size, sizeof(size_t))); + cudaIpcMemHandle_t devIpc; + NCCLCHECK(ncclCudaCalloc((char**)ptr, size)); + cudaError_t res = cudaIpcGetMemHandle(&devIpc, *ptr); + if (res != cudaSuccess) { + WARN("[Rem Allocator] cudaIpcGetMemHandle failed : %s", cudaGetErrorString(res)); + cudaFree(*ptr); + CUDACHECK(res); + } + // The CUDA IPC + NCCLCHECK(socketSend(fd, &devIpc, sizeof(cudaIpcMemHandle_t))); + // And the direct pointer + NCCLCHECK(socketSend(fd, ptr, sizeof(void*))); + return ncclSuccess; +} + +#include <poll.h> + +// Service thread to allocate memory for other GPUs, used as intermediate step. +void* ncclRemoteMemAllocationService(void* args) { + struct remAllocState* state = (struct remAllocState *) args; + if (cudaSetDevice(state->cudaDev) != cudaSuccess) { + WARN("[Rem Allocator] Failed to set CUDA device %d\n", state->cudaDev); + } + + // Prepare poll descriptor + void* segments[MAX_SEGMENTS]; + struct pollfd pollfds[MAX_SEGMENTS+1]; + for (int s=0; s<MAX_SEGMENTS; s++) segments[s] = NULL; + for (int s=0; s<MAX_SEGMENTS; s++) { + pollfds[s].fd = -1; + pollfds[s].events = POLLHUP; + } + pollfds[MAX_SEGMENTS].fd = state->listenFd; + pollfds[MAX_SEGMENTS].events = POLLIN; + + int nbuffers = 0; + while (state->stop == 0 || (state->stop == 1 && nbuffers > 0)) { + if (int error = poll(pollfds, MAX_SEGMENTS+1, 100/*ms*/) < 0) { + WARN("[Rem Allocator] Poll failed with error %d", error); + return NULL; + } + if (pollfds[MAX_SEGMENTS].revents) { + int s = 0; + while (segments[s] != NULL && s < MAX_SEGMENTS) s++; + if (bootstrapNetAccept(pollfds[MAX_SEGMENTS].fd, &pollfds[s].fd) != ncclSuccess) { + pollfds[s].fd = -1; + } else { + if (s == MAX_SEGMENTS || (remoteAlloc(segments+s, pollfds[s].fd) != ncclSuccess)) { + WARN("[Rem Allocator] Allocation failed (segment %d, fd %d)", s, pollfds[s].fd); + close(pollfds[s].fd); + pollfds[s].fd = -1; + } else { + nbuffers++; + } + } + } + for (int s=0; s<MAX_SEGMENTS; s++) { + if (pollfds[s].revents & POLLHUP) { + if (cudaFree(segments[s]) != cudaSuccess) { + WARN("[Rem Allocator] cudaFree %p failed", segments[s]); + } + segments[s] = NULL; + close(pollfds[s].fd); + pollfds[s].fd = -1; + nbuffers--; + } + } + } + for (int s=0; s<MAX_SEGMENTS; s++) { + if (segments[s]) cudaFree(segments[s]); + close(pollfds[s].fd); + } + close(state->listenFd); + free(state); + return NULL; +} + +ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, cudaIpcMemHandle_t* ipc, void** ptr) { + struct extState* state = (struct extState*)commState; + int fd; + ncclResult_t res; + *id = -1; + NCCLCHECK(connectAddress(&fd, state->peerAllocAddresses+rank)); + NCCLCHECKGOTO(socketSend(fd, &size, sizeof(size_t)), res, end); + NCCLCHECKGOTO(socketRecv(fd, ipc, sizeof(cudaIpcMemHandle_t)), res, end); + NCCLCHECKGOTO(socketRecv(fd, ptr, sizeof(void*)), res, end); + *id = fd; +end: + return res; +} + +ncclResult_t bootstrapRemFree(int id, int rank, void* commState) { + SYSCHECK(close(id), "close"); + return ncclSuccess; +} + ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) { - ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id; - bool idFromEnv = getenv("NCCL_COMM_ID") != NULL; struct extState* state; NCCLCHECK(ncclCalloc(&state, 1)); state->rank = rank; @@ -283,19 +331,15 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS struct extInfo info = { 0 }; info.rank = rank; info.nranks = nranks; - void *tmpSendComm, *tmpRecvComm; - // Pass the remote address to listen via info - if (idFromEnv) { - memcpy(&info.extHandleListen, netHandle, sizeof(ncclNetHandle_t)); - memcpy(&info.extHandleListenRoot, netHandle, sizeof(ncclNetHandle_t)); - } - // listen will return the local address via info (specify interface type 'findSubnetIf') - state->dev = idFromEnv ? findSubnetIf : 0; - void* extBstrapListenCommRoot; - NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListen, &state->extBstrapListenComm)); - NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListenRoot, &extBstrapListenCommRoot)); + int tmpSendFd, tmpRecvFd; - // stagger connection times to avoid an overload of the root at very high rank counts + int extListenFdRoot; + memcpy(&info.extAddressListen, &bootstrapNetIfAddr, sizeof(union socketAddress)); + memcpy(&info.extAddressListenRoot, &bootstrapNetIfAddr, sizeof(union socketAddress)); + NCCLCHECK(createListenSocket(&state->extListenFd, &info.extAddressListen)); + NCCLCHECK(createListenSocket(&extListenFdRoot, &info.extAddressListenRoot)); + + // stagger connection times to avoid an overload of the root if (nranks > 128) { long msec = rank; struct timespec tv; @@ -306,25 +350,35 @@ ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commS } // send info on my listening socket to root - NCCLCHECK(bootstrapNetConnect(state->dev, netHandle, &tmpSendComm)); - NCCLCHECK(bootstrapNetSend(tmpSendComm, &info, sizeof(info))); - NCCLCHECK(bootstrapNetCloseSend(tmpSendComm)); + union socketAddress* rootAddr = (union socketAddress*)id; + NCCLCHECK(connectAddress(&tmpSendFd, rootAddr)); + NCCLCHECK(bootstrapNetSend(tmpSendFd, &info, sizeof(info))); + close(tmpSendFd); // get info on my "next" rank in the bootstrap ring from root - ncclNetHandle_t extHandleNext; - NCCLCHECK(bootstrapNetAccept(extBstrapListenCommRoot, &tmpRecvComm)); - NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext))); - NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm)); - NCCLCHECK(bootstrapNetCloseListen(extBstrapListenCommRoot)); + union socketAddress extAddressNext; + NCCLCHECK(bootstrapNetAccept(extListenFdRoot, &tmpRecvFd)); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &extAddressNext, sizeof(extAddressNext))); + close(tmpRecvFd); + close(extListenFdRoot); - NCCLCHECK(bootstrapNetConnect(state->dev, &extHandleNext, &state->extBstrapRingSendComm)); + NCCLCHECK(connectAddress(&state->extRingSendFd, &extAddressNext)); // Accept the connect request from the previous rank in the AllGather ring - NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &state->extBstrapRingRecvComm)); + NCCLCHECK(bootstrapNetAccept(state->extListenFd, &state->extRingRecvFd)); // AllGather all listen handlers - NCCLCHECK(ncclCalloc(&state->peerBstrapHandles, nranks)); - memcpy(state->peerBstrapHandles+rank, info.extHandleListen, sizeof(ncclNetHandle_t)); - NCCLCHECK(bootstrapAllGather(state, state->peerBstrapHandles, sizeof(ncclNetHandle_t))); + NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); + memcpy(state->peerCommAddresses+rank, &info.extAddressListen, sizeof(union socketAddress)); + NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union socketAddress))); + + // Create the memory allocation service + NCCLCHECK(ncclCalloc(&state->peerAllocAddresses, nranks)); + memcpy(state->peerAllocAddresses+rank, &bootstrapNetIfAddr, sizeof(union socketAddress)); + NCCLCHECK(ncclCalloc(&state->allocState, 1)); + CUDACHECK(cudaGetDevice(&state->allocState->cudaDev)); + NCCLCHECK(createListenSocket(&state->allocState->listenFd, state->peerAllocAddresses+rank)); + pthread_create(&state->allocThread, NULL, ncclRemoteMemAllocationService, state->allocState); + NCCLCHECK(bootstrapAllGather(state, state->peerAllocAddresses, sizeof(union socketAddress))); TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); @@ -348,9 +402,9 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { size_t sslice = (rank - i + nranks) % nranks; // Send slice to the right - NCCLCHECK(bootstrapNetSend(state->extBstrapRingSendComm, data+sslice*size, size)); + NCCLCHECK(bootstrapNetSend(state->extRingSendFd, data+sslice*size, size)); // Recv slice from the left - NCCLCHECK(bootstrapNetRecv(state->extBstrapRingRecvComm, data+rslice*size, size)); + NCCLCHECK(bootstrapNetRecv(state->extRingRecvFd, data+rslice*size, size)); } TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); @@ -359,20 +413,20 @@ ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { ncclResult_t bootstrapSend(void* commState, int peer, void* data, int size) { struct extState* state = (struct extState*)commState; - void* tmpSendComm; - NCCLCHECK(bootstrapNetConnect(state->dev, state->peerBstrapHandles+peer, &tmpSendComm)); - NCCLCHECK(bootstrapNetSend(tmpSendComm, &state->rank, sizeof(int))); - NCCLCHECK(bootstrapNetSend(tmpSendComm, data, size)); - NCCLCHECK(bootstrapNetCloseSend(tmpSendComm)); + int tmpSendFd; + NCCLCHECK(connectAddress(&tmpSendFd, state->peerCommAddresses+peer)); + NCCLCHECK(bootstrapNetSend(tmpSendFd, &state->rank, sizeof(int))); + NCCLCHECK(bootstrapNetSend(tmpSendFd, data, size)); + close(tmpSendFd); return ncclSuccess; } -ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) { +ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int fd) { // New unex struct unexConn* unex; NCCLCHECK(ncclCalloc(&unex, 1)); unex->peer = peer; - unex->comm = comm; + unex->fd = fd; // Enqueue struct unexConn* list = state->unexpectedConnections; @@ -385,7 +439,7 @@ ncclResult_t unexpectedEnqueue(struct extState* state, int peer, void* comm) { return ncclSuccess; } -void* unexpectedDequeue(struct extState* state, int peer) { +int unexpectedDequeue(struct extState* state, int peer) { struct unexConn* elem = state->unexpectedConnections; struct unexConn* prev = NULL; while (elem) { @@ -395,41 +449,41 @@ void* unexpectedDequeue(struct extState* state, int peer) { } else { prev->next = elem->next; } - void* comm = elem->comm; + int fd = elem->fd; free(elem); - return comm; + return fd; } prev = elem; elem = elem->next; } - return NULL; + return -1; } // We can't know who we'll receive from, so we need to receive everything at once ncclResult_t bootstrapRecv(void* commState, int peer, void* data, int size) { struct extState* state = (struct extState*)commState; - void* tmpRecvComm; + int tmpRecvFd; // Search unexpected connections first - if ((tmpRecvComm = unexpectedDequeue(state, peer)) != NULL) { - NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size)); - NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm)); + if ((tmpRecvFd = unexpectedDequeue(state, peer)) != -1) { + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size)); + close(tmpRecvFd); return ncclSuccess; } // Then look for new connections while (1) { - NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &tmpRecvComm)); + NCCLCHECK(bootstrapNetAccept(state->extListenFd, &tmpRecvFd)); int newPeer; - NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &newPeer, sizeof(int))); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, &newPeer, sizeof(int))); if (newPeer == peer) { - NCCLCHECK(bootstrapNetRecv(tmpRecvComm, ((char*)data), size)); - NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm)); + NCCLCHECK(bootstrapNetRecv(tmpRecvFd, ((char*)data), size)); + close(tmpRecvFd); return ncclSuccess; } // Unexpected connection. Save for later. - NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvComm)); + NCCLCHECK(unexpectedEnqueue(state, newPeer, tmpRecvFd)); } } @@ -439,11 +493,17 @@ ncclResult_t bootstrapClose(void* commState) { WARN("Unexpected connections are not empty.\n"); return ncclInternalError; } - NCCLCHECK(bootstrapNetCloseListen(state->extBstrapListenComm)); - NCCLCHECK(bootstrapNetCloseSend(state->extBstrapRingSendComm)); - NCCLCHECK(bootstrapNetCloseRecv(state->extBstrapRingRecvComm)); + close(state->extListenFd); + close(state->extRingSendFd); + close(state->extRingRecvFd); + + state->allocState->stop = 1; + + // Join the allocThread so we catch resource leaks as being hung here + // pthread_join(state->allocThread, nullptr); - free(state->peerBstrapHandles); + free(state->peerCommAddresses); + free(state->peerAllocAddresses); free(state); return ncclSuccess; @@ -451,10 +511,12 @@ ncclResult_t bootstrapClose(void* commState) { ncclResult_t bootstrapAbort(void* commState) { struct extState* state = (struct extState*)commState; - bootstrapNetCloseListen(state->extBstrapListenComm); - bootstrapNetCloseSend(state->extBstrapRingSendComm); - bootstrapNetCloseRecv(state->extBstrapRingRecvComm); - free(state->peerBstrapHandles); + close(state->extListenFd); + close(state->extRingSendFd); + close(state->extRingRecvFd); + state->allocState->stop = 2; + free(state->peerCommAddresses); + free(state->peerAllocAddresses); free(state); return ncclSuccess; } |