Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/transport/shm.cc')
-rw-r--r--src/transport/shm.cc111
1 files changed, 19 insertions, 92 deletions
diff --git a/src/transport/shm.cc b/src/transport/shm.cc
index 2ec5f23..60f16c8 100644
--- a/src/transport/shm.cc
+++ b/src/transport/shm.cc
@@ -4,13 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
-#include "utils.h"
-#include "transport.h"
-#include "param.h"
+#include "comm.h"
#include "shm.h"
-#include <unistd.h>
-#include <cuda_runtime.h>
struct shmConnectInfo {
uint64_t pidHash;
@@ -40,98 +35,29 @@ struct shmRecvResources {
NCCL_PARAM(ShmDisable, "SHM_DISABLE", 0);
-/* Determine if we can communicate with the peer */
-ncclResult_t shmCanConnect(ncclTvalue_t* ret, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo) {
- *ret = ((ncclParamShmDisable() == 1) || (myInfo->hostHash != peerInfo->hostHash)) ? 0 : 1;
- return ncclSuccess;
-}
+/* Determine two peers can communicate with SHM */
+ncclResult_t shmCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) {
+ *ret = 0;
-static inline int groupFirst(int nranks, int* groups, int group, int rankToAvoid) {
- for (int rank = 0; rank<nranks; rank++) {
- if ((groups[rank] == group) && (rank != rankToAvoid)) return rank;
- }
- return -1;
-}
+ if (ncclParamShmDisable() == 1) return ncclSuccess;
-static inline int groupLast(int nranks, int* groups, int group, int rankToAvoid) {
- for (int rank = nranks-1; rank>=0; rank--) {
- if ((groups[rank] == group) && (rank != rankToAvoid)) return rank;
- }
- return -1;
-}
+ // Same host?
+ TRACE(NCCL_INIT|NCCL_SHM, "peer1 hostHash %lx peer2 hostHash %lx", info1->hostHash, info2->hostHash);
+ if (info1->hostHash != info2->hostHash) return ncclSuccess;
+
+ // Common /dev/shm (between containers) ?
+ TRACE(NCCL_INIT|NCCL_SHM, "peer1 shmDev %lx peer2 shmDev %lx", info1->shmDev, info2->shmDev);
+ if (info1->shmDev != info2->shmDev) return ncclSuccess;
+
+ *ret = 1;
-#define MAXGROUPS 16
-
-ncclResult_t shmGetRings(int nranks, int* groups, int* subgroups, ncclTvalue_t* values, int* nringsRet, int* prev, int* next, int minScore, int* nthreads) {
- if (*nringsRet == MAXCHANNELS) *nringsRet = 1;
- int nGroups = groups[nranks-1] + 1;
- int starts[MAXGROUPS];
- int ends[MAXGROUPS];
- for (int ring = 0; ring<*nringsRet; ring++) {
- int startGroup = -1, endGroup = -1;
- for (int group = 0; group<nGroups; group++) {
- int start = -1;
- int end = -1;
- int nranksInGroup = 0;
- for (int rank=0; rank<nranks; rank++) {
- if (groups[rank] != group) continue;
- nranksInGroup++;
- if (prev[ring*nranks+rank] != -1) {
- if (start != -1) {
- WARN("Multiple starts found in group");
- }
- start = rank;
- startGroup = group;
- }
- if (next[ring*nranks+rank] != -1) {
- if (end != -1) {
- WARN("Multiple ends found in group");
- }
- end = rank;
- endGroup = group;
- }
- }
- if (nranksInGroup == 1) {
- start = end = groupFirst(nranks, groups, group, -1);
- } else {
- if (start == -1)
- start = groupFirst(nranks, groups, group, end);
- if (end == -1)
- end = groupLast(nranks, groups, group, start);
- }
- if (start == -1 || end == -1) {
- *nringsRet = ring;
- return ncclSuccess;
- }
- starts[group] = start;
- ends[group] = end;
- }
- if (endGroup == -1 || startGroup == -1) {
- startGroup = 0;
- endGroup = nGroups-1;
- // Close the loop
- next[ring*nranks+ends[endGroup]] = starts[startGroup];
- prev[ring*nranks+starts[startGroup]] = ends[endGroup];
- }
- int group = startGroup;
- for (int i=0; i<nGroups-2; i++) {
- int nextGroup = (group+1)%nGroups;
- if (nextGroup == endGroup) nextGroup = (nextGroup+1)%nGroups;
- next[ring*nranks+ends[group]] = starts[nextGroup];
- prev[ring*nranks+starts[nextGroup]] = ends[group];
- group = nextGroup;
- }
- // Connect with the last
- next[ring*nranks+ends[group]] = starts[endGroup];
- prev[ring*nranks+starts[endGroup]] = ends[group];
- }
return ncclSuccess;
}
#define MAX_SHM_NAME_LEN 1024
/* Create and return connect structures for this peer to connect to me */
-ncclResult_t shmSendSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int buffSize, int channelId) {
+ncclResult_t shmSendSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int buffSize, int channelId) {
struct shmSendResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
@@ -149,13 +75,13 @@ ncclResult_t shmSendSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peer
TRACE(NCCL_SHM,"Open shmName %s shmSize %d", shmName, info.shmSize);
NCCLCHECK(shmOpen(shmName, resources->shmSize, (void**)&resources->hostMem, (void**)&resources->devHostMem, 1));
- INFO(NCCL_INIT|NCCL_SHM,"Ring %02d : %d[%d] -> %d[%d] via direct shared memory", channelId, myInfo->rank, myInfo->cudaDev, peerInfo->rank, peerInfo->cudaDev);
+ INFO(NCCL_INIT|NCCL_SHM,"Ring %02d : %d[%lx] -> %d[%lx] via direct shared memory", channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId);
static_assert(sizeof(struct shmConnectInfo) <= sizeof(struct ncclConnect), "shm Connect Recv Info is too big");
memcpy(connectInfo, &info, sizeof(struct shmConnectInfo));
return ncclSuccess;
}
-ncclResult_t shmRecvSetup(struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int buffSize, int channelId) {
+ncclResult_t shmRecvSetup(struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int buffSize, int channelId) {
struct shmRecvResources* resources;
NCCLCHECK(ncclCalloc(&resources, 1));
recv->transportResources = resources;
@@ -194,6 +120,7 @@ ncclResult_t shmSendConnect(struct ncclConnect* connectInfo, struct ncclConnecto
send->transportResources = resources;
send->conn.buff = resources->devRemHostMem->buff;
send->conn.llBuff = resources->devRemHostMem->llBuff;
+ send->conn.ll128Buff = resources->devRemHostMem->ll128Buff;
send->conn.tail = &resources->devRemHostMem->tail;
send->conn.opCountRem = &resources->devRemHostMem->opCount;
@@ -218,6 +145,7 @@ ncclResult_t shmRecvConnect(struct ncclConnect* connectInfo, struct ncclConnecto
recv->conn.buff = resources->devHostMem->buff;
recv->conn.llBuff = resources->devHostMem->llBuff;
+ recv->conn.ll128Buff = resources->devHostMem->ll128Buff;
recv->conn.tail = &resources->devHostMem->tail;
recv->conn.opCountLoc = &resources->devHostMem->opCount;
return ncclSuccess;
@@ -242,7 +170,6 @@ ncclResult_t shmRecvFree(void* transportResources) {
struct ncclTransport shmTransport = {
"SHM",
shmCanConnect,
- shmGetRings,
{ shmSendSetup, shmSendConnect, shmSendFree, NULL },
{ shmRecvSetup, shmRecvConnect, shmRecvFree, NULL }
};