Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/transport/net_socket.cc')
-rw-r--r--src/transport/net_socket.cc102
1 files changed, 70 insertions, 32 deletions
diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc
index 1b1fc4f..272d8cd 100644
--- a/src/transport/net_socket.cc
+++ b/src/transport/net_socket.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -19,28 +19,48 @@
#include <fcntl.h>
/* Init functions */
-static char ncclNetIfNames[MAX_IF_NAME_SIZE*MAX_IFS];
-static union socketAddress ncclNetIfAddrs[MAX_IFS];
static int ncclNetIfs = -1;
+struct ncclSocketDev {
+ union socketAddress addr;
+ char devName[MAX_IF_NAME_SIZE];
+ char* pciPath;
+};
+static struct ncclSocketDev ncclSocketDevs[MAX_IFS];
+
pthread_mutex_t ncclSocketLock = PTHREAD_MUTEX_INITIALIZER;
+static ncclResult_t ncclSocketGetPciPath(char* devName, char** pciPath) {
+ char devicePath[PATH_MAX];
+ snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName);
+ // May return NULL if the file doesn't exist.
+ *pciPath = realpath(devicePath, NULL);
+ return ncclSuccess;
+}
+
ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) {
if (ncclNetIfs == -1) {
pthread_mutex_lock(&ncclSocketLock);
if (ncclNetIfs == -1) {
- ncclNetIfs = findInterfaces(ncclNetIfNames, ncclNetIfAddrs, MAX_IF_NAME_SIZE, MAX_IFS);
+ char names[MAX_IF_NAME_SIZE*MAX_IFS];
+ union socketAddress addrs[MAX_IFS];
+ ncclNetIfs = findInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS);
if (ncclNetIfs <= 0) {
WARN("NET/Socket : no interface found");
return ncclInternalError;
} else {
- char line[1024];
- char addrline[1024];
+ #define MAX_LINE_LEN (2047)
+ char line[MAX_LINE_LEN+1];
+ char addrline[SOCKET_NAME_MAXLEN+1];
line[0] = '\0';
+ addrline[SOCKET_NAME_MAXLEN] = '\0';
for (int i=0; i<ncclNetIfs; i++) {
- snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%s", i, ncclNetIfNames+i*MAX_IF_NAME_SIZE,
- socketToString(&ncclNetIfAddrs[i].sa, addrline));
+ strcpy(ncclSocketDevs[i].devName, names+i*MAX_IF_NAME_SIZE);
+ memcpy(&ncclSocketDevs[i].addr, addrs+i, sizeof(union socketAddress));
+ NCCLCHECK(ncclSocketGetPciPath(ncclSocketDevs[i].devName, &ncclSocketDevs[i].pciPath));
+ snprintf(line+strlen(line), MAX_LINE_LEN-strlen(line), " [%d]%s:%s", i, names+i*MAX_IF_NAME_SIZE,
+ socketToString(&addrs[i].sa, addrline));
}
- line[1023] = '\0';
+ line[MAX_LINE_LEN] = '\0';
INFO(NCCL_INIT|NCCL_NET,"NET/Socket : Using%s", line);
}
}
@@ -49,30 +69,44 @@ ncclResult_t ncclSocketInit(ncclDebugLogger_t logFunction) {
return ncclSuccess;
}
-ncclResult_t ncclSocketPtrSupport(int dev, int* supportedTypes) {
- *supportedTypes = NCCL_PTR_HOST;
- return ncclSuccess;
-}
-
ncclResult_t ncclSocketDevices(int* ndev) {
*ndev = ncclNetIfs;
return ncclSuccess;
}
-ncclResult_t ncclSocketPciPath(int dev, char** path) {
- char devicepath[PATH_MAX];
- snprintf(devicepath, PATH_MAX, "/sys/class/net/%s/device", ncclNetIfNames+dev*MAX_IF_NAME_SIZE);
- *path = realpath(devicepath, NULL);
- if (*path == NULL) {
- INFO(NCCL_NET|NCCL_INIT, "Could not find real path of %s", devicepath);
- return ncclSystemError;
+static ncclResult_t ncclSocketGetSpeed(char* devName, int* speed) {
+ *speed = 0;
+ char speedPath[PATH_MAX];
+ sprintf(speedPath, "/sys/class/net/%s/speed", devName);
+ int fd = open(speedPath, O_RDONLY);
+ if (fd != -1) {
+ char speedStr[] = " ";
+ if (read(fd, speedStr, sizeof(speedStr)-1) > 0) {
+ *speed = strtol(speedStr, NULL, 0);
+ }
+ close(fd);
+ }
+ if (*speed <= 0) {
+ INFO(NCCL_NET, "Could not get speed from %s. Defaulting to 10 Gbps.", speedPath);
+ *speed = 10000;
}
return ncclSuccess;
}
+ncclResult_t ncclSocketGetProperties(int dev, ncclNetProperties_t* props) {
+ props->name = ncclSocketDevs[dev].devName;
+ props->pciPath = ncclSocketDevs[dev].pciPath;
+ props->guid = dev;
+ props->ptrSupport = NCCL_PTR_HOST;
+ NCCLCHECK(ncclSocketGetSpeed(props->name, &props->speed));
+ props->port = 0;
+ props->maxComms = 65536;
+ return ncclSuccess;
+}
+
ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
if (dev >= ncclNetIfs) return ncclInternalError;
- memcpy(addr, ncclNetIfAddrs+dev, sizeof(*addr));
+ memcpy(addr, &ncclSocketDevs[dev].addr, sizeof(*addr));
return ncclSuccess;
}
@@ -80,8 +114,7 @@ ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) {
#define MAX_SOCKETS 64
#define MAX_THREADS 16
-#define MAX_REQUESTS 128
-#define MAX_QUEUE_LEN MAX_REQUESTS
+#define MAX_REQUESTS NCCL_NET_MAX_REQUESTS
#define MIN_CHUNKSIZE (64*1024)
NCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
@@ -117,6 +150,7 @@ struct ncclSocketRequest {
struct ncclSocketTaskQueue {
int next;
+ int len;
struct ncclSocketTask* tasks;
};
@@ -156,7 +190,7 @@ void* persistentSocketThread(void *args_) {
while (1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
- for (int i=0; i<MAX_QUEUE_LEN; i+=nSocksPerThread) {
+ for (int i=0; i<myQueue->len; i+=nSocksPerThread) {
int repeat;
do {
repeat = 0;
@@ -196,7 +230,7 @@ ncclResult_t ncclSocketGetNsockNthread(int dev, int* ns, int* nt) {
// Auto-detection
int autoNt=0, autoNs=1; // By default, we only use the main thread and do not spawn extra threads
char vendorPath[PATH_MAX];
- snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclNetIfNames+dev*MAX_IF_NAME_SIZE);
+ snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", ncclSocketDevs[dev].devName);
char* rPath = realpath(vendorPath, NULL);
int fd = open(rPath, O_RDONLY);
free(rPath);
@@ -331,7 +365,11 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
struct ncclSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if (queue->tasks == NULL) {
- NCCLCHECK(ncclCalloc(&queue->tasks, MAX_QUEUE_LEN));
+ // each request can be divided up to nSocks tasks, and
+ // these tasks are distributed to nThreads threads,
+ // we need to make sure each thread queue has enough slots for MAX_REQUESTS
+ queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
+ NCCLCHECK(ncclCalloc(&queue->tasks, queue->len));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
@@ -350,7 +388,7 @@ ncclResult_t ncclSocketGetTask(struct ncclSocketComm* comm, int op, void* data,
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
- queue->next = (queue->next+1)%MAX_QUEUE_LEN;
+ queue->next = (queue->next+1)%queue->len;
res->state = start;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
@@ -388,6 +426,7 @@ ncclResult_t ncclSocketTest(void* request, int* done, int* size) {
// divide into subtasks
int chunkOffset = 0, i = 0;
if (r->comm->nSocks > 0) {
+ // each request can be divided up to nSocks tasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
while (chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size-chunkOffset);
@@ -445,7 +484,7 @@ ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, void* mhandle
return ncclSuccess;
}
-ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size, void* mhandle) {
+ncclResult_t ncclSocketIflush(void* recvComm, void* data, int size, void* mhandle, void** request) {
// We don't support CUDA pointers, so we don't need a flush operation
return ncclInternalError;
}
@@ -486,8 +525,7 @@ ncclNet_t ncclNetSocket = {
"Socket",
ncclSocketInit,
ncclSocketDevices,
- ncclSocketPciPath,
- ncclSocketPtrSupport,
+ ncclSocketGetProperties,
ncclSocketListen,
ncclSocketConnect,
ncclSocketAccept,
@@ -495,7 +533,7 @@ ncclNet_t ncclNetSocket = {
ncclSocketDeregMr,
ncclSocketIsend,
ncclSocketIrecv,
- ncclSocketFlush,
+ ncclSocketIflush,
ncclSocketTest,
ncclSocketClose,
ncclSocketClose,