diff options
Diffstat (limited to 'src/transport.cc')
-rw-r--r-- | src/transport.cc | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/src/transport.cc b/src/transport.cc index 4059849..cc8d5d1 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -100,6 +100,7 @@ static ncclResult_t SaveProxy(int peer, struct ncclProxyArgs* args) { struct ncclPeer* peerComm = args->channel->peers+peer; struct ncclConnector* connector = type == proxyRecv ? &peerComm->recv : &peerComm->send; + if (connector->transportComm == NULL) return ncclInternalError; if (connector->transportComm->proxy == NULL) return ncclSuccess; struct ncclProxyArgs* op; @@ -130,6 +131,18 @@ ncclResult_t transportSaveProxies(struct ncclProxyArgs* args, int pattern, int r for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxySend>(tree->down[i], args)); NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args)); } + if (pattern == ncclPatternCollTreeUp) { + // CollTree up + struct ncclTree* tree = &args->channel->collTreeUp; + NCCLCHECK(SaveProxy<proxyRecv>(tree->down[0], args)); + NCCLCHECK(SaveProxy<proxySend>(tree->up, args)); + } + if (pattern == ncclPatternCollTreeDown) { + // CollTree down + struct ncclTree* tree = &args->channel->collTreeDn; + NCCLCHECK(SaveProxy<proxySend>(tree->down[0], args)); + NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args)); + } return ncclSuccess; } |