diff options
Diffstat (limited to 'src/graph/connect.cc')
-rw-r--r-- | src/graph/connect.cc | 53 |
1 files changed, 51 insertions, 2 deletions
diff --git a/src/graph/connect.cc b/src/graph/connect.cc index af481d2..dd9f9f0 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -14,7 +14,7 @@ /******************************************************************/ ncclResult_t ncclTopoPreset(struct ncclComm* comm, - struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, + struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph, struct ncclTopoRanks* topoRanks) { int rank = comm->rank; int localRanks = comm->localRanks; @@ -27,9 +27,14 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->treeUp.down[i] = -1; channel->treeDn.up = -1; for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->treeDn.down[i] = -1; + channel->collTreeUp.up = -1; + for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->collTreeUp.down[i] = -1; + channel->collTreeDn.up = -1; + for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) channel->collTreeDn.down[i] = -1; int* ringIntra = ringGraph->intra+c*localRanks; int* treeIntra = treeGraph->intra+c*localRanks; + int* collNetIntra = collNetGraph->intra+c*localRanks; for (int i=0; i<localRanks; i++) { if (ringIntra[i] == rank) { @@ -57,6 +62,16 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, channel->treeUp.down[0] = sym ? channel->treeDn.down[0] : channel->treeDn.up ; channel->treeUp.up = sym ? channel->treeDn.up : channel->treeDn.down[0]; } + if (collNetIntra[i] == rank) { + int prev = (i-1+localRanks)%localRanks, next = (i+1)%localRanks; + + // CollTrees are always symmetric, i.e. + // up/down go in reverse directions + channel->collTreeDn.up = collNetIntra[prev]; + channel->collTreeDn.down[0] = collNetIntra[next]; + channel->collTreeUp.down[0] = channel->collTreeDn.down[0]; + channel->collTreeUp.up = channel->collTreeDn.up; + } } topoRanks->ringPrev[c] = channel->ring.prev; topoRanks->ringNext[c] = channel->ring.next; @@ -174,6 +189,40 @@ static ncclResult_t connectTrees(struct ncclComm* comm, int* treeUpRecv, int* tr return ncclSuccess; } +ncclResult_t ncclTopoConnectCollNet(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, int rank) { + int nranks = comm->nRanks; + int depth = nranks/comm->nNodes; + int sendIndex = collNetGraph->pattern == NCCL_TOPO_PATTERN_TREE ? 0 : 1; // send GPU index depends on topo pattern + int sendEndIndex = (sendIndex+comm->localRanks-1)%comm->localRanks; + for (int c=0; c<comm->nChannels/2; c++) { + struct ncclChannel* channel = comm->channels+c; + // Set root of collTree to id nranks + if (rank == collNetGraph->intra[sendIndex+c*comm->localRanks]) { // is master + channel->collTreeUp.up = channel->collTreeDn.up = nranks; + } + if (rank == collNetGraph->intra[sendEndIndex+c*comm->localRanks]) { // is bottom of intra-node chain + channel->collTreeUp.down[0] = channel->collTreeDn.down[0] = -1; + } + channel->collTreeUp.depth = channel->collTreeDn.depth = depth; + INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", c, rank, channel->collTreeUp.up, channel->collTreeUp.down[0]); + } + int recvIndex = 0; // recv GPU index is always 0 + int recvEndIndex = (recvIndex+comm->localRanks-1)%comm->localRanks; + for (int c=0; c<comm->nChannels/2; c++) { + struct ncclChannel* channel = comm->channels+comm->nChannels/2+c; + // Set root of collTree to id nranks + if (rank == collNetGraph->intra[recvIndex+c*comm->localRanks]) { // is master + channel->collTreeUp.up = channel->collTreeDn.up = nranks; + } + if (rank == collNetGraph->intra[recvEndIndex+c*comm->localRanks]) { // is bottom of intra-node chain + channel->collTreeUp.down[0] = channel->collTreeDn.down[0] = -1; + } + channel->collTreeUp.depth = channel->collTreeDn.depth = depth; + INFO(NCCL_GRAPH, "CollNet Channel %d rank %d up %d down %d", comm->nChannels/2+c, rank, channel->collTreeDn.up, channel->collTreeDn.down[0]); + } + return ncclSuccess; +} + // Legacy naming NCCL_PARAM(MinNrings, "MIN_NRINGS", -2); NCCL_PARAM(MaxNrings, "MAX_NRINGS", -2); |