diff options
Diffstat (limited to 'src/transport/shm.cc')
-rw-r--r-- | src/transport/shm.cc | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/src/transport/shm.cc b/src/transport/shm.cc new file mode 100644 index 0000000..60f16c8 --- /dev/null +++ b/src/transport/shm.cc @@ -0,0 +1,175 @@ +/************************************************************************* + * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "comm.h" +#include "shm.h" + +struct shmConnectInfo { + uint64_t pidHash; + int id; + int sendRank; + int recvRank; + int shmSize; +}; + +struct shmSendResources { + int remShmSize; + struct ncclRecvMem* remHostMem; + struct ncclRecvMem* devRemHostMem; + int shmSize; + struct ncclSendMem* hostMem; + struct ncclSendMem* devHostMem; +}; + +struct shmRecvResources { + int remShmSize; + struct ncclSendMem* remHostMem; + struct ncclSendMem* devRemHostMem; + int shmSize; + struct ncclRecvMem* hostMem; + struct ncclRecvMem* devHostMem; +}; + +NCCL_PARAM(ShmDisable, "SHM_DISABLE", 0); + +/* Determine two peers can communicate with SHM */ +ncclResult_t shmCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { + *ret = 0; + + if (ncclParamShmDisable() == 1) return ncclSuccess; + + // Same host? + TRACE(NCCL_INIT|NCCL_SHM, "peer1 hostHash %lx peer2 hostHash %lx", info1->hostHash, info2->hostHash); + if (info1->hostHash != info2->hostHash) return ncclSuccess; + + // Common /dev/shm (between containers) ? + TRACE(NCCL_INIT|NCCL_SHM, "peer1 shmDev %lx peer2 shmDev %lx", info1->shmDev, info2->shmDev); + if (info1->shmDev != info2->shmDev) return ncclSuccess; + + *ret = 1; + + return ncclSuccess; +} + +#define MAX_SHM_NAME_LEN 1024 + +/* Create and return connect structures for this peer to connect to me */ +ncclResult_t shmSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int buffSize, int channelId) { + + struct shmSendResources* resources; + NCCLCHECK(ncclCalloc(&resources, 1)); + send->transportResources = resources; + + struct shmConnectInfo info; + info.id = channelId; + info.pidHash = myInfo->pidHash; + info.sendRank = myInfo->rank; + info.recvRank = peerInfo->rank; + + char shmName[MAX_SHM_NAME_LEN]; + sprintf(shmName, "nccl-shm-send-%lx-%d-%d-%d", info.pidHash, info.id, info.sendRank, info.recvRank); + info.shmSize = resources->shmSize = sizeof(struct ncclSendMem); + TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info.shmSize); + NCCLCHECK(shmOpen(shmName, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + + INFO(NCCL_INIT|NCCL_SHM,"Ring %02d : %d[%lx] -> %d[%lx] via direct shared memory", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId); + static_assert(sizeof(struct shmConnectInfo) <= sizeof(struct ncclConnect), "shm Connect Recv Info is too big"); + memcpy(connectInfo, &info, sizeof(struct shmConnectInfo)); + return ncclSuccess; +} + +ncclResult_t shmRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int buffSize, int channelId) { + struct shmRecvResources* resources; + NCCLCHECK(ncclCalloc(&resources, 1)); + recv->transportResources = resources; + + struct shmConnectInfo info; + info.id = channelId; + info.pidHash = myInfo->pidHash; + info.sendRank = peerInfo->rank; + info.recvRank = myInfo->rank; + + char shmName[MAX_SHM_NAME_LEN]; + sprintf(shmName, "nccl-shm-recv-%lx-%d-%d-%d", info.pidHash, info.id, info.sendRank, info.recvRank); + info.shmSize = resources->shmSize = offsetof(struct ncclRecvMem, buff)+buffSize; + TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info.shmSize); + NCCLCHECK(shmOpen(shmName, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1)); + + static_assert(sizeof(struct shmConnectInfo) <= sizeof(struct ncclConnect), "shm Connect Send Info is too big"); + memcpy(connectInfo, &info, sizeof(struct shmConnectInfo)); + return ncclSuccess; +} + +/* Connect to this peer */ +ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, struct ncclConnector* send) { + // Setup device pointers + struct shmConnectInfo* info = (struct shmConnectInfo*)connectInfo; + struct shmSendResources* resources = (struct shmSendResources*)send->transportResources; + + char shmName[MAX_SHM_NAME_LEN]; + sprintf(shmName, "nccl-shm-recv-%lx-%d-%d-%d", info->pidHash, info->id, info->sendRank, info->recvRank); + resources->remShmSize = info->shmSize; + TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info->shmSize); + NCCLCHECK(shmOpen(shmName, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); + // Remove the file to ensure proper clean-up + NCCLCHECK(shmUnlink(shmName)); + + send->transportResources = resources; + send->conn.buff = resources->devRemHostMem->buff; + send->conn.llBuff = resources->devRemHostMem->llBuff; + send->conn.ll128Buff = resources->devRemHostMem->ll128Buff; + send->conn.tail = &resources->devRemHostMem->tail; + send->conn.opCountRem = &resources->devRemHostMem->opCount; + + send->conn.head = &resources->devHostMem->head; + send->conn.opCountLoc = &resources->devHostMem->opCount; + return ncclSuccess; +} + +ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, struct ncclConnector* recv) { + // Setup device pointers + struct shmRecvResources* resources = (struct shmRecvResources*)recv->transportResources; + struct shmConnectInfo* info = (struct shmConnectInfo*)connectInfo; + + char shmName[MAX_SHM_NAME_LEN]; + sprintf(shmName, "nccl-shm-send-%lx-%d-%d-%d", info->pidHash, info->id, info->sendRank, info->recvRank); + resources->remShmSize = info->shmSize; + TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info->shmSize); + NCCLCHECK(shmOpen(shmName, resources->remShmSize, (void**)&resources->remHostMem, (void**)&resources->devRemHostMem, 0)); + NCCLCHECK(shmUnlink(shmName)); + recv->conn.head = &resources->devRemHostMem->head; + recv->conn.opCountRem = &resources->devRemHostMem->opCount; + + recv->conn.buff = resources->devHostMem->buff; + recv->conn.llBuff = resources->devHostMem->llBuff; + recv->conn.ll128Buff = resources->devHostMem->ll128Buff; + recv->conn.tail = &resources->devHostMem->tail; + recv->conn.opCountLoc = &resources->devHostMem->opCount; + return ncclSuccess; +} + +ncclResult_t shmSendFree(void* transportResources) { + struct shmSendResources* resources = (struct shmSendResources*)transportResources; + NCCLCHECK(shmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); + NCCLCHECK(shmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + free(resources); + return ncclSuccess; +} + +ncclResult_t shmRecvFree(void* transportResources) { + struct shmRecvResources* resources = (struct shmRecvResources*)transportResources; + NCCLCHECK(shmClose(resources->hostMem, resources->devHostMem, resources->shmSize)); + NCCLCHECK(shmClose(resources->remHostMem, resources->devRemHostMem, resources->remShmSize)); + free(resources); + return ncclSuccess; +} + +struct ncclTransport shmTransport = { + "SHM", + shmCanConnect, + { shmSendSetup, shmSendConnect, shmSendFree, NULL }, + { shmRecvSetup, shmRecvConnect, shmRecvFree, NULL } +}; |