diff options
author | Ke Wen <kwen@nvidia.com> | 2019-06-25 23:22:47 +0300 |
---|---|---|
committer | Ke Wen <kwen@nvidia.com> | 2019-06-25 23:22:47 +0300 |
commit | 7c72dee660e4d055b81721dd6b03e4e1c0a983cf (patch) | |
tree | 7e5c478c25f9877f721efd14a908dd6b0dc6a3b4 | |
parent | 0ceaec9cee96ae7658aa45686853286651f36384 (diff) |
2.4.8-1
Fix #209: improve socket transport performance
Split transfers over multiple sockets
Launch multiple threads to drive sockets
Detect AWS NICs and set nsockets/nthreads accordingly
-rw-r--r-- | makefiles/version.mk | 2 | ||||
-rw-r--r-- | src/bootstrap.cc | 152 | ||||
-rw-r--r-- | src/include/bootstrap.h | 1 | ||||
-rw-r--r-- | src/include/net.h | 6 | ||||
-rw-r--r-- | src/include/socket.h | 21 | ||||
-rw-r--r-- | src/init.cc | 5 | ||||
-rw-r--r-- | src/transport/net_socket.cc | 334 |
7 files changed, 425 insertions, 96 deletions
diff --git a/makefiles/version.mk b/makefiles/version.mk index 8341f33..bab58ec 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 NCCL_MINOR := 4 -NCCL_PATCH := 7 +NCCL_PATCH := 8 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/bootstrap.cc b/src/bootstrap.cc index 9df38e4..d7c2ac6 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -9,37 +9,145 @@ #include "utils.h" #include "bootstrap.h" #include "net.h" +#include "socket.h" #include <unistd.h> #include <sys/types.h> // Always use sockets for bootstrap -ncclNet_t* ncclBootstrapNet = &ncclNetSocket; +struct bootstrapNetHandle { + union socketAddress connectAddr; +}; + +struct bootstrapNetComm { + int fd; +}; + +/* Init functions */ +static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS]; +static union socketAddress bootstrapNetIfAddrs[MAX_IFS]; +static int bootstrapNetIfs = -1; +pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; + +ncclResult_t bootstrapNetInit() { + if (bootstrapNetIfs == -1) { + pthread_mutex_lock(&bootstrapNetLock); + if (bootstrapNetIfs == -1) { + bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS); + if (bootstrapNetIfs <= 0) { + WARN("Bootstrap : no socket interface found"); + return ncclInternalError; + } else { + char line[1024]; + char addrline[1024]; + line[0] = '\0'; + for (int i=0; i<bootstrapNetIfs; i++) { + snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE, + socketToString(&bootstrapNetIfAddrs[i].sa, addrline)); + } + line[1023] = '\0'; + INFO(NCCL_INIT, "Bootstrap : Using%s", line); + } + } + pthread_mutex_unlock(&bootstrapNetLock); + } + return ncclSuccess; +} + +static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) { + NCCLCHECK(ncclCalloc(comm, 1)); + (*comm)->fd = -1; + return ncclSuccess; +} + +static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) { + if (dev >= bootstrapNetIfs) return ncclInternalError; + memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr)); + return ncclSuccess; +} + +/* Socket Interface Selection type */ +enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; + +static ncclResult_t bootstrapNetListen(int dev, void* opaqueHandle, void** listenComm) { + struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle; + static_assert(sizeof(struct bootstrapNetHandle) < NCCL_NET_HANDLE_MAXSIZE, "bootstrapNetHandle size too large"); + // if dev >= 0, listen based on dev + if (dev >= 0) { + NCCLCHECK(bootstrapNetGetSocketAddr(dev, &(handle->connectAddr))); + } else if (dev == findSubnetIf) { + // handle stores a remote address + // need to find a local addr that is in the same network as the remote addr + union socketAddress localAddr; + char ifName[MAX_IF_NAME_SIZE]; + if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + WARN("NET/Socket : No usable listening interface found"); + return ncclSystemError; + } + // pass the local address back + memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr)); + } // Otherwise, handle stores a local address + struct bootstrapNetComm* comm; + NCCLCHECK(bootstrapNetNewComm(&comm)); + NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr)); + *listenComm = comm; + return ncclSuccess; +} + +static ncclResult_t bootstrapNetConnect(int dev, void* opaqueHandle, void** sendComm) { + struct bootstrapNetComm* comm; + NCCLCHECK(bootstrapNetNewComm(&comm)); + struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle; + NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr)); + *sendComm = comm; + return ncclSuccess; +} + +static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { + struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm; + struct bootstrapNetComm* rComm; + NCCLCHECK(bootstrapNetNewComm(&rComm)); + struct sockaddr_in sockaddr; + socklen_t socklen = sizeof(struct sockaddr_in); + SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd); + *recvComm = rComm; + return ncclSuccess; +} -static ncclResult_t bootstrapNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclBootstrapNet->listen(dev, handle, listenComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclBootstrapNet->connect(dev, handle, sendComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclBootstrapNet->accept(listenComm, recvComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclBootstrapNet->test(request, done, size)); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(ncclBootstrapNet->closeSend(sendComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(ncclBootstrapNet->closeRecv(recvComm)); return ncclSuccess; } -static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(ncclBootstrapNet->closeListen(listenComm)); return ncclSuccess; } +static ncclResult_t bootstrapNetClose(void* opaqueComm) { + struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm; + if (comm) { + close(comm->fd); + free(comm); + } + return ncclSuccess; +} -// Additional sync functions based on async + test for bootstrap, using host ptrs. +static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; } +static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; } +static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; } + +// Additional sync functions static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) { - void* request, *mhandle; - NCCLCHECK(ncclBootstrapNet->regMr(sendComm, data, size, NCCL_PTR_HOST, &mhandle)); - NCCLCHECK(ncclBootstrapNet->isend(sendComm, data, size, mhandle, &request)); - NCCLCHECK(ncclBootstrapNet->deregMr(sendComm, mhandle)); - int done = 0; - while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL)); + struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm; + NCCLCHECK(socketSend(comm->fd, &size, sizeof(int))); + NCCLCHECK(socketSend(comm->fd, data, size)); return ncclSuccess; } static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) { - void* request, *mhandle; - NCCLCHECK(ncclBootstrapNet->regMr(recvComm, data, size, NCCL_PTR_HOST, &mhandle)); - NCCLCHECK(ncclBootstrapNet->irecv(recvComm, data, size, mhandle, &request)); - NCCLCHECK(ncclBootstrapNet->deregMr(recvComm, mhandle)); - int done = 0; - while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL)); + struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm; + int recvSize; + NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int))); + if (recvSize > size) { + WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size); + return ncclInternalError; + } + NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size))); + return ncclSuccess; +} + +ncclResult_t bootstrapNetCreateHandle(void* opaqueHandle, const char* str) { + struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle; + NCCLCHECK(GetSocketAddrFromString(&handle->connectAddr, str)); return ncclSuccess; } @@ -148,7 +256,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) { char* env = getenv("NCCL_COMM_ID"); if (env) { - if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) { + if (bootstrapNetCreateHandle(&id->extHandleRoot, env) != 0) { WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>"); return ncclInvalidArgument; } diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index dd7de2c..dacbc7c 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -9,6 +9,7 @@ #include "nccl.h" +ncclResult_t bootstrapNetInit(); ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv); ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out); ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState); diff --git a/src/include/net.h b/src/include/net.h index da3ecea..950b5e5 100644 --- a/src/include/net.h +++ b/src/include/net.h @@ -13,11 +13,6 @@ extern ncclNet_t* ncclNet; typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE]; -/* Socket Interface Selection type */ -typedef enum { findSubnetIf = -1, - dontCareIf = -2 -} ncclSocketIfSl_t; - // Translation to external API static const char* ncclNetName() { return ncclNet->name; } static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; } @@ -36,7 +31,6 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; } static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; } -extern ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str); extern ncclNet_t ncclNetIb; extern ncclNet_t ncclNetSocket; diff --git a/src/include/socket.h b/src/include/socket.h index 739c0c4..8197a65 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -42,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) { return buf; } -static inline short socketToPort(struct sockaddr *saddr) { +static inline uint16_t socketToPort(struct sockaddr *saddr) { return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port); } @@ -161,7 +161,10 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) { } static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) { - char line[1024], line_a[1024]; +#ifdef ENABLE_TRACE + char line[1024]; +#endif + char line_a[1024]; int found = 0; struct ifaddrs *interfaces, *interface; getifaddrs(&interfaces); @@ -185,7 +188,7 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd // Store the interface name strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize); - INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a)); + TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a)); found++; if (found == maxIfs) break; } @@ -390,12 +393,12 @@ retry: #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 -static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { +static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) { int bytes = 0; char* data = (char*)ptr; do { - if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); - if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); + if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); + if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); if (op == NCCL_SOCKET_RECV && bytes == 0) { WARN("Net : Connection closed by remote peer"); return ncclSystemError; @@ -413,9 +416,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off return ncclSuccess; } +static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { + return socketProgressOpt(op, fd, ptr, size, offset, 0); +} + static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) { while (*offset < size) - NCCLCHECK(socketProgress(op, fd, ptr, size, offset)); + NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1)); return ncclSuccess; } diff --git a/src/init.cc b/src/init.cc index 80af287..42499c0 100644 --- a/src/init.cc +++ b/src/init.cc @@ -124,14 +124,15 @@ cleanup: } ncclResult_t initNet() { - // Always initialize sockets as we use it for bootstrap - NCCLCHECK(initNet(&ncclNetSocket)); + // Always initialize bootstrap network + NCCLCHECK(bootstrapNetInit()); NCCLCHECK(initNetPlugin(&ncclNet)); if (ncclNet != NULL) return ncclSuccess; if (initNet(&ncclNetIb) == ncclSuccess) { ncclNet = &ncclNetIb; } else { + NCCLCHECK(initNet(&ncclNetSocket)); ncclNet = &ncclNetSocket; } return ncclSuccess; diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 9958936..ab5e8ec 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -8,6 +8,7 @@ #include "core.h" #include "socket.h" #include "net.h" +#include "param.h" #include <assert.h> #include <pthread.h> @@ -15,6 +16,7 @@ #include <stdlib.h> #include <poll.h> #include <limits.h> +#include <fcntl.h> /* Init functions */ static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS]; @@ -68,7 +70,7 @@ ncclResult_t ncclSocketPciPath(int dev, char** path) { return ncclSuccess; } -static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) { +ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) { if (dev >= ncclNetIfs) return ncclInternalError; memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr)); return ncclSuccess; @@ -76,105 +78,281 @@ static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) { /* Communication functions */ +#define MAX_SOCKETS 64 +#define MAX_THREADS 16 +#define MAX_REQUESTS 128 +#define MAX_QUEUE_LEN MAX_REQUESTS +#define MIN_CHUNKSIZE (64*1024) + +NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2); +NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2); + struct ncclSocketHandle { union socketAddress connectAddr; + int nSocks; + int nThreads; }; -struct ncclSocketRequest { +struct ncclSocketTask { int op; void* data; int size; int fd; int offset; int used; + ncclResult_t result; }; -struct ncclSocketReqs { - struct ncclSocketRequest* requests; +struct ncclSocketRequest { + int op; + void* data; + int size; + int ctrlFd; + int used; + struct ncclSocketComm* comm; + struct ncclSocketTask* tasks[MAX_SOCKETS]; + int nSubs; }; -struct ncclSocketComm { +struct ncclSocketTaskQueue { + int next; + struct ncclSocketTask* tasks; +}; + +enum threadState {start, stop}; + +struct ncclSocketThreadResources { + struct ncclSocketTaskQueue threadTaskQueue; + enum threadState state; + struct ncclSocketComm* comm; + pthread_mutex_t threadLock; + pthread_cond_t threadCond; +}; + +struct ncclSocketListenComm { int fd; - struct ncclSocketReqs reqs; + int nSocks; + int nThreads; }; -ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) { +struct ncclSocketComm { + int ctrlFd; + int fds[MAX_SOCKETS]; + int nSocks; + int nThreads; + int nextFd; + struct ncclSocketRequest requests[MAX_REQUESTS]; + pthread_t helperThread[MAX_THREADS]; + struct ncclSocketThreadResources threadResources[MAX_THREADS]; +}; + +void* persistentSocketThread(void *args_) { + struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_; + struct ncclSocketComm* comm = resource->comm; + volatile enum threadState* state = &resource->state; + struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue; + int nSocksPerThread = comm->nSocks / comm->nThreads; + while (1) { + int idle = 1; + int mark = myQueue->next; // mark newest task seen + for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) { + int repeat; + do { + repeat = 0; + for (int j=0; j<nSocksPerThread; j++) { + struct ncclSocketTask* r = myQueue->tasks+i+j; + if (r != NULL && r->used == 1 && r->offset < r->size) { + r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset); + if (r->result != ncclSuccess) { + WARN("NET/Socket : socket progress error"); + return NULL; + } + idle = 0; + if (r->offset < r->size) repeat = 1; + } + } + } while (repeat); + } + if (idle) { + pthread_mutex_lock(&resource->threadLock); + while (mark == myQueue->next && *state != stop) { // no new tasks, wait + pthread_cond_wait(&resource->threadCond, &resource->threadLock); + } + pthread_mutex_unlock(&resource->threadLock); + } + if (*state == stop) return NULL; + } +} + +ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) { + int nSocksPerThread = ncclParamSocketNsocksPerThread(); + int nThreads = ncclParamSocketNthreads(); + if (nThreads > MAX_THREADS) { + WARN("NET/Socket : NCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS); + nThreads = MAX_THREADS; + } + if (nThreads == -2 || nSocksPerThread == -2) { + // Auto-detection + int autoNt=1, autoNs=1; + char vendorPath[PATH_MAX]; + snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetIfNames+dev*MAX_IF_NAME_SIZE); + char* rPath = realpath(vendorPath, NULL); + int fd = open(rPath, O_RDONLY); + free(rPath); + if (fd == -1) { + // Could not find device vendor. This is handled silently so + // we don't want to print an INFO error. + TRACE(NCCL_NET, "Open of %s failed : %s\n", vendorPath, strerror(errno)); + goto end; + } + char vendor[7]; + strncpy(vendor, "0x0000", 7); + int len; + SYSCHECKVAL(read(fd, vendor, 6), "read", len); + SYSCHECK(close(fd), "close"); + if (strcmp(vendor, "0x1d0f") == 0) { // AWS + autoNt = 2; + autoNs = 8; + } +end: + if (nThreads == -2) nThreads = autoNt; + if (nSocksPerThread == -2) nSocksPerThread = autoNs; + } + int nSocks = nSocksPerThread * nThreads; + if (nSocks > MAX_SOCKETS) { + nSocksPerThread = MAX_SOCKETS/nThreads; + WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting NCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread); + nSocks = nSocksPerThread * nThreads; + } + *ns = nSocks; + *nt = nThreads; + INFO(NCCL_INIT, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread); + return ncclSuccess; +} + +ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) { NCCLCHECK(ncclCalloc(comm, 1)); (*comm)->fd = -1; return ncclSuccess; } -ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str) { - struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - NCCLCHECK(GetSocketAddrFromString(&(handle->connectAddr), str)); +ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) { + NCCLCHECK(ncclCalloc(comm, 1)); + (*comm)->ctrlFd = -1; + for (int i=0; i < MAX_SOCKETS; i++) { + (*comm)->fds[i] = -1; + } + (*comm)->nextFd = 0; return ncclSuccess; } ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) { + if (dev < 0) { // data transfer socket is based on specified dev + return ncclInternalError; + } struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large"); - // if dev >= 0, listen based on dev - if (dev >= 0) { - NCCLCHECK(GetSocketAddr(dev, &(handle->connectAddr))); - } else if (dev == findSubnetIf) { - // handle stores a remote address - // need to find a local addr that is in the same network as the remote addr - union socketAddress localAddr; - char ifName[MAX_IF_NAME_SIZE]; - if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) { - WARN("NET/Socket : No usable listening interface found"); - return ncclSystemError; - } - // pass the local address back - memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr)); - } // Otherwise, handle stores a local address - struct ncclSocketComm* comm; - NCCLCHECK(ncclSocketNewComm(&comm)); + struct ncclSocketListenComm* comm; + NCCLCHECK(ncclSocketNewListenComm(&comm)); + NCCLCHECK(GetSocketAddr(dev, &handle->connectAddr)); NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr)); + NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); + handle->nSocks = comm->nSocks; + handle->nThreads = comm->nThreads; *listenComm = comm; return ncclSuccess; } ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) { + if (dev < 0) { // data transfer socket is based on specified dev + return ncclInternalError; + } struct ncclSocketComm* comm; NCCLCHECK(ncclSocketNewComm(&comm)); struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle; - NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr)); + comm->nSocks = handle->nSocks; + comm->nThreads = handle->nThreads; + for (int i=0; i<comm->nSocks+1; i++) { + int tmpFd, offset=0; + NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr)); + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset)); + if (i == comm->nSocks) comm->ctrlFd = tmpFd; + else comm->fds[i] = tmpFd; + } *sendComm = comm; return ncclSuccess; } ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { - struct ncclSocketComm* lComm = (struct ncclSocketComm*)listenComm; + struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm; struct ncclSocketComm* rComm; NCCLCHECK(ncclSocketNewComm(&rComm)); - struct sockaddr_in sockaddr; - socklen_t socklen = sizeof(struct sockaddr_in); - SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd); + rComm->nSocks = lComm->nSocks; + rComm->nThreads = lComm->nThreads; + for (int i=0; i<rComm->nSocks+1; i++) { + int tmpFd, sendSockIdx, offset=0; + struct sockaddr_in sockaddr; + socklen_t socklen = sizeof(struct sockaddr_in); + SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd); + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset)); + if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd; + else rComm->fds[sendSockIdx] = tmpFd; + } *recvComm = rComm; return ncclSuccess; } -#define MAX_REQUESTS 128 - -ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) { - if (reqs->requests == NULL) { - NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS)); - } +ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) { for (int i=0; i<MAX_REQUESTS; i++) { - struct ncclSocketRequest* r = reqs->requests+i; + struct ncclSocketRequest* r = comm->requests+i; if (r->used == 0) { r->op = op; r->data = data; r->size = size; - r->fd = fd; - r->offset = -1; + r->ctrlFd = comm->ctrlFd; r->used = 1; + r->comm = comm; + r->nSubs = 0; *req = r; return ncclSuccess; } } - WARN("Socket : unable to allocate requests"); + WARN("NET/Socket : unable to allocate requests"); + return ncclInternalError; +} + +ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) { + int tid = comm->nextFd % comm->nThreads; + struct ncclSocketThreadResources* res = comm->threadResources+tid; + struct ncclSocketTaskQueue* queue = &res->threadTaskQueue; + // create helper threads and prepare per-thread task queue + if (queue->tasks == NULL) { + NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN)); + queue->next = 0; + res->comm = comm; + pthread_mutex_init(&res->threadLock, NULL); + pthread_cond_init(&res->threadCond, NULL); + pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res); + } + struct ncclSocketTask* r = queue->tasks+queue->next; + if (r->used == 0) { + r->op = op; + r->data = data; + r->size = size; + r->fd = comm->fds[comm->nextFd]; + r->offset = 0; + r->result = ncclSuccess; + comm->nextFd = (comm->nextFd + 1) % comm->nSocks; + r->used = 1; + *req = r; + pthread_mutex_lock(&res->threadLock); + queue->next = (queue->next+1)%MAX_QUEUE_LEN; + res->state = start; + pthread_cond_signal(&res->threadCond); + pthread_mutex_unlock(&res->threadLock); + return ncclSuccess; + } + WARN("NET/Socket : unable to allocate subtasks"); return ncclInternalError; } @@ -185,15 +363,15 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { WARN("NET/Socket : test called with NULL request"); return ncclInternalError; } - if (r->offset == -1) { /* try to send/recv size */ + if (r->used == 1) { /* try to send/recv size */ int data = r->size; int offset = 0; - NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset)); + NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset)); if (offset == 0) return ncclSuccess; /* Not ready -- retry later */ // Not sure we could ever receive less than 4 bytes, but just in case ... - if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset)); + if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset)); // Check size is less or equal to the size provided by the user if (r->op == NCCL_SOCKET_RECV && data > r->size) { @@ -201,15 +379,33 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) { return ncclInternalError; } r->size = data; - r->offset = 0; - } - if (r->offset < r->size) { - NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset)); + r->used = 2; // done exchanging size + // divide into subtasks + int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks)); + int chunkOffset = 0, i = 0; + while (chunkOffset < r->size) { + int chunkSize = std::min(taskSize, r->size-chunkOffset); + NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++)); + chunkOffset += chunkSize; + } + r->nSubs = i; } - if (r->offset == r->size) { - if (size) *size = r->size; - *done = 1; - r->used = 0; + if (r->used == 2) { // already exchanged size + int nCompleted = 0; + for (int i=0; i<r->nSubs; i++) { + struct ncclSocketTask* sub = r->tasks[i]; + if (sub->result != ncclSuccess) return sub->result; + if (sub->offset == sub->size) nCompleted++; + } + if (nCompleted == r->nSubs) { + if (size) *size = r->size; + *done = 1; + r->used = 0; + for (int i=0; i<r->nSubs; i++) { + struct ncclSocketTask* sub = r->tasks[i]; + sub->used = 0; + } + } } return ncclSuccess; } @@ -221,13 +417,13 @@ ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, void* mhandle, void** request) { struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm; - NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request)); + NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request)); return ncclSuccess; } ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) { struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm; - NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request)); + NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data, size, (struct ncclSocketRequest**)request)); return ncclSuccess; } @@ -236,11 +432,33 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle return ncclInternalError; } +ncclResult_t ncclSocketCloseListen(void* opaqueComm) { + struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm; + if (comm) { + if (comm->fd != -1) close(comm->fd); + free(comm); + } + return ncclSuccess; +} + ncclResult_t ncclSocketClose(void* opaqueComm) { struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm; if (comm) { - free(comm->reqs.requests); - close(comm->fd); + for (int i=0; i<comm->nThreads; i++) { + struct ncclSocketThreadResources* res = comm->threadResources+i; + if (comm->helperThread[i]) { + pthread_mutex_lock(&res->threadLock); + res->state = stop; + pthread_cond_signal(&res->threadCond); + pthread_mutex_unlock(&res->threadLock); + pthread_join(comm->helperThread[i], NULL); + } + free(res->threadTaskQueue.tasks); + } + if (comm->ctrlFd != -1) close(comm->ctrlFd); + for (int i=0; i<comm->nSocks; i++) { + if (comm->fds[i] != -1) close(comm->fds[i]); + } free(comm); } return ncclSuccess; @@ -263,5 +481,5 @@ ncclNet_t ncclNetSocket = { ncclSocketTest, ncclSocketClose, ncclSocketClose, - ncclSocketClose + ncclSocketCloseListen }; |