diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/include/nccl_net.h | 2 | ||||
-rw-r--r-- | src/include/socket.h | 53 | ||||
-rw-r--r-- | src/transport/net.cu | 20 | ||||
-rw-r--r-- | src/transport/net_ib.cu | 51 | ||||
-rw-r--r-- | src/transport/net_socket.cu | 76 |
5 files changed, 124 insertions, 78 deletions
diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index 7dbbc37..5d3ec7c 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -41,8 +41,10 @@ typedef struct { // Finalize connection establishment after remote peer has called connectHandle ncclResult_t (*accept)(void* listenComm, void** recvComm); // Asynchronous send to a peer. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + // May return request == NULL if the call cannot be performed (or would block) ncclResult_t (*isend)(void* sendComm, void* data, int size, int type, void** request); // Asynchronous recv from a peer. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + // May return request == NULL if the call cannot be performed (or would block) ncclResult_t (*irecv)(void* recvComm, void* data, int size, int type, void** request); // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is // visible to the GPU diff --git a/src/include/socket.h b/src/include/socket.h index 9d2b2c8..81e1651 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -370,39 +370,46 @@ static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { return ncclSuccess; } -static ncclResult_t socketReceive(int fd, void* ptr, int size) { +#define NCCL_SOCKET_SEND 0 +#define NCCL_SOCKET_RECV 1 +static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) { + int bytes = 0; char* data = (char*)ptr; - int offset = 0; - while (offset < size) { - int recvsize; - SYSCHECKVAL(recv(fd, data, size-offset, 0), "recv", recvsize); - if (recvsize == 0) { + do { + if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); + if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), MSG_DONTWAIT); + if (op == NCCL_SOCKET_RECV && bytes == 0) { WARN("Net : Connection closed by remote peer"); return ncclSystemError; } - if (recvsize == -1) { - INFO(NCCL_NET,"Recv : got retcode %d, retrying", errno); - continue; + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + WARN("Call to recv failed : %s", strerror(errno)); + return ncclSystemError; + } else { + bytes = 0; + } } - data += recvsize; - offset += recvsize; - } + (*offset) += bytes; + } while (bytes > 0 && (*offset) < size); + return ncclSuccess; +} + +static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) { + while (*offset < size) + NCCLCHECK(socketProgress(op, fd, ptr, size, offset)); return ncclSuccess; } static ncclResult_t socketSend(int fd, void* ptr, int size) { - char* data = (char*)ptr; int offset = 0; - while (offset < size) { - int sendsize; - SYSCHECKVAL(write(fd, data, size-offset), "write", sendsize); - if (sendsize == -1) { - INFO(NCCL_NET,"Send : got retcode %d, retrying", errno); - continue; - } - data += sendsize; - offset += sendsize; - } + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset)); + return ncclSuccess; +} + +static ncclResult_t socketReceive(int fd, void* ptr, int size) { + int offset = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset)); return ncclSuccess; } diff --git a/src/transport/net.cu b/src/transport/net.cu index 8a7e3b8..165187c 100644 --- a/src/transport/net.cu +++ b/src/transport/net.cu @@ -416,17 +416,21 @@ ncclResult_t netSendProxy(struct ncclProxyArgs* args) { while (f1[0] != flag || f2[0] != flag); } NCCLCHECK(ncclNetIsend(resources->netSendComm, lines, size, ptrType, requests+slot)); - sizesFifo[slot] = size; - tail++; - idle = 0; + if (requests[slot] != NULL) { + sizesFifo[slot] = size; + tail++; + idle = 0; + } } } } else while (tail < *prevTail) { // Send through network int slot = tail%args->substeps; NCCLCHECK(ncclNetIsend(resources->netSendComm, localBuff+slot*sliceSize, sizesFifo[slot], ptrType, requests+slot)); - tail++; - idle = 0; + if (requests[slot] != NULL) { + tail++; + idle = 0; + } } if (head < tail) { int done; @@ -502,8 +506,10 @@ ncclResult_t netRecvProxy(struct ncclProxyArgs* args) { if ((tail < head + args->substeps) && (tail < *nextHead + args->substeps) && (tail < end)) { int slot = tail%args->substeps; NCCLCHECK(ncclNetIrecv(resources->netRecvComm, localBuff+slot*sliceSize, sliceSize, ptrType, requests+slot)); - tail++; - idle = 0; + if (requests[slot] != NULL) { + tail++; + idle = 0; + } } if (tail > head) { int done; diff --git a/src/transport/net_ib.cu b/src/transport/net_ib.cu index fb8bd7b..18e158d 100644 --- a/src/transport/net_ib.cu +++ b/src/transport/net_ib.cu @@ -551,25 +551,31 @@ ncclResult_t ncclIbGetRequest(struct ncclIbRequest* reqs, struct ncclIbRequest** } ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { - if (comm->ready == 0) { - struct ncclIbQpInfo remQpInfo; - struct ibv_qp* qp = comm->qp; - NCCLCHECK(socketReceive(comm->fd, &remQpInfo, sizeof(remQpInfo))); - NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); - int go = 1; - NCCLCHECK(socketSend(comm->fd, &go, sizeof(go))); - comm->ready = 1; - } + struct ncclIbQpInfo remQpInfo; + struct ibv_qp* qp = comm->qp; + + // Do not block on this receive, return if not ready. + int bytes = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + if (bytes == 0) return ncclSuccess; // Try again later + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + + NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + comm->ready = 1; + + // Block until this is done. It *should* not block indefinitely. + NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int))); + return ncclSuccess; } ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) { - if (comm->ready == 0) { - int go; - NCCLCHECK(socketReceive(comm->fd, &go, sizeof(go))); - comm->ready = 1; - } + // Do not block on this receive, return if not ready. + int bytes = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); + if (bytes == 0) return ncclSuccess; // Try again later + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); return ncclSuccess; } @@ -625,7 +631,13 @@ ncclResult_t ncclIbGetMr(struct ncclIbVerbs* verbs, void* data, int size, struct ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int type, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + + // Wait for the receiver to have posted the corresponding receive + volatile struct ncclIbSendFifo* slot = comm->fifo + (comm->fifoHead%MAX_REQUESTS); + volatile uint32_t * readyPtr = &slot->ready; + if (*readyPtr == 0) { *request = NULL; return ncclSuccess; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(comm->reqs, &req)); @@ -650,10 +662,6 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int type, void** wr.opcode = IBV_WR_SEND; wr.send_flags = IBV_SEND_SIGNALED; - // Wait for receiver to have posted the recv - volatile struct ncclIbSendFifo* slot = comm->fifo + (comm->fifoHead%MAX_REQUESTS); - volatile uint32_t * readyPtr = &slot->ready; - while (*readyPtr == 0) sched_yield(); #if USE_RDMA_WRITE __sync_synchronize(); // order the readyPtr load against rkey load below // Sanity checks to catch user collective call count/size mismatches @@ -714,7 +722,8 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, int type, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) { *request = NULL; return ncclSuccess; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(comm->reqs, &req)); diff --git a/src/transport/net_socket.cu b/src/transport/net_socket.cu index a8ae866..1efee15 100644 --- a/src/transport/net_socket.cu +++ b/src/transport/net_socket.cu @@ -72,8 +72,12 @@ struct ncclSocketHandle { }; struct ncclSocketRequest { - int used; + int op; + void* data; int size; + int fd; + int offset; + int used; }; struct ncclSocketReqs { @@ -144,15 +148,19 @@ ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { #define MAX_REQUESTS 128 -ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, struct ncclSocketRequest** req) { +ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) { if (reqs->requests == NULL) { NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS)); } for (int i=0; i<MAX_REQUESTS; i++) { struct ncclSocketRequest* r = reqs->requests+i; if (r->used == 0) { + r->op = op; + r->data = data; + r->size = size; + r->fd = fd; + r->offset = -1; r->used = 1; - r->size = -1; *req = r; return ncclSuccess; } @@ -161,29 +169,53 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, struct ncclSocket return ncclInternalError; } +ncclResult_t ncclSocketTest(void* request, int* done, int* size) { + *done = 0; + struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; + if (r == NULL) { + WARN("NET/Socket : test called with NULL request"); + return ncclInternalError; + } + if (r->offset == -1) { /* try to send/recv size */ + int data = r->size; + int offset = 0; + NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset)); + + if (offset == 0) return ncclSuccess; /* Not ready -- retry later */ + + // Not sure we could ever receive less than 4 bytes, but just in case ... + if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset)); + + // Check size is less or equal to the size provided by the user + if (r->op == NCCL_SOCKET_RECV && data > r->size) { + WARN("NET/Socket : message truncated : receiving %d bytes instead of %d", data, r->size); + return ncclInternalError; + } + r->size = data; + r->offset = 0; + } + if (r->offset < r->size) { + NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset)); + } + if (r->offset == r->size) { + if (size) *size = r->size; + *done = 1; + r->used = 0; + } + return ncclSuccess; +} + ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, int type, void** request) { if (type != NCCL_PTR_HOST) return ncclInternalError; struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm; - *request = NULL; - NCCLCHECK(socketSend(comm->fd, &size, sizeof(int))); - NCCLCHECK(socketSend(comm->fd, data, size)); + NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request)); return ncclSuccess; } ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, int type, void** request) { if (type != NCCL_PTR_HOST) return ncclInternalError; struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm; - int recvSize; - NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int))); - if (recvSize > size) { - WARN("Message truncated : received %d bytes instead of %d", recvSize, size); - return ncclInternalError; - } - NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size))); - struct ncclSocketRequest* recvReq = NULL; - NCCLCHECK(ncclSocketGetRequest(&comm->reqs, &recvReq)); - recvReq->size = recvSize; - *request = recvReq; + NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request)); return ncclSuccess; } @@ -192,16 +224,6 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size) { return ncclInternalError; } -ncclResult_t ncclSocketTest(void* request, int* done, int* size) { - *done = 1; - struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; - if (r) { - if (size) *size = r->size; - r->used = 0; - } - return ncclSuccess; -} - ncclResult_t ncclSocketClose(void* opaqueComm) { struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm; if (comm) { |