diff options
Diffstat (limited to 'src/transport.cc')
-rw-r--r-- | src/transport.cc | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/transport.cc b/src/transport.cc index 1436a5b..4059849 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -4,7 +4,8 @@ * See LICENSE.txt for license information ************************************************************************/ -#include "core.h" +#include "comm.h" +#include "info.h" extern struct ncclTransport p2pTransport; extern struct ncclTransport shmTransport; @@ -119,13 +120,13 @@ ncclResult_t transportSaveProxies(struct ncclProxyArgs* args, int pattern, int r } if (pattern == ncclPatternTreeUp || pattern == ncclPatternTreeUpDown) { // Tree up - struct ncclTree* tree = &args->channel->tree; + struct ncclTree* tree = &args->channel->treeUp; for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxyRecv>(tree->down[i], args)); NCCLCHECK(SaveProxy<proxySend>(tree->up, args)); } if (pattern == ncclPatternTreeDown || pattern == ncclPatternTreeUpDown) { // Tree down - struct ncclTree* tree = &args->channel->tree; + struct ncclTree* tree = &args->channel->treeDn; for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxySend>(tree->down[i], args)); NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args)); } @@ -157,7 +158,9 @@ void* persistentThread(void *comm_) { } } while (op == NULL); op->idle = 0; - if (op->state != ncclProxyOpNone) ret = op->progress(op); + // opCount >= lastOpCount are part of an ongoing GroupStart/GroupEnd that hasn't started + // yet and might be cancelled before they even start. Hold on on those. + if (op->state != ncclProxyOpNone && op->opCount < comm->lastOpCount) ret = op->progress(op); if (ret != ncclSuccess) { comm->fatalError = ret; INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret); |