diff options
author | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2020-10-14 02:58:05 +0300 |
---|---|---|
committer | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2020-10-14 02:58:05 +0300 |
commit | 0e14394c5ffcdae5517ef5b3a8c734e2ccb215b0 (patch) | |
tree | 4e1639b47312d85a2328434dce4f054a2bca1bb5 | |
parent | c6dbdb00849027b4e2c277653cbef53729f7213d (diff) |
Fix affinity move
-rw-r--r-- | src/init.cc | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/src/init.cc b/src/init.cc index 1cc8f7f..585db4b 100644 --- a/src/init.cc +++ b/src/init.cc @@ -799,23 +799,25 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECK(ncclTopoComputeP2pChannels(comm)); // Compute intra ranks (using AllGather1 data) - int intraRank0 = -1, intraRank = -1, intraRanks = 0; - for (int i = 0; i < nranks; i++) { - if ((allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) && - (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash)) { - if (intraRanks == 0) intraRank0 = i; - if (i == rank) intraRank = intraRanks; - intraRanks++; + do { + int intraRank0 = -1, intraRank = -1, intraRanks = 0; + for (int i = 0; i < nranks; i++) { + if ((allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) && + (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash)) { + if (intraRanks == 0) intraRank0 = i; + if (i == rank) intraRank = intraRanks; + intraRanks++; + } } - } - TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", + TRACE(NCCL_INIT,"hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0); - if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) { - WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", - rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0); - return ncclInternalError; - } - NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, allGather1Data[intraRank0].comm)); + if (intraRank == -1 || intraRank0 == -1 || allGather1Data[intraRank0].comm == NULL) { + WARN("Failed to determine intra ranks hostHash[%d] %lx intraRank %d intraRanks %d intraRank0 %d", + rank, allGather1Data[rank].peerInfo.hostHash, intraRank, intraRanks, intraRank0); + return ncclInternalError; + } + NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, allGather1Data[intraRank0].comm)); + } while(0); // Done with AllGather1 data free(allGather1Data); |