Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/transport.cc')
-rw-r--r--src/transport.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/transport.cc b/src/transport.cc
index 1436a5b..4059849 100644
--- a/src/transport.cc
+++ b/src/transport.cc
@@ -4,7 +4,8 @@
* See LICENSE.txt for license information
************************************************************************/
-#include "core.h"
+#include "comm.h"
+#include "info.h"
extern struct ncclTransport p2pTransport;
extern struct ncclTransport shmTransport;
@@ -119,13 +120,13 @@ ncclResult_t transportSaveProxies(struct ncclProxyArgs* args, int pattern, int r
}
if (pattern == ncclPatternTreeUp || pattern == ncclPatternTreeUpDown) {
// Tree up
- struct ncclTree* tree = &args->channel->tree;
+ struct ncclTree* tree = &args->channel->treeUp;
for (int i=0; i<NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxyRecv>(tree->down[i], args));
NCCLCHECK(SaveProxy<proxySend>(tree->up, args));
}
if (pattern == ncclPatternTreeDown || pattern == ncclPatternTreeUpDown) {
// Tree down
- struct ncclTree* tree = &args->channel->tree;
+ struct ncclTree* tree = &args->channel->treeDn;
for (int i=0; i< NCCL_MAX_TREE_ARITY; i++) NCCLCHECK(SaveProxy<proxySend>(tree->down[i], args));
NCCLCHECK(SaveProxy<proxyRecv>(tree->up, args));
}
@@ -157,7 +158,9 @@ void* persistentThread(void *comm_) {
}
} while (op == NULL);
op->idle = 0;
- if (op->state != ncclProxyOpNone) ret = op->progress(op);
+ // opCount >= lastOpCount are part of an ongoing GroupStart/GroupEnd that hasn't started
+ // yet and might be cancelled before they even start. Hold on on those.
+ if (op->state != ncclProxyOpNone && op->opCount < comm->lastOpCount) ret = op->progress(op);
if (ret != ncclSuccess) {
comm->fatalError = ret;
INFO(NCCL_ALL,"%s:%d -> %d [Proxy Thread]", __FILE__, __LINE__, ret);