diff options
Diffstat (limited to 'src/transport/shm.cc')
-rw-r--r-- | src/transport/shm.cc | 111 |
1 files changed, 19 insertions, 92 deletions
diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 2ec5f23..60f16c8 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -4,13 +4,8 @@ * See LICENSE.txt for license information ************************************************************************/ -#include "core.h" -#include "utils.h" -#include "transport.h" -#include "param.h" +#include "comm.h" #include "shm.h" -#include <unistd.h> -#include <cuda_runtime.h> struct shmConnectInfo { uint64_t pidHash; @@ -40,98 +35,29 @@ struct shmRecvResources { NCCL_PARAM(ShmDisable, "SHM_DISABLE", 0); -/* Determine if we can communicate with the peer */ -ncclResult_t shmCanConnect(ncclTvalue_t* ret, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo) { - *ret = ((ncclParamShmDisable() == 1) || (myInfo->hostHash != peerInfo->hostHash)) ? 0 : 1; - return ncclSuccess; -} +/* Determine two peers can communicate with SHM */ +ncclResult_t shmCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { + *ret = 0; -static inline int groupFirst(int nranks, int* groups, int group, int rankToAvoid) { - for (int rank = 0; rank<nranks; rank++) { - if ((groups[rank] == group) && (rank != rankToAvoid)) return rank; - } - return -1; -} + if (ncclParamShmDisable() == 1) return ncclSuccess; -static inline int groupLast(int nranks, int* groups, int group, int rankToAvoid) { - for (int rank = nranks-1; rank>=0; rank--) { - if ((groups[rank] == group) && (rank != rankToAvoid)) return rank; - } - return -1; -} + // Same host? + TRACE(NCCL_INIT|NCCL_SHM, "peer1 hostHash %lx peer2 hostHash %lx", info1->hostHash, info2->hostHash); + if (info1->hostHash != info2->hostHash) return ncclSuccess; + + // Common /dev/shm (between containers) ? + TRACE(NCCL_INIT|NCCL_SHM, "peer1 shmDev %lx peer2 shmDev %lx", info1->shmDev, info2->shmDev); + if (info1->shmDev != info2->shmDev) return ncclSuccess; + + *ret = 1; -#define MAXGROUPS 16 - -ncclResult_t shmGetRings(int nranks, int* groups, int* subgroups, ncclTvalue_t* values, int* nringsRet, int* prev, int* next, int minScore, int* nthreads) { - if (*nringsRet == MAXCHANNELS) *nringsRet = 1; - int nGroups = groups[nranks-1] + 1; - int starts[MAXGROUPS]; - int ends[MAXGROUPS]; - for (int ring = 0; ring<*nringsRet; ring++) { - int startGroup = -1, endGroup = -1; - for (int group = 0; group<nGroups; group++) { - int start = -1; - int end = -1; - int nranksInGroup = 0; - for (int rank=0; rank<nranks; rank++) { - if (groups[rank] != group) continue; - nranksInGroup++; - if (prev[ring*nranks+rank] != -1) { - if (start != -1) { - WARN("Multiple starts found in group"); - } - start = rank; - startGroup = group; - } - if (next[ring*nranks+rank] != -1) { - if (end != -1) { - WARN("Multiple ends found in group"); - } - end = rank; - endGroup = group; - } - } - if (nranksInGroup == 1) { - start = end = groupFirst(nranks, groups, group, -1); - } else { - if (start == -1) - start = groupFirst(nranks, groups, group, end); - if (end == -1) - end = groupLast(nranks, groups, group, start); - } - if (start == -1 || end == -1) { - *nringsRet = ring; - return ncclSuccess; - } - starts[group] = start; - ends[group] = end; - } - if (endGroup == -1 || startGroup == -1) { - startGroup = 0; - endGroup = nGroups-1; - // Close the loop - next[ring*nranks+ends[endGroup]] = starts[startGroup]; - prev[ring*nranks+starts[startGroup]] = ends[endGroup]; - } - int group = startGroup; - for (int i=0; i<nGroups-2; i++) { - int nextGroup = (group+1)%nGroups; - if (nextGroup == endGroup) nextGroup = (nextGroup+1)%nGroups; - next[ring*nranks+ends[group]] = starts[nextGroup]; - prev[ring*nranks+starts[nextGroup]] = ends[group]; - group = nextGroup; - } - // Connect with the last - next[ring*nranks+ends[group]] = starts[endGroup]; - prev[ring*nranks+starts[endGroup]] = ends[group]; - } return ncclSuccess; } #define MAX_SHM_NAME_LEN 1024 /* Create and return connect structures for this peer to connect to me */ -ncclResult_t shmSendSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int buffSize, int channelId) { +ncclResult_t shmSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int buffSize, int channelId) { struct shmSendResources* resources; NCCLCHECK(ncclCalloc(&resources, 1)); @@ -149,13 +75,13 @@ ncclResult_t shmSendSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peer TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info.shmSize); NCCLCHECK(shmOpen(shmName, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); - INFO(NCCL_INIT|NCCL_SHM,"Ring %02d : %d[%d] -> %d[%d] via direct shared memory", channelId, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev); + INFO(NCCL_INIT|NCCL_SHM,"Ring %02d : %d[%lx] -> %d[%lx] via direct shared memory", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId); static_assert(sizeof(struct shmConnectInfo) <= sizeof(struct ncclConnect), "shm Connect Recv Info is too big"); memcpy(connectInfo, &info, sizeof(struct shmConnectInfo)); return ncclSuccess; } -ncclResult_t shmRecvSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int buffSize, int channelId) { +ncclResult_t shmRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int buffSize, int channelId) { struct shmRecvResources* resources; NCCLCHECK(ncclCalloc(&resources, 1)); recv->transportResources = resources; @@ -194,6 +120,7 @@ ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, struct ncclConnecto send->transportResources = resources; send->conn.buff = resources->devRemHostMem->buff; send->conn.llBuff = resources->devRemHostMem->llBuff; + send->conn.ll128Buff = resources->devRemHostMem->ll128Buff; send->conn.tail = &resources->devRemHostMem->tail; send->conn.opCountRem = &resources->devRemHostMem->opCount; @@ -218,6 +145,7 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, struct ncclConnecto recv->conn.buff = resources->devHostMem->buff; recv->conn.llBuff = resources->devHostMem->llBuff; + recv->conn.ll128Buff = resources->devHostMem->ll128Buff; recv->conn.tail = &resources->devHostMem->tail; recv->conn.opCountLoc = &resources->devHostMem->opCount; return ncclSuccess; @@ -242,7 +170,6 @@ ncclResult_t shmRecvFree(void* transportResources) { struct ncclTransport shmTransport = { "SHM", shmCanConnect, - shmGetRings, { shmSendSetup, shmSendConnect, shmSendFree, NULL }, { shmRecvSetup, shmRecvConnect, shmRecvFree, NULL } }; |