Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/transport.cc')
-rw-r--r--src/transport.cc13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/transport.cc b/src/transport.cc
index 4059849..cc8d5d1 100644
--- a/src/transport.cc
+++ b/src/transport.cc
@@ -100,6 +100,7 @@ static ncclResult_t SaveProxy(int peer, struct ncclProxyArgs* args) {
struct ncclPeer* peerComm = args->channel->peers+peer;
struct ncclConnector* connector = type == proxyRecv ? &peerComm->recv : &peerComm->send;
+ if (connector->transportComm == NULL) return ncclInternalError;
if (connector->transportComm->proxy == NULL) return ncclSuccess;
struct ncclProxyArgs* op;
@@ -130,6 +131,18 @@ ncclResult_t transportSaveProxies(struct ncclProxyArgs* args, int pattern, int r
for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxySend>(tree->down[i], args));
NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args));
}
+ if (pattern == ncclPatternCollTreeUp) {
+ // CollTree up
+ struct ncclTree* tree = &args->channel->collTreeUp;
+ NCCLCHECK(SaveProxy<proxyRecv>(tree->down[0], args));
+ NCCLCHECK(SaveProxy<proxySend>(tree->up, args));
+ }
+ if (pattern == ncclPatternCollTreeDown) {
+ // CollTree down
+ struct ncclTree* tree = &args->channel->collTreeDn;
+ NCCLCHECK(SaveProxy<proxySend>(tree->down[0], args));
+ NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args));
+ }
return ncclSuccess;
}