Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/socket.h')
-rw-r--r--src/include/socket.h65
1 files changed, 42 insertions, 23 deletions
diff --git a/src/include/socket.h b/src/include/socket.h
index fb5cfc0..96bf5db 100644
--- a/src/include/socket.h
+++ b/src/include/socket.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -18,8 +18,9 @@
#define MAX_IFS 16
#define MAX_IF_NAME_SIZE 16
-#define SLEEP_INT 1000 // sleep interval in usec
-#define RETRY_TIMES 2e4 // retry times before reporting a timeout (20 sec)
+#define SLEEP_INT 1000 // connection retry sleep interval in usec
+#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
+#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
/* Common socket address storage structure for IPv4/IPv6 */
union socketAddress {
@@ -41,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
return buf;
}
-static inline short socketToPort(struct sockaddr *saddr) {
+static inline uint16_t socketToPort(struct sockaddr *saddr) {
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
}
@@ -65,6 +66,9 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre
#endif
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
+ if (searchNot) prefixList++;
+ bool searchExact = prefixList && prefixList[0] == '=';
+ if (searchExact) prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
@@ -91,7 +95,7 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre
}
// check against user specified interfaces
- if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs) ^ searchNot)) {
+ if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
@@ -116,17 +120,17 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre
return found;
}
-static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
+static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
- if (family != remote.sa.sa_family) {
+ if (family != remote->sa.sa_family) {
return false;
}
if (family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
- struct sockaddr_in& remote_addr = remote.sin;
+ struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
@@ -134,7 +138,7 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
} else if (family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
- struct sockaddr_in6& remote_addr = remote.sin6;
+ struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
@@ -159,8 +163,11 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
}
}
-static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
- char line[1024], line_a[1024];
+static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
+#ifdef ENABLE_TRACE
+ char line[1024];
+#endif
+ char line_a[1024];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
@@ -184,13 +191,13 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd
// Store the interface name
strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
- INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
+ TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a));
found++;
if (found == maxIfs) break;
}
if (found == 0) {
- WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr.sa), line_a));
+ WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a));
}
freeifaddrs(interfaces);
return found;
@@ -295,7 +302,7 @@ static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNam
// Try to find interface that is in the same subnet as the IP in comm id
union socketAddress idAddr;
GetSocketAddrFromString(&idAddr, commId);
- nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, idAddr, ifNameMaxSize, maxIfs);
+ nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
// Then look for anything else (but not docker or lo)
@@ -322,7 +329,11 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
if (socketToPort(&localAddr->sa)) {
// Port is forced by env. Make sure we get the port.
int opt = 1;
+#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
+#else
+ SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
+#endif
}
// localAddr port should be 0 (Any port)
@@ -370,14 +381,18 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
#endif
int ret;
- int retries = 0;
+ int timedout_retries = 0;
+ int refused_retries = 0;
retry:
SYSCHECKSYNC(connect(*fd, &remoteAddr->sa, salen), "connect", ret);
if (ret == 0) return ncclSuccess;
- if (errno == ECONNREFUSED && ++retries < RETRY_TIMES) {
- INFO(NCCL_ALL,"Call to connect returned %s, retrying", strerror(errno)); \
- usleep(SLEEP_INT);
- goto retry;
+ if ((errno == ECONNREFUSED || errno == ETIMEDOUT)) {
+ if ((errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) ||
+ (errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) {
+ if (refused_retries % 1000 == 0) INFO(NCCL_ALL,"Call to connect returned %s, retrying", strerror(errno));
+ usleep(SLEEP_INT);
+ goto retry;
+ }
}
WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno));
return ncclSystemError;
@@ -385,12 +400,12 @@ retry:
#define NCCL_SOCKET_SEND 0
#define NCCL_SOCKET_RECV 1
-static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
+static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
int bytes = 0;
char* data = (char*)ptr;
do {
- if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
- if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT);
+ if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
+ if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_RECV && bytes == 0) {
WARN("Net : Connection closed by remote peer");
return ncclSystemError;
@@ -408,9 +423,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off
return ncclSuccess;
}
+static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
+ return socketProgressOpt(op, fd, ptr, size, offset, 0);
+}
+
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
while (*offset < size)
- NCCLCHECK(socketProgress(op, fd, ptr, size, offset));
+ NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
return ncclSuccess;
}