diff options
Diffstat (limited to 'src/include/socket.h')
-rw-r--r-- | src/include/socket.h | 65 |
1 files changed, 42 insertions, 23 deletions
diff --git a/src/include/socket.h b/src/include/socket.h index fb5cfc0..96bf5db 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -18,8 +18,9 @@ #define MAX_IFS 16 #define MAX_IF_NAME_SIZE 16 -#define SLEEP_INT 1000 // sleep interval in usec -#define RETRY_TIMES 2e4 // retry times before reporting a timeout (20 sec) +#define SLEEP_INT 1000 // connection retry sleep interval in usec +#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) +#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) /* Common socket address storage structure for IPv4/IPv6 */ union socketAddress { @@ -41,7 +42,7 @@ static inline const char *socketToString(struct sockaddr *saddr, char *buf) { return buf; } -static inline short socketToPort(struct sockaddr *saddr) { +static inline uint16_t socketToPort(struct sockaddr *saddr) { return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port); } @@ -65,6 +66,9 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre #endif struct netIf userIfs[MAX_IFS]; bool searchNot = prefixList && prefixList[0] == '^'; + if (searchNot) prefixList++; + bool searchExact = prefixList && prefixList[0] == '='; + if (searchExact) prefixList++; int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS); int found = 0; @@ -91,7 +95,7 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre } // check against user specified interfaces - if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs) ^ searchNot)) { + if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) { continue; } @@ -116,17 +120,17 @@ static int findInterfaces(const char* prefixList, char* names, union socketAddre return found; } -static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) { +static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) { /* Check family first */ int family = local_if.ifa_addr->sa_family; - if (family != remote.sa.sa_family) { + if (family != remote->sa.sa_family) { return false; } if (family == AF_INET) { struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr); struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask); - struct sockaddr_in& remote_addr = remote.sin; + struct sockaddr_in& remote_addr = remote->sin; struct in_addr local_subnet, remote_subnet; local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr; remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr; @@ -134,7 +138,7 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) { } else if (family == AF_INET6) { struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr); struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask); - struct sockaddr_in6& remote_addr = remote.sin6; + struct sockaddr_in6& remote_addr = remote->sin6; struct in6_addr& local_in6 = local_addr->sin6_addr; struct in6_addr& mask_in6 = mask->sin6_addr; struct in6_addr& remote_in6 = remote_addr.sin6_addr; @@ -159,8 +163,11 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) { } } -static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) { - char line[1024], line_a[1024]; +static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) { +#ifdef ENABLE_TRACE + char line[1024]; +#endif + char line_a[1024]; int found = 0; struct ifaddrs *interfaces, *interface; getifaddrs(&interfaces); @@ -184,13 +191,13 @@ static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAdd // Store the interface name strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize); - INFO(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a)); + TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a)); found++; if (found == maxIfs) break; } if (found == 0) { - WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr.sa), line_a)); + WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a)); } freeifaddrs(interfaces); return found; @@ -295,7 +302,7 @@ static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNam // Try to find interface that is in the same subnet as the IP in comm id union socketAddress idAddr; GetSocketAddrFromString(&idAddr, commId); - nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, idAddr, ifNameMaxSize, maxIfs); + nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); } } // Then look for anything else (but not docker or lo) @@ -322,7 +329,11 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) if (socketToPort(&localAddr->sa)) { // Port is forced by env. Make sure we get the port. int opt = 1; +#if defined(SO_REUSEPORT) SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); +#else + SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); +#endif } // localAddr port should be 0 (Any port) @@ -370,14 +381,18 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { #endif int ret; - int retries = 0; + int timedout_retries = 0; + int refused_retries = 0; retry: SYSCHECKSYNC(connect(*fd, &remoteAddr->sa, salen), "connect", ret); if (ret == 0) return ncclSuccess; - if (errno == ECONNREFUSED && ++retries < RETRY_TIMES) { - INFO(NCCL_ALL,"Call to connect returned %s, retrying", strerror(errno)); \ - usleep(SLEEP_INT); - goto retry; + if ((errno == ECONNREFUSED || errno == ETIMEDOUT)) { + if ((errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) || + (errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) { + if (refused_retries % 1000 == 0) INFO(NCCL_ALL,"Call to connect returned %s, retrying", strerror(errno)); + usleep(SLEEP_INT); + goto retry; + } } WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno)); return ncclSystemError; @@ -385,12 +400,12 @@ retry: #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 -static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { +static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) { int bytes = 0; char* data = (char*)ptr; do { - if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); - if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); + if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); + if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); if (op == NCCL_SOCKET_RECV && bytes == 0) { WARN("Net : Connection closed by remote peer"); return ncclSystemError; @@ -408,9 +423,13 @@ static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* off return ncclSuccess; } +static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { + return socketProgressOpt(op, fd, ptr, size, offset, 0); +} + static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) { while (*offset < size) - NCCLCHECK(socketProgress(op, fd, ptr, size, offset)); + NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1)); return ncclSuccess; } |