Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/socket.h')
-rw-r--r--src/include/socket.h29
1 files changed, 19 insertions, 10 deletions
diff --git a/src/include/socket.h b/src/include/socket.h
index 96bf5db..e903b04 100644
--- a/src/include/socket.h
+++ b/src/include/socket.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -21,6 +21,7 @@
#define SLEEP_INT 1000 // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
+#define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV)
/* Common socket address storage structure for IPv4/IPv6 */
union socketAddress {
@@ -53,6 +54,8 @@ static inline int envSocketFamily(void) {
if (env == NULL)
return family;
+ INFO(NCCL_ENV, "NCCL_SOCKET_FAMILY set by environment to %s", env);
+
if (strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if (strcmp(env, "AF_INET6") == 0)
@@ -62,7 +65,7 @@ static inline int envSocketFamily(void) {
static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
#endif
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
@@ -165,9 +168,9 @@ static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) {
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
#endif
- char line_a[1024];
+ char line_a[SOCKET_NAME_MAXLEN+1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
@@ -283,13 +286,16 @@ static ncclResult_t GetSocketAddrFromString(union socketAddress* ua, const char*
}
static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNameMaxSize, int maxIfs) {
+ static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = envSocketFamily();
// User specified interface
char* env = getenv("NCCL_SOCKET_IFNAME");
if (env && strlen(env) > 1) {
+ INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
+ if (shownIfName++ == 0) INFO(NCCL_NET, "NCCL_SOCKET_IFNAME set to %s", env);
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
@@ -299,7 +305,8 @@ static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNam
if (nIfs == 0) {
char* commId = getenv("NCCL_COMM_ID");
if (commId && strlen(commId) > 1) {
- // Try to find interface that is in the same subnet as the IP in comm id
+ INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId);
+ // Try to find interface that is in the same subnet as the IP in comm id
union socketAddress idAddr;
GetSocketAddrFromString(&idAddr, commId);
nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
@@ -344,7 +351,7 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname");
#ifdef ENABLE_TRACE
- char line[1024];
+ char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line));
#endif
@@ -359,6 +366,10 @@ static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr)
static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
/* IPv4/IPv6 support */
int family = remoteAddr->sa.sa_family;
+ if (family != AF_INET && family != AF_INET6) {
+ WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)\n", family, AF_INET, AF_INET6);
+ return ncclInternalError;
+ }
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
/* Connect to a hostname / port */
@@ -375,10 +386,8 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
- char line[1024];
-#ifdef ENABLE_TRACE
+ char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line));
-#endif
int ret;
int timedout_retries = 0;
@@ -439,7 +448,7 @@ static ncclResult_t socketSend(int fd, void* ptr, int size) {
return ncclSuccess;
}
-static ncclResult_t socketReceive(int fd, void* ptr, int size) {
+static ncclResult_t socketRecv(int fd, void* ptr, int size) {
int offset = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset));
return ncclSuccess;