Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/rings.cc')
-rw-r--r--src/graph/rings.cc57
1 files changed, 57 insertions, 0 deletions
diff --git a/src/graph/rings.cc b/src/graph/rings.cc
new file mode 100644
index 0000000..5aacbb5
--- /dev/null
+++ b/src/graph/rings.cc
@@ -0,0 +1,57 @@
+/*************************************************************************
+ * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#include "core.h"
+
+#define MAXWIDTH 20
+#define PREFIXLEN 15
+#define STRLENGTH (PREFIXLEN+5*MAXWIDTH)
+void dumpLine(int* values, int nranks, const char* prefix) {
+ int prefixlen = strlen(prefix);
+ char line[STRLENGTH+1];
+ line[STRLENGTH] = '\0';
+ memset(line, ' ', STRLENGTH);
+ strncpy(line, prefix, PREFIXLEN);
+ for (int i=0; i<nranks && i<MAXWIDTH; i++) sprintf(line+prefixlen+4*i, " %3d", values[i]);
+ INFO(NCCL_INIT,"%s", line);
+}
+
+ncclResult_t ncclBuildRings(int nrings, int* rings, int rank, int nranks, int* prev, int* next) {
+ for (int r=0; r<nrings; r++) {
+ char prefix[30];
+ /*sprintf(prefix, "[%d] Channel %d Prev : ", rank, r);
+ dumpLine(prev+r*nranks, nranks, prefix);
+ sprintf(prefix, "[%d] Channel %d Next : ", rank, r);
+ dumpLine(next+r*nranks, nranks, prefix);*/
+
+ int current = rank;
+ for (int i=0; i<nranks; i++) {
+ rings[r*nranks+i] = current;
+ current = next[r*nranks+current];
+ }
+ sprintf(prefix, "Channel %02d/%02d : ", r, nrings);
+ if (rank == 0) dumpLine(rings+r*nranks, nranks, prefix);
+ if (current != rank) {
+ WARN("Error : ring %d does not loop back to start (%d != %d)", r, current, rank);
+ return ncclInternalError;
+ }
+ // Check that all ranks are there
+ for (int i=0; i<nranks; i++) {
+ int found = 0;
+ for (int j=0; j<nranks; j++) {
+ if (rings[r*nranks+j] == i) {
+ found = 1;
+ break;
+ }
+ }
+ if (found == 0) {
+ WARN("Error : ring %d does not contain rank %d", r, i);
+ return ncclInternalError;
+ }
+ }
+ }
+ return ncclSuccess;
+}