diff options
Diffstat (limited to 'src/transport')
-rw-r--r-- | src/transport/net.cu | 20 | ||||
-rw-r--r-- | src/transport/net_ib.cu | 51 | ||||
-rw-r--r-- | src/transport/net_socket.cu | 76 |
3 files changed, 92 insertions, 55 deletions
diff --git a/src/transport/net.cu b/src/transport/net.cu index 8a7e3b8..165187c 100644 --- a/src/transport/net.cu +++ b/src/transport/net.cu @@ -416,17 +416,21 @@ ncclResult_t netSendProxy(struct ncclProxyArgs* args) { while (f1[0] != flag || f2[0] != flag); } NCCLCHECK(ncclNetIsend(resources->netSendComm, lines, size, ptrType, requests+slot)); - sizesFifo[slot] = size; - tail++; - idle = 0; + if (requests[slot] != NULL) { + sizesFifo[slot] = size; + tail++; + idle = 0; + } } } } else while (tail < *prevTail) { // Send through network int slot = tail%args->substeps; NCCLCHECK(ncclNetIsend(resources->netSendComm, localBuff+slot*sliceSize, sizesFifo[slot], ptrType, requests+slot)); - tail++; - idle = 0; + if (requests[slot] != NULL) { + tail++; + idle = 0; + } } if (head < tail) { int done; @@ -502,8 +506,10 @@ ncclResult_t netRecvProxy(struct ncclProxyArgs* args) { if ((tail < head + args->substeps) && (tail < *nextHead + args->substeps) && (tail < end)) { int slot = tail%args->substeps; NCCLCHECK(ncclNetIrecv(resources->netRecvComm, localBuff+slot*sliceSize, sliceSize, ptrType, requests+slot)); - tail++; - idle = 0; + if (requests[slot] != NULL) { + tail++; + idle = 0; + } } if (tail > head) { int done; diff --git a/src/transport/net_ib.cu b/src/transport/net_ib.cu index fb8bd7b..18e158d 100644 --- a/src/transport/net_ib.cu +++ b/src/transport/net_ib.cu @@ -551,25 +551,31 @@ ncclResult_t ncclIbGetRequest(struct ncclIbRequest* reqs, struct ncclIbRequest** } ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { - if (comm->ready == 0) { - struct ncclIbQpInfo remQpInfo; - struct ibv_qp* qp = comm->qp; - NCCLCHECK(socketReceive(comm->fd, &remQpInfo, sizeof(remQpInfo))); - NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); - int go = 1; - NCCLCHECK(socketSend(comm->fd, &go, sizeof(go))); - comm->ready = 1; - } + struct ncclIbQpInfo remQpInfo; + struct ibv_qp* qp = comm->qp; + + // Do not block on this receive, return if not ready. + int bytes = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + if (bytes == 0) return ncclSuccess; // Try again later + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &remQpInfo, sizeof(remQpInfo), &bytes)); + + NCCLCHECK(ncclIbRtrQp(qp, &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + comm->ready = 1; + + // Block until this is done. It *should* not block indefinitely. + NCCLCHECK(socketSend(comm->fd, &comm->ready, sizeof(int))); + return ncclSuccess; } ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) { - if (comm->ready == 0) { - int go; - NCCLCHECK(socketReceive(comm->fd, &go, sizeof(go))); - comm->ready = 1; - } + // Do not block on this receive, return if not ready. + int bytes = 0; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); + if (bytes == 0) return ncclSuccess; // Try again later + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, comm->fd, &comm->ready, sizeof(int), &bytes)); return ncclSuccess; } @@ -625,7 +631,13 @@ ncclResult_t ncclIbGetMr(struct ncclIbVerbs* verbs, void* data, int size, struct ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int type, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + + // Wait for the receiver to have posted the corresponding receive + volatile struct ncclIbSendFifo* slot = comm->fifo + (comm->fifoHead%MAX_REQUESTS); + volatile uint32_t * readyPtr = &slot->ready; + if (*readyPtr == 0) { *request = NULL; return ncclSuccess; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(comm->reqs, &req)); @@ -650,10 +662,6 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int type, void** wr.opcode = IBV_WR_SEND; wr.send_flags = IBV_SEND_SIGNALED; - // Wait for receiver to have posted the recv - volatile struct ncclIbSendFifo* slot = comm->fifo + (comm->fifoHead%MAX_REQUESTS); - volatile uint32_t * readyPtr = &slot->ready; - while (*readyPtr == 0) sched_yield(); #if USE_RDMA_WRITE __sync_synchronize(); // order the readyPtr load against rkey load below // Sanity checks to catch user collective call count/size mismatches @@ -714,7 +722,8 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, uint32_t rkey, uint64_t ncclResult_t ncclIbIrecv(void* recvComm, void* data, int size, int type, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) { *request = NULL; return ncclSuccess; } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(comm->reqs, &req)); diff --git a/src/transport/net_socket.cu b/src/transport/net_socket.cu index a8ae866..1efee15 100644 --- a/src/transport/net_socket.cu +++ b/src/transport/net_socket.cu @@ -72,8 +72,12 @@ struct ncclSocketHandle { }; struct ncclSocketRequest { - int used; + int op; + void* data; int size; + int fd; + int offset; + int used; }; struct ncclSocketReqs { @@ -144,15 +148,19 @@ ncclResult_t ncclSocketAccept(void* listenComm, void** recvComm) { #define MAX_REQUESTS 128 -ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, struct ncclSocketRequest** req) { +ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, int op, void* data, int size, int fd, struct ncclSocketRequest** req) { if (reqs->requests == NULL) { NCCLCHECK(ncclCalloc(&reqs->requests, MAX_REQUESTS)); } for (int i=0; i<MAX_REQUESTS; i++) { struct ncclSocketRequest* r = reqs->requests+i; if (r->used == 0) { + r->op = op; + r->data = data; + r->size = size; + r->fd = fd; + r->offset = -1; r->used = 1; - r->size = -1; *req = r; return ncclSuccess; } @@ -161,29 +169,53 @@ ncclResult_t ncclSocketGetRequest(struct ncclSocketReqs* reqs, struct ncclSocket return ncclInternalError; } +ncclResult_t ncclSocketTest(void* request, int* done, int* size) { + *done = 0; + struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; + if (r == NULL) { + WARN("NET/Socket : test called with NULL request"); + return ncclInternalError; + } + if (r->offset == -1) { /* try to send/recv size */ + int data = r->size; + int offset = 0; + NCCLCHECK(socketProgress(r->op, r->fd, &data, sizeof(int), &offset)); + + if (offset == 0) return ncclSuccess; /* Not ready -- retry later */ + + // Not sure we could ever receive less than 4 bytes, but just in case ... + if (offset < sizeof(int)) NCCLCHECK(socketWait(r->op, r->fd, &data, sizeof(int), &offset)); + + // Check size is less or equal to the size provided by the user + if (r->op == NCCL_SOCKET_RECV && data > r->size) { + WARN("NET/Socket : message truncated : receiving %d bytes instead of %d", data, r->size); + return ncclInternalError; + } + r->size = data; + r->offset = 0; + } + if (r->offset < r->size) { + NCCLCHECK(socketProgress(r->op, r->fd, r->data, r->size, &r->offset)); + } + if (r->offset == r->size) { + if (size) *size = r->size; + *done = 1; + r->used = 0; + } + return ncclSuccess; +} + ncclResult_t ncclSocketIsend(void* sendComm, void* data, int size, int type, void** request) { if (type != NCCL_PTR_HOST) return ncclInternalError; struct ncclSocketComm* comm = (struct ncclSocketComm*)sendComm; - *request = NULL; - NCCLCHECK(socketSend(comm->fd, &size, sizeof(int))); - NCCLCHECK(socketSend(comm->fd, data, size)); + NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_SEND, data, size, comm->fd, (struct ncclSocketRequest**)request)); return ncclSuccess; } ncclResult_t ncclSocketIrecv(void* recvComm, void* data, int size, int type, void** request) { if (type != NCCL_PTR_HOST) return ncclInternalError; struct ncclSocketComm* comm = (struct ncclSocketComm*)recvComm; - int recvSize; - NCCLCHECK(socketReceive(comm->fd, &recvSize, sizeof(int))); - if (recvSize > size) { - WARN("Message truncated : received %d bytes instead of %d", recvSize, size); - return ncclInternalError; - } - NCCLCHECK(socketReceive(comm->fd, data, std::min(recvSize, size))); - struct ncclSocketRequest* recvReq = NULL; - NCCLCHECK(ncclSocketGetRequest(&comm->reqs, &recvReq)); - recvReq->size = recvSize; - *request = recvReq; + NCCLCHECK(ncclSocketGetRequest(&comm->reqs, NCCL_SOCKET_RECV, data, size, comm->fd, (struct ncclSocketRequest**)request)); return ncclSuccess; } @@ -192,16 +224,6 @@ ncclResult_t ncclSocketFlush(void* recvComm, void* data, int size) { return ncclInternalError; } -ncclResult_t ncclSocketTest(void* request, int* done, int* size) { - *done = 1; - struct ncclSocketRequest *r = (struct ncclSocketRequest*)request; - if (r) { - if (size) *size = r->size; - r->used = 0; - } - return ncclSuccess; -} - ncclResult_t ncclSocketClose(void* opaqueComm) { struct ncclSocketComm* comm = (struct ncclSocketComm*)opaqueComm; if (comm) { |