Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/socket.h')
-rw-r--r--src/include/socket.h401
1 files changed, 401 insertions, 0 deletions
diff --git a/src/include/socket.h b/src/include/socket.h
new file mode 100644
index 0000000..3321e4d
--- /dev/null
+++ b/src/include/socket.h
@@ -0,0 +1,401 @@
+/*************************************************************************
+ * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef NCCL_SOCKET_H_
+#define NCCL_SOCKET_H_
+
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <netinet/tcp.h>
+#include <unistd.h>
+#include <netdb.h>
+#include <ifaddrs.h>
+#include <net/if.h>
+#include "utils.h"
+
+#define MAX_IF_NAME_SIZE 16
+#define SLEEP_INT 1000 // sleep interval in usec
+#define RETRY_TIMES 2e4 // retry times before reporting a timeout (20 sec)
+
+/* Common socket address storage structure for IPv4/IPv6 */
+union socketAddress {
+ struct sockaddr sa;
+ struct sockaddr_in sin;
+ struct sockaddr_in6 sin6;
+};
+
+/* Format a string representation of a (struct sockaddr *) socket address using getnameinfo()
+ *
+ * Output: "IPv4/IPv6 address<port>"
+ */
+static inline const char *socketToString(struct sockaddr *saddr, char *buf) {
+ if (buf == NULL || saddr == NULL) return NULL;
+ if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; }
+ char host[NI_MAXHOST], service[NI_MAXSERV];
+ (void) getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST|NI_NUMERICSERV);
+ sprintf(buf, "%s<%s>", host, service);
+ return buf;
+}
+
+/* Allow the user to force the IPv4/IPv6 interface selection */
+static inline int envSocketFamily(void) {
+ int family = -1; // Family selection is not forced, will use first one found
+ char* env = getenv("NCCL_SOCKET_FAMILY");
+ if (env == NULL)
+ return family;
+
+ if (strcmp(env, "AF_INET") == 0)
+ family = AF_INET; // IPv4
+ else if (strcmp(env, "AF_INET6") == 0)
+ family = AF_INET6; // IPv6
+ return family;
+}
+
+static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) {
+ char line[1024];
+ struct netIf userIfs[maxIfs];
+ bool searchNot = prefixList && prefixList[0] == '^';
+ int nUserIfs = parseStringList(prefixList, userIfs, maxIfs);
+
+ int found = 0;
+ struct ifaddrs *interfaces, *interface;
+ getifaddrs(&interfaces);
+ for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
+ if (interface->ifa_addr == NULL) continue;
+
+ /* We only support IPv4 & IPv6 */
+ int family = interface->ifa_addr->sa_family;
+ if (family != AF_INET && family != AF_INET6)
+ continue;
+
+ TRACE(INIT|NET,"Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line));
+
+ /* Allow the caller to force the socket family type */
+ if (sock_family != -1 && family != sock_family)
+ continue;
+
+ /* We also need to skip IPv6 loopback interfaces */
+ if (family == AF_INET6) {
+ struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
+ if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue;
+ }
+
+ // check against user specified interfaces
+ if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs) ^ searchNot)) {
+ continue;
+ }
+
+ // Check that this interface has not already been saved
+ // getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
+ bool duplicate = false;
+ for (int i = 0; i < found; i++) {
+ if (strcmp(interface->ifa_name, names+i*maxIfNameSize) == 0) { duplicate = true; break; }
+ }
+
+ if (!duplicate) {
+ // Store the interface name
+ strncpy(names+found*maxIfNameSize, interface->ifa_name, maxIfNameSize);
+ // Store the IP address
+ int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
+ memcpy(addrs+found, interface->ifa_addr, salen);
+ INFO(INIT|NET,"NET : Using interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line));
+ found++;
+ }
+ }
+
+ freeifaddrs(interfaces);
+ return found;
+}
+
+static bool matchSubnet(struct ifaddrs local_if, union socketAddress remote) {
+ /* Check family first */
+ int family = local_if.ifa_addr->sa_family;
+ if (family != remote.sa.sa_family) {
+ return false;
+ }
+
+ if (family == AF_INET) {
+ struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
+ struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
+ struct sockaddr_in& remote_addr = remote.sin;
+ struct in_addr local_subnet, remote_subnet;
+ local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
+ remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
+ return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
+ } else if (family == AF_INET6) {
+ struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
+ struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
+ struct sockaddr_in6& remote_addr = remote.sin6;
+ struct in6_addr& local_in6 = local_addr->sin6_addr;
+ struct in6_addr& mask_in6 = mask->sin6_addr;
+ struct in6_addr& remote_in6 = remote_addr.sin6_addr;
+ bool same = true;
+ int len = 16; //IPv6 address is 16 unsigned char
+ for (int c = 0; c < len; c++) { //Network byte order is big-endian
+ char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
+ char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
+ if (c1 ^ c2) {
+ same = false;
+ break;
+ }
+ }
+ // At last, we need to compare scope id
+ // Two Link-type addresses can have the same subnet address even though they are not in the same scope
+ // For Global type, this field is 0, so a comparison wouldn't matter
+ same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
+ return same;
+ } else {
+ WARN("Net : Unsupported address family type");
+ return false;
+ }
+}
+
+static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress remoteAddr, int ifNameMaxSize, int maxIfs) {
+ char line[1024], line_a[1024];
+ int found = 0;
+ struct ifaddrs *interfaces, *interface;
+ getifaddrs(&interfaces);
+ for (interface = interfaces; interface && !found; interface = interface->ifa_next) {
+ if (interface->ifa_addr == NULL) continue;
+
+ /* We only support IPv4 & IPv6 */
+ int family = interface->ifa_addr->sa_family;
+ if (family != AF_INET && family != AF_INET6)
+ continue;
+
+ // check against user specified interfaces
+ if (!matchSubnet(*interface, remoteAddr)) {
+ continue;
+ }
+
+ // Store the local IP address
+ int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
+ memcpy(localAddrs+found, interface->ifa_addr, salen);
+
+ // Store the interface name
+ strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
+
+ INFO(INIT|NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr.sa), line_a));
+ found++;
+ if (found == maxIfs) break;
+ }
+
+ if (found == 0) {
+ WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr.sa), line_a));
+ }
+ freeifaddrs(interfaces);
+ return found;
+}
+
+static ncclResult_t GetSocketAddrFromString(union socketAddress* ua, const char* ip_port_pair) {
+ if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
+ WARN("Net : string is null");
+ return ncclInvalidArgument;
+ }
+
+ bool ipv6 = ip_port_pair[0] == '[';
+ /* Construct the sockaddress structure */
+ if (!ipv6) {
+ struct netIf ni;
+ // parse <ip_or_hostname>:<port> string, expect one pair
+ if (parseStringList(ip_port_pair, &ni, 1) != 1) {
+ WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
+ return ncclInvalidArgument;
+ }
+
+ struct addrinfo hints, *p;
+ int rv;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+
+ if ( (rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
+ WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
+ return ncclInvalidArgument;
+ }
+
+ // use the first
+ if (p->ai_family == AF_INET) {
+ struct sockaddr_in& sin = ua->sin;
+ memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
+ sin.sin_family = AF_INET; // IPv4
+ //inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
+ sin.sin_port = htons(ni.port); // port
+ } else if (p->ai_family == AF_INET6) {
+ struct sockaddr_in6& sin6 = ua->sin6;
+ memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
+ sin6.sin6_family = AF_INET6; // IPv6
+ sin6.sin6_port = htons(ni.port); // port
+ sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
+ sin6.sin6_scope_id = 0; // should be global scope, set to 0
+ } else {
+ WARN("Net : unsupported IP family");
+ return ncclInvalidArgument;
+ }
+
+ freeaddrinfo(p); // all done with this structure
+
+ } else {
+ int i, j = -1, len = strlen(ip_port_pair);
+ for (i = 1; i < len; i++) {
+ if (ip_port_pair[i] == '%') j = i;
+ if (ip_port_pair[i] == ']') break;
+ }
+ if (i == len) {
+ WARN("Net : No valid [IPv6]:port pair found");
+ return ncclInvalidArgument;
+ }
+ bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
+
+ char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
+ memset(ip_str, '\0', sizeof(ip_str));
+ memset(port_str, '\0', sizeof(port_str));
+ memset(if_name, '\0', sizeof(if_name));
+ strncpy(ip_str, ip_port_pair+1, global_scope ? i-1 : j-1);
+ strncpy(port_str, ip_port_pair+i+2, len-i-1);
+ int port = atoi(port_str);
+ if (!global_scope) strncpy(if_name, ip_port_pair+j+1, i-j-1); // If not global scope, we need the intf name
+
+ struct sockaddr_in6& sin6 = ua->sin6;
+ sin6.sin6_family = AF_INET6; // IPv6
+ inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
+ sin6.sin6_port = htons(port); // port
+ sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
+ sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
+ }
+ return ncclSuccess;
+}
+
+static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNameMaxSize, int maxIfs) {
+ int nIfs = 0;
+ // Allow user to force the INET socket family selection
+ int sock_family = envSocketFamily();
+ // User specified interface
+ char* env = getenv("NCCL_SOCKET_IFNAME");
+ if (env && strlen(env) > 1) {
+ // Specified by user : find or fail
+ nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
+ } else {
+ // Try to automatically pick the right one
+ // Start with IB
+ nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
+ // else see if we can get some hint from COMM ID
+ if (nIfs == 0) {
+ char* commId = getenv("NCCL_COMM_ID");
+ if (commId && strlen(commId) > 1) {
+ // Try to find interface that is in the same subnet as the IP in comm id
+ union socketAddress idAddr;
+ GetSocketAddrFromString(&idAddr, commId);
+ nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, idAddr, ifNameMaxSize, maxIfs);
+ }
+ }
+ // Then look for anything else (but not docker or lo)
+ if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
+ // Finally look for docker, then lo.
+ if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
+ if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
+ }
+ return nIfs;
+}
+
+static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) {
+ /* IPv4/IPv6 support */
+ int family = localAddr->sa.sa_family;
+ int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
+
+ /* Create socket and bind it to a port */
+ int sockfd = socket(family, SOCK_STREAM, 0);
+ if (sockfd == -1) {
+ WARN("Net : Socket creation failed : %s", strerror(errno));
+ return ncclSystemError;
+ }
+
+ int opt = 1;
+ SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
+
+ // localAddr port should be 0 (Any port)
+ SYSCHECK(bind(sockfd, &localAddr->sa, salen), "bind");
+
+ /* Get the assigned Port */
+ socklen_t size = salen;
+ SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname");
+
+#ifdef ENABLE_TRACE
+ char line[1024];
+ TRACE(INIT|NET,"Listening on socket %s", socketToString(&localAddr->sa, line));
+#endif
+
+ /* Put the socket in listen mode */
+ SYSCHECK(listen(sockfd, 128), "listen");
+ *fd = sockfd;
+ return ncclSuccess;
+}
+
+static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
+ /* IPv4/IPv6 support */
+ int family = remoteAddr->sa.sa_family;
+ int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
+
+ /* Connect to a hostname / port */
+ *fd = socket(family, SOCK_STREAM, 0);
+ if (*fd == -1) {
+ WARN("Net : Socket creation failed : %s", strerror(errno));
+ return ncclSystemError;
+ }
+
+ const int one = 1;
+ SYSCHECK(setsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
+
+ /* const int bufsize = 128*1024;
+ SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
+ SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
+
+#ifdef ENABLE_TRACE
+ char line[1024];
+ TRACE(INIT|NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line));
+#endif
+
+ SYSCHECKNTIMES(connect(*fd, &remoteAddr->sa, salen), "connect", RETRY_TIMES, SLEEP_INT, ECONNREFUSED);
+ return ncclSuccess;
+}
+
+static ncclResult_t socketReceive(int fd, void* ptr, int size) {
+ char* data = (char*)ptr;
+ int offset = 0;
+ while (offset < size) {
+ int recvsize;
+ SYSCHECKVAL(recv(fd, data, size-offset, 0), "recv", recvsize);
+ if (recvsize == 0) {
+ WARN("Net : Connection closed by remote peer");
+ return ncclSystemError;
+ }
+ if (recvsize == -1) {
+ INFO(NET,"Recv : got retcode %d, retrying", errno);
+ continue;
+ }
+ data += recvsize;
+ offset += recvsize;
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t socketSend(int fd, void* ptr, int size) {
+ char* data = (char*)ptr;
+ int offset = 0;
+ while (offset < size) {
+ int sendsize;
+ SYSCHECKVAL(write(fd, data, size-offset), "write", sendsize);
+ if (sendsize == -1) {
+ INFO(NET,"Send : got retcode %d, retrying", errno);
+ continue;
+ }
+ data += sendsize;
+ offset += sendsize;
+ }
+ return ncclSuccess;
+}
+
+#endif