diff options
Diffstat (limited to 'src/include/socket.h')
-rw-r--r-- | src/include/socket.h | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/src/include/socket.h b/src/include/socket.h index 96bf5db..e903b04 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -21,6 +21,7 @@ #define SLEEP_INT 1000 // connection retry sleep interval in usec #define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) #define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) +#define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) /* Common socket address storage structure for IPv4/IPv6 */ union socketAddress { @@ -53,6 +54,8 @@ static inline int envSocketFamily(void) { if (env == NULL) return family; + INFO(NCCL_ENV, "NCCL_SOCKET_FAMILY set by environment to %s", env); + if (strcmp(env, "AF_INET") == 0) family = AF_INET; // IPv4 else if (strcmp(env, "AF_INET6") == 0) @@ -62,7 +65,7 @@ static inline int envSocketFamily(void) { static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) { #ifdef ENABLE_TRACE - char line[1024]; + char line[SOCKET_NAME_MAXLEN+1]; #endif struct netIf userIfs[MAX_IFS]; bool searchNot = prefixList && prefixList[0] == '^'; @@ -165,9 +168,9 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) { static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) { #ifdef ENABLE_TRACE - char line[1024]; + char line[SOCKET_NAME_MAXLEN+1]; #endif - char line_a[1024]; + char line_a[SOCKET_NAME_MAXLEN+1]; int found = 0; struct ifaddrs *interfaces, *interface; getifaddrs(&interfaces); @@ -283,13 +286,16 @@ static ncclResult_t GetSocketAddrFromString(union socketAddress* ua, const char* } static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNameMaxSize, int maxIfs) { + static int shownIfName = 0; int nIfs = 0; // Allow user to force the INET socket family selection int sock_family = envSocketFamily(); // User specified interface char* env = getenv("NCCL_SOCKET_IFNAME"); if (env && strlen(env) > 1) { + INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env); // Specified by user : find or fail + if (shownIfName++ == 0) INFO(NCCL_NET, "NCCL_SOCKET_IFNAME set to %s", env); nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); } else { // Try to automatically pick the right one @@ -299,7 +305,8 @@ static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNam if (nIfs == 0) { char* commId = getenv("NCCL_COMM_ID"); if (commId && strlen(commId) > 1) { - // Try to find interface that is in the same subnet as the IP in comm id + INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); + // Try to find interface that is in the same subnet as the IP in comm id union socketAddress idAddr; GetSocketAddrFromString(&idAddr, commId); nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); @@ -344,7 +351,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname"); #ifdef ENABLE_TRACE - char line[1024]; + char line[SOCKET_NAME_MAXLEN+1]; TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line)); #endif @@ -359,6 +366,10 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { /* IPv4/IPv6 support */ int family = remoteAddr->sa.sa_family; + if (family != AF_INET && family != AF_INET6) { + WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)\n", family, AF_INET, AF_INET6); + return ncclInternalError; + } int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); /* Connect to a hostname / port */ @@ -375,10 +386,8 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt"); SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/ - char line[1024]; -#ifdef ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN+1]; TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line)); -#endif int ret; int timedout_retries = 0; @@ -439,7 +448,7 @@ static ncclResult_t socketSend(int fd, void* ptr, int size) { return ncclSuccess; } -static ncclResult_t socketReceive(int fd, void* ptr, int size) { +static ncclResult_t socketRecv(int fd, void* ptr, int size) { int offset = 0; NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset)); return ncclSuccess; |