Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKe Wen <kwen@nvidia.com>2019-06-25 23:22:47 +0300
committerKe Wen <kwen@nvidia.com>2019-06-25 23:22:47 +0300
commit7c72dee660e4d055b81721dd6b03e4e1c0a983cf (patch)
tree7e5c478c25f9877f721efd14a908dd6b0dc6a3b4
parent0ceaec9cee96ae7658aa45686853286651f36384 (diff)
2.4.8-1
Fix #209: improve socket transport performance Split transfers over multiple sockets Launch multiple threads to drive sockets Detect AWS NICs and set nsockets/nthreads accordingly
-rw-r--r--makefiles/version.mk2
-rw-r--r--src/bootstrap.cc152
-rw-r--r--src/include/bootstrap.h1
-rw-r--r--src/include/net.h6
-rw-r--r--src/include/socket.h21
-rw-r--r--src/init.cc5
-rw-r--r--src/transport/net_socket.cc334
7 files changed, 425 insertions, 96 deletions
diff --git a/makefiles/version.mk b/makefiles/version.mk
index 8341f33..bab58ec 100644
--- a/makefiles/version.mk
+++ b/makefiles/version.mk
@@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
NCCL_MINOR := 4
-NCCL_PATCH := 7
+NCCL_PATCH := 8
NCCL_SUFFIX :=
PKG_REVISION := 1
diff --git a/src/bootstrap.cc b/src/bootstrap.cc
index 9df38e4..d7c2ac6 100644
--- a/src/bootstrap.cc
+++ b/src/bootstrap.cc
@@ -9,37 +9,145 @@
#include "utils.h"
#include "bootstrap.h"
#include "net.h"
+#include "socket.h"
#include <unistd.h>
#include <sys/types.h>
// Always use sockets for bootstrap
-ncclNet_t* ncclBootstrapNet = &ncclNetSocket;
+struct bootstrapNetHandle {
+ union socketAddress connectAddr;
+};
+
+struct bootstrapNetComm {
+ int fd;
+};
+
+/* Init functions */
+static char bootstrapNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
+static union socketAddress bootstrapNetIfAddrs[MAX_IFS];
+static int bootstrapNetIfs = -1;
+pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
+
+ncclResult_t bootstrapNetInit() {
+ if (bootstrapNetIfs == -1) {
+ pthread_mutex_lock(&bootstrapNetLock);
+ if (bootstrapNetIfs == -1) {
+ bootstrapNetIfs = findInterfaces(bootstrapNetIfNames, bootstrapNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
+ if (bootstrapNetIfs <= 0) {
+ WARN("Bootstrap : no socket interface found");
+ return ncclInternalError;
+ } else {
+ char line[1024];
+ char addrline[1024];
+ line[0] = '\0';
+ for (int i=0; i<bootstrapNetIfs; i++) {
+ snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, bootstrapNetIfNames+i*MAX_IF_NAME_SIZE,
+ socketToString(&bootstrapNetIfAddrs[i].sa, addrline));
+ }
+ line[1023] = '\0';
+ INFO(NCCL_INIT, "Bootstrap : Using%s", line);
+ }
+ }
+ pthread_mutex_unlock(&bootstrapNetLock);
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t bootstrapNetNewComm(struct bootstrapNetComm** comm) {
+ NCCLCHECK(ncclCalloc(comm, 1));
+ (*comm)->fd = -1;
+ return ncclSuccess;
+}
+
+static ncclResult_t bootstrapNetGetSocketAddr(int dev, union socketAddress* addr) {
+ if (dev >= bootstrapNetIfs) return ncclInternalError;
+ memcpy(addr, bootstrapNetIfAddrs+dev, sizeof(*addr));
+ return ncclSuccess;
+}
+
+/* Socket Interface Selection type */
+enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
+
+static ncclResult_t bootstrapNetListen(int dev, void* opaqueHandle, void** listenComm) {
+ struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
+ static_assert(sizeof(struct bootstrapNetHandle) < NCCL_NET_HANDLE_MAXSIZE, "bootstrapNetHandle size too large");
+ // if dev >= 0, listen based on dev
+ if (dev >= 0) {
+ NCCLCHECK(bootstrapNetGetSocketAddr(dev, &(handle->connectAddr)));
+ } else if (dev == findSubnetIf) {
+ // handle stores a remote address
+ // need to find a local addr that is in the same network as the remote addr
+ union socketAddress localAddr;
+ char ifName[MAX_IF_NAME_SIZE];
+ if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
+ WARN("NET/Socket : No usable listening interface found");
+ return ncclSystemError;
+ }
+ // pass the local address back
+ memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
+ } // Otherwise, handle stores a local address
+ struct bootstrapNetComm* comm;
+ NCCLCHECK(bootstrapNetNewComm(&comm));
+ NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
+ *listenComm = comm;
+ return ncclSuccess;
+}
+
+static ncclResult_t bootstrapNetConnect(int dev, void* opaqueHandle, void** sendComm) {
+ struct bootstrapNetComm* comm;
+ NCCLCHECK(bootstrapNetNewComm(&comm));
+ struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
+ NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
+ *sendComm = comm;
+ return ncclSuccess;
+}
+
+static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) {
+ struct bootstrapNetComm* lComm = (struct bootstrapNetComm*)listenComm;
+ struct bootstrapNetComm* rComm;
+ NCCLCHECK(bootstrapNetNewComm(&rComm));
+ struct sockaddr_in sockaddr;
+ socklen_t socklen = sizeof(struct sockaddr_in);
+ SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
+ *recvComm = rComm;
+ return ncclSuccess;
+}
-static ncclResult_t bootstrapNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclBootstrapNet->listen(dev, handle, listenComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclBootstrapNet->connect(dev, handle, sendComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclBootstrapNet->accept(listenComm, recvComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetTest(void* request, int* done, int* size) { NCCLCHECK(ncclBootstrapNet->test(request, done, size)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(ncclBootstrapNet->closeSend(sendComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(ncclBootstrapNet->closeRecv(recvComm)); return ncclSuccess; }
-static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(ncclBootstrapNet->closeListen(listenComm)); return ncclSuccess; }
+static ncclResult_t bootstrapNetClose(void* opaqueComm) {
+ struct bootstrapNetComm* comm = (struct bootstrapNetComm*)opaqueComm;
+ if (comm) {
+ close(comm->fd);
+ free(comm);
+ }
+ return ncclSuccess;
+}
-// Additional sync functions based on async + test for bootstrap, using host ptrs.
+static ncclResult_t bootstrapNetCloseSend(void* sendComm) { NCCLCHECK(bootstrapNetClose(sendComm)); return ncclSuccess; }
+static ncclResult_t bootstrapNetCloseRecv(void* recvComm) { NCCLCHECK(bootstrapNetClose(recvComm)); return ncclSuccess; }
+static ncclResult_t bootstrapNetCloseListen(void* listenComm) { NCCLCHECK(bootstrapNetClose(listenComm)); return ncclSuccess; }
+
+// Additional sync functions
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
- void* request, *mhandle;
- NCCLCHECK(ncclBootstrapNet->regMr(sendComm, data, size, NCCL_PTR_HOST, &mhandle));
- NCCLCHECK(ncclBootstrapNet->isend(sendComm, data, size, mhandle, &request));
- NCCLCHECK(ncclBootstrapNet->deregMr(sendComm, mhandle));
- int done = 0;
- while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
+ struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
+ NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
+ NCCLCHECK(socketSend(comm->fd, data, size));
return ncclSuccess;
}
static ncclResult_t bootstrapNetRecv(void* recvComm, void* data, int size) {
- void* request, *mhandle;
- NCCLCHECK(ncclBootstrapNet->regMr(recvComm, data, size, NCCL_PTR_HOST, &mhandle));
- NCCLCHECK(ncclBootstrapNet->irecv(recvComm, data, size, mhandle, &request));
- NCCLCHECK(ncclBootstrapNet->deregMr(recvComm, mhandle));
- int done = 0;
- while (!done) NCCLCHECK(bootstrapNetTest(request, &done, NULL));
+ struct bootstrapNetComm* comm = (struct bootstrapNetComm*)recvComm;
+ int recvSize;
+ NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int)));
+ if (recvSize > size) {
+ WARN("Message truncated : received %d bytes instead of %d\n", recvSize, size);
+ return ncclInternalError;
+ }
+ NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size)));
+ return ncclSuccess;
+}
+
+ncclResult_t bootstrapNetCreateHandle(void* opaqueHandle, const char* str) {
+ struct bootstrapNetHandle* handle = (struct bootstrapNetHandle*) opaqueHandle;
+ NCCLCHECK(GetSocketAddrFromString(&handle->connectAddr, str));
return ncclSuccess;
}
@@ -148,7 +256,7 @@ ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out) {
char* env = getenv("NCCL_COMM_ID");
if (env) {
- if (ncclSocketCreateHandle(&id->extHandleRoot, env) != 0) {
+ if (bootstrapNetCreateHandle(&id->extHandleRoot, env) != 0) {
WARN("Invalid NCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
return ncclInvalidArgument;
}
diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h
index dd7de2c..dacbc7c 100644
--- a/src/include/bootstrap.h
+++ b/src/include/bootstrap.h
@@ -9,6 +9,7 @@
#include "nccl.h"
+ncclResult_t bootstrapNetInit();
ncclResult_t bootstrapCreateRoot(ncclUniqueId* commId, bool idFromEnv);
ncclResult_t bootstrapGetUniqueId(ncclUniqueId* out);
ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commState);
diff --git a/src/include/net.h b/src/include/net.h
index da3ecea..950b5e5 100644
--- a/src/include/net.h
+++ b/src/include/net.h
@@ -13,11 +13,6 @@
extern ncclNet_t* ncclNet;
typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
-/* Socket Interface Selection type */
-typedef enum { findSubnetIf = -1,
- dontCareIf = -2
-} ncclSocketIfSl_t;
-
// Translation to external API
static const char* ncclNetName() { return ncclNet->name; }
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
@@ -36,7 +31,6 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
-extern ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str);
extern ncclNet_t ncclNetIb;
extern ncclNet_t ncclNetSocket;
diff --git a/src/include/socket.h b/src/include/socket.h
index 739c0c4..8197a65 100644
--- a/src/include/socket.h
+++ b/src/include/socket.h
@@ -42,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
return buf;
}
-static inline short socketToPort(struct sockaddr *saddr) {
+static inline uint16_t socketToPort(struct sockaddr *saddr) {
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
}
@@ -161,7 +161,10 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
}
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
- char line[1024], line_a[1024];
+#ifdef ENABLE_TRACE
+ char line[1024];
+#endif
+ char line_a[1024];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
@@ -185,7 +188,7 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
// Store the interface name
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
- INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
+ TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
found++;
if (found == maxIfs) break;
}
@@ -390,12 +393,12 @@ retry:
#define NCCL_SOCKET_SEND 0
#define NCCL_SOCKET_RECV 1
-static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
+static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
int bytes = 0;
char* data = (char*)ptr;
do {
- if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
- if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
+ if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
+ if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_RECV && bytes == 0) {
WARN("Net : Connection closed by remote peer");
return ncclSystemError;
@@ -413,9 +416,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off
return ncclSuccess;
}
+static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
+ return socketProgressOpt(op, fd, ptr, size, offset, 0);
+}
+
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
while (*offset < size)
- NCCLCHECK(socketProgress(op, fd, ptr, size, offset));
+ NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
return ncclSuccess;
}
diff --git a/src/init.cc b/src/init.cc
index 80af287..42499c0 100644
--- a/src/init.cc
+++ b/src/init.cc
@@ -124,14 +124,15 @@ cleanup:
}
ncclResult_t initNet() {
- // Always initialize sockets as we use it for bootstrap
- NCCLCHECK(initNet(&ncclNetSocket));
+ // Always initialize bootstrap network
+ NCCLCHECK(bootstrapNetInit());
NCCLCHECK(initNetPlugin(&ncclNet));
if (ncclNet != NULL) return ncclSuccess;
if (initNet(&ncclNetIb) == ncclSuccess) {
ncclNet = &ncclNetIb;
} else {
+ NCCLCHECK(initNet(&ncclNetSocket));
ncclNet = &ncclNetSocket;
}
return ncclSuccess;
diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc
index 9958936..ab5e8ec 100644
--- a/src/transport/net_socket.cc
+++ b/src/transport/net_socket.cc
@@ -8,6 +8,7 @@
#include "core.h"
#include "socket.h"
#include "net.h"
+#include "param.h"
#include <assert.h>
#include <pthread.h>
@@ -15,6 +16,7 @@
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
+#include <fcntl.h>
/* Init functions */
static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
@@ -68,7 +70,7 @@ ncclResult_t ncclSocketPciPath(int dev, char** path) {
return ncclSuccess;
}
-static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
+ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
if (dev >= ncclNetIfs) return ncclInternalError;
memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr));
return ncclSuccess;
@@ -76,105 +78,281 @@ static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
/* Communication functions */
+#define MAX_SOCKETS 64
+#define MAX_THREADS 16
+#define MAX_REQUESTS 128
+#define MAX_QUEUE_LEN MAX_REQUESTS
+#define MIN_CHUNKSIZE (64*1024)
+
+NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
+NCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2);
+
struct ncclSocketHandle {
union socketAddress connectAddr;
+ int nSocks;
+ int nThreads;
};
-struct ncclSocketRequest {
+struct ncclSocketTask {
int op;
void* data;
int size;
int fd;
int offset;
int used;
+ ncclResult_t result;
};
-struct ncclSocketReqs {
- struct ncclSocketRequest* requests;
+struct ncclSocketRequest {
+ int op;
+ void* data;
+ int size;
+ int ctrlFd;
+ int used;
+ struct ncclSocketComm* comm;
+ struct ncclSocketTask* tasks[MAX_SOCKETS];
+ int nSubs;
};
-struct ncclSocketComm {
+struct ncclSocketTaskQueue {
+ int next;
+ struct ncclSocketTask* tasks;
+};
+
+enum threadState {start, stop};
+
+struct ncclSocketThreadResources {
+ struct ncclSocketTaskQueue threadTaskQueue;
+ enum threadState state;
+ struct ncclSocketComm* comm;
+ pthread_mutex_t threadLock;
+ pthread_cond_t threadCond;
+};
+
+struct ncclSocketListenComm {
int fd;
- struct ncclSocketReqs reqs;
+ int nSocks;
+ int nThreads;
};
-ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
+struct ncclSocketComm {
+ int ctrlFd;
+ int fds[MAX_SOCKETS];
+ int nSocks;
+ int nThreads;
+ int nextFd;
+ struct ncclSocketRequest requests[MAX_REQUESTS];
+ pthread_t helperThread[MAX_THREADS];
+ struct ncclSocketThreadResources threadResources[MAX_THREADS];
+};
+
+void* persistentSocketThread(void *args_) {
+ struct ncclSocketThreadResources* resource = (struct ncclSocketThreadResources*)args_;
+ struct ncclSocketComm* comm = resource->comm;
+ volatile enum threadState* state = &resource->state;
+ struct ncclSocketTaskQueue* myQueue = &resource->threadTaskQueue;
+ int nSocksPerThread = comm->nSocks / comm->nThreads;
+ while (1) {
+ int idle = 1;
+ int mark = myQueue->next; // mark newest task seen
+ for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) {
+ int repeat;
+ do {
+ repeat = 0;
+ for (int j=0; j<nSocksPerThread; j++) {
+ struct ncclSocketTask* r = myQueue->tasks+i+j;
+ if (r != NULL && r->used == 1 && r->offset < r->size) {
+ r->result = socketProgress(r->op, r->fd, r->data, r->size, &r->offset);
+ if (r->result != ncclSuccess) {
+ WARN("NET/Socket : socket progress error");
+ return NULL;
+ }
+ idle = 0;
+ if (r->offset < r->size) repeat = 1;
+ }
+ }
+ } while (repeat);
+ }
+ if (idle) {
+ pthread_mutex_lock(&resource->threadLock);
+ while (mark == myQueue->next && *state != stop) { // no new tasks, wait
+ pthread_cond_wait(&resource->threadCond, &resource->threadLock);
+ }
+ pthread_mutex_unlock(&resource->threadLock);
+ }
+ if (*state == stop) return NULL;
+ }
+}
+
+ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) {
+ int nSocksPerThread = ncclParamSocketNsocksPerThread();
+ int nThreads = ncclParamSocketNthreads();
+ if (nThreads > MAX_THREADS) {
+ WARN("NET/Socket : NCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS);
+ nThreads = MAX_THREADS;
+ }
+ if (nThreads == -2 || nSocksPerThread == -2) {
+ // Auto-detection
+ int autoNt=1, autoNs=1;
+ char vendorPath[PATH_MAX];
+ snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetIfNames+dev*MAX_IF_NAME_SIZE);
+ char* rPath = realpath(vendorPath, NULL);
+ int fd = open(rPath, O_RDONLY);
+ free(rPath);
+ if (fd == -1) {
+ // Could not find device vendor. This is handled silently so
+ // we don't want to print an INFO error.
+ TRACE(NCCL_NET, "Open of %s failed : %s\n", vendorPath, strerror(errno));
+ goto end;
+ }
+ char vendor[7];
+ strncpy(vendor, "0x0000", 7);
+ int len;
+ SYSCHECKVAL(read(fd, vendor, 6), "read", len);
+ SYSCHECK(close(fd), "close");
+ if (strcmp(vendor, "0x1d0f") == 0) { // AWS
+ autoNt = 2;
+ autoNs = 8;
+ }
+end:
+ if (nThreads == -2) nThreads = autoNt;
+ if (nSocksPerThread == -2) nSocksPerThread = autoNs;
+ }
+ int nSocks = nSocksPerThread * nThreads;
+ if (nSocks > MAX_SOCKETS) {
+ nSocksPerThread = MAX_SOCKETS/nThreads;
+ WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting NCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread);
+ nSocks = nSocksPerThread * nThreads;
+ }
+ *ns = nSocks;
+ *nt = nThreads;
+ INFO(NCCL_INIT, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread);
+ return ncclSuccess;
+}
+
+ncclResult_t ncclSocketNewListenComm(struct ncclSocketListenComm** comm) {
NCCLCHECK(ncclCalloc(comm, 1));
(*comm)->fd = -1;
return ncclSuccess;
}
-ncclResult_t ncclSocketCreateHandle(void* opaqueHandle, const char* str) {
- struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
- NCCLCHECK(GetSocketAddrFromString(&(handle->connectAddr), str));
+ncclResult_t ncclSocketNewComm(struct ncclSocketComm** comm) {
+ NCCLCHECK(ncclCalloc(comm, 1));
+ (*comm)->ctrlFd = -1;
+ for (int i=0; i < MAX_SOCKETS; i++) {
+ (*comm)->fds[i] = -1;
+ }
+ (*comm)->nextFd = 0;
return ncclSuccess;
}
ncclResult_t ncclSocketListen(int dev, void* opaqueHandle, void** listenComm) {
+ if (dev < 0) { // data transfer socket is based on specified dev
+ return ncclInternalError;
+ }
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclSocketHandle size too large");
- // if dev >= 0, listen based on dev
- if (dev >= 0) {
- NCCLCHECK(GetSocketAddr(dev, &(handle->connectAddr)));
- } else if (dev == findSubnetIf) {
- // handle stores a remote address
- // need to find a local addr that is in the same network as the remote addr
- union socketAddress localAddr;
- char ifName[MAX_IF_NAME_SIZE];
- if (findInterfaceMatchSubnet(ifName, &localAddr, handle->connectAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
- WARN("NET/Socket : No usable listening interface found");
- return ncclSystemError;
- }
- // pass the local address back
- memcpy(&handle->connectAddr, &localAddr, sizeof(handle->connectAddr));
- } // Otherwise, handle stores a local address
- struct ncclSocketComm* comm;
- NCCLCHECK(ncclSocketNewComm(&comm));
+ struct ncclSocketListenComm* comm;
+ NCCLCHECK(ncclSocketNewListenComm(&comm));
+ NCCLCHECK(GetSocketAddr(dev, &handle->connectAddr));
NCCLCHECK(createListenSocket(&comm->fd, &handle->connectAddr));
+ NCCLCHECK(ncclSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads));
+ handle->nSocks = comm->nSocks;
+ handle->nThreads = comm->nThreads;
*listenComm = comm;
return ncclSuccess;
}
ncclResult_t ncclSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
+ if (dev < 0) { // data transfer socket is based on specified dev
+ return ncclInternalError;
+ }
struct ncclSocketComm* comm;
NCCLCHECK(ncclSocketNewComm(&comm));
struct ncclSocketHandle* handle = (struct ncclSocketHandle*) opaqueHandle;
- NCCLCHECK(connectAddress(&comm->fd, &handle->connectAddr));
+ comm->nSocks = handle->nSocks;
+ comm->nThreads = handle->nThreads;
+ for (int i=0; i<comm->nSocks+1; i++) {
+ int tmpFd, offset=0;
+ NCCLCHECK(connectAddress(&tmpFd, &handle->connectAddr));
+ NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset));
+ if (i == comm->nSocks) comm->ctrlFd = tmpFd;
+ else comm->fds[i] = tmpFd;
+ }
*sendComm = comm;
return ncclSuccess;
}
ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) {
- struct ncclSocketComm* lComm = (struct ncclSocketComm*)listenComm;
+ struct ncclSocketListenComm* lComm = (struct ncclSocketListenComm*)listenComm;
struct ncclSocketComm* rComm;
NCCLCHECK(ncclSocketNewComm(&rComm));
- struct sockaddr_in sockaddr;
- socklen_t socklen = sizeof(struct sockaddr_in);
- SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", rComm->fd);
+ rComm->nSocks = lComm->nSocks;
+ rComm->nThreads = lComm->nThreads;
+ for (int i=0; i<rComm->nSocks+1; i++) {
+ int tmpFd, sendSockIdx, offset=0;
+ struct sockaddr_in sockaddr;
+ socklen_t socklen = sizeof(struct sockaddr_in);
+ SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), "accept", tmpFd);
+ NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), &offset));
+ if (sendSockIdx == rComm->nSocks) rComm->ctrlFd = tmpFd;
+ else rComm->fds[sendSockIdx] = tmpFd;
+ }
*recvComm = rComm;
return ncclSuccess;
}
-#define MAX_REQUESTS 128
-
-ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) {
- if (reqs->requests == NULL) {
- NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS));
- }
+ncclResult_t ncclSocketGetRequest(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketRequest** req) {
for (int i=0; i<MAX_REQUESTS; i++) {
- struct ncclSocketRequest* r = reqs->requests+i;
+ struct ncclSocketRequest* r = comm->requests+i;
if (r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
- r->fd = fd;
- r->offset = -1;
+ r->ctrlFd = comm->ctrlFd;
r->used = 1;
+ r->comm = comm;
+ r->nSubs = 0;
*req = r;
return ncclSuccess;
}
}
- WARN("Socket : unable to allocate requests");
+ WARN("NET/Socket : unable to allocate requests");
+ return ncclInternalError;
+}
+
+ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data, int size, struct ncclSocketTask** req) {
+ int tid = comm->nextFd % comm->nThreads;
+ struct ncclSocketThreadResources* res = comm->threadResources+tid;
+ struct ncclSocketTaskQueue* queue = &res->threadTaskQueue;
+ // create helper threads and prepare per-thread task queue
+ if (queue->tasks == NULL) {
+ NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN));
+ queue->next = 0;
+ res->comm = comm;
+ pthread_mutex_init(&res->threadLock, NULL);
+ pthread_cond_init(&res->threadCond, NULL);
+ pthread_create(comm->helperThread+tid, NULL, persistentSocketThread, res);
+ }
+ struct ncclSocketTask* r = queue->tasks+queue->next;
+ if (r->used == 0) {
+ r->op = op;
+ r->data = data;
+ r->size = size;
+ r->fd = comm->fds[comm->nextFd];
+ r->offset = 0;
+ r->result = ncclSuccess;
+ comm->nextFd = (comm->nextFd + 1) % comm->nSocks;
+ r->used = 1;
+ *req = r;
+ pthread_mutex_lock(&res->threadLock);
+ queue->next = (queue->next+1)%MAX_QUEUE_LEN;
+ res->state = start;
+ pthread_cond_signal(&res->threadCond);
+ pthread_mutex_unlock(&res->threadLock);
+ return ncclSuccess;
+ }
+ WARN("NET/Socket : unable to allocate subtasks");
return ncclInternalError;
}
@@ -185,15 +363,15 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
WARN("NET/Socket : test called with NULL request");
return ncclInternalError;
}
- if (r->offset == -1) { /* try to send/recv size */
+ if (r->used == 1) { /* try to send/recv size */
int data = r->size;
int offset = 0;
- NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset));
+ NCCLCHECK(socketProgress(r->op, r->ctrlFd, &data, sizeof(int), &offset));
if (offset == 0) return ncclSuccess; /* Not ready -- retry later */
// Not sure we could ever receive less than 4 bytes, but just in case ...
- if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset));
+ if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->ctrlFd, &data, sizeof(int), &offset));
// Check size is less or equal to the size provided by the user
if (r->op == NCCL_SOCKET_RECV && data > r->size) {
@@ -201,15 +379,33 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
return ncclInternalError;
}
r->size = data;
- r->offset = 0;
- }
- if (r->offset < r->size) {
- NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset));
+ r->used = 2; // done exchanging size
+ // divide into subtasks
+ int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
+ int chunkOffset = 0, i = 0;
+ while (chunkOffset < r->size) {
+ int chunkSize = std::min(taskSize, r->size-chunkOffset);
+ NCCLCHECK(ncclSocketGetTask(r->comm, r->op, (char*)(r->data)+chunkOffset, chunkSize, r->tasks+i++));
+ chunkOffset += chunkSize;
+ }
+ r->nSubs = i;
}
- if (r->offset == r->size) {
- if (size) *size = r->size;
- *done = 1;
- r->used = 0;
+ if (r->used == 2) { // already exchanged size
+ int nCompleted = 0;
+ for (int i=0; i<r->nSubs; i++) {
+ struct ncclSocketTask* sub = r->tasks[i];
+ if (sub->result != ncclSuccess) return sub->result;
+ if (sub->offset == sub->size) nCompleted++;
+ }
+ if (nCompleted == r->nSubs) {
+ if (size) *size = r->size;
+ *done = 1;
+ r->used = 0;
+ for (int i=0; i<r->nSubs; i++) {
+ struct ncclSocketTask* sub = r->tasks[i];
+ sub->used = 0;
+ }
+ }
}
return ncclSuccess;
}
@@ -221,13 +417,13 @@ ncclResult_t ncclSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess;
ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, void* mhandle, void** request) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm;
- NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request));
+ NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclSocketRequest**)request));
return ncclSuccess;
}
ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle, void** request) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm;
- NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request));
+ NCCLCHECK(ncclSocketGetRequest(comm, NCCL_SOCKET_RECV, data, size, (struct ncclSocketRequest**)request));
return ncclSuccess;
}
@@ -236,11 +432,33 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle
return ncclInternalError;
}
+ncclResult_t ncclSocketCloseListen(void* opaqueComm) {
+ struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm;
+ if (comm) {
+ if (comm->fd != -1) close(comm->fd);
+ free(comm);
+ }
+ return ncclSuccess;
+}
+
ncclResult_t ncclSocketClose(void* opaqueComm) {
struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm;
if (comm) {
- free(comm->reqs.requests);
- close(comm->fd);
+ for (int i=0; i<comm->nThreads; i++) {
+ struct ncclSocketThreadResources* res = comm->threadResources+i;
+ if (comm->helperThread[i]) {
+ pthread_mutex_lock(&res->threadLock);
+ res->state = stop;
+ pthread_cond_signal(&res->threadCond);
+ pthread_mutex_unlock(&res->threadLock);
+ pthread_join(comm->helperThread[i], NULL);
+ }
+ free(res->threadTaskQueue.tasks);
+ }
+ if (comm->ctrlFd != -1) close(comm->ctrlFd);
+ for (int i=0; i<comm->nSocks; i++) {
+ if (comm->fds[i] != -1) close(comm->fds[i]);
+ }
free(comm);
}
return ncclSuccess;
@@ -263,5 +481,5 @@ ncclNet_t ncclNetSocket = {
ncclSocketTest,
ncclSocketClose,
ncclSocketClose,
- ncclSocketClose
+ ncclSocketCloseListen
};