1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
/*************************************************************************
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "core.h"
#include "net.h"
#include "param.h"
#define RANK_TO_INDEX(r) (rank > root ? rank-1 : rank)
/* Btree which alternates leaves and nodes.
* Assumes root is 0, which conveniently builds a tree on powers of two,
* (because we have pow2-1 ranks) which lets us manipulate bits.
* Find first non-zero bit, then :
* Find the parent :
* xx01[0] -> xx10[0] (1,5,9 below) or xx00[0] if xx10[0] is out of bounds (13 below)
* xx11[0] -> xx10[0] (3,7,11 below)
* Find the children :
* xx10[0] -> xx01[0] (2,4,6,8,10,12) or -1 (1,3,5,7,9,11,13)
* xx10[0] -> xx11[0] (2,4,6,8,10) or xx101[0] (12) or xx1001[0] ... or -1 (1,3,5,7,9,11,13)
*
* Illustration :
* 0---------------8
* ______/ \______
* 4 12
* / \ / \
* 2 6 10 \
* / \ / \ / \ \
* 1 3 5 7 9 11 13
*/
ncclResult_t ncclGetBtree(int nranks, int rank, int* u, int* d0, int* d1) {
int up, down0, down1;
int bit;
for (bit=1; bit<nranks; bit<<=1) {
if (bit & rank) break;
}
if (rank == 0) {
*u = -1;
*d0 = nranks > 1 ? bit >> 1 : -1;
*d1 = -1;
return ncclSuccess;
}
up = (rank ^ bit) | (bit << 1);
if (up >= nranks) up = (rank ^ bit);
*u = up;
int lowbit = bit >> 1;
// down0 is always within bounds
down0 = lowbit == 0 ? -1 : rank-lowbit;
down1 = lowbit == 0 ? -1 : rank+lowbit;
// Make sure down1 is within bounds
while (down1 >= nranks) {
down1 = lowbit == 0 ? -1 : rank+lowbit;
lowbit >>= 1;
}
*d0 = down0; *d1 = down1;
return ncclSuccess;
}
/* Build a double binary tree. Take the previous tree for the first tree.
* For the second tree, we use a mirror tree (if nranks is odd)
*
* 8---------0---------5
* ______/ \______ _____/ \______
* 4 12 1 9
* / \ / \ / \
* 2 6 10 3 7 10
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 11 12
*
* or shift it by one rank (if nranks is even)
*
* 8---------0--------------9
* ______/ \ ______/ \
* 4 \ 5 \
* / \ \ / \ \
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 10 1
*/
ncclResult_t ncclGetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* s1, int* d1_0, int* d1_1) {
// First tree ... use a btree
ncclGetBtree(nranks, rank, s0, d0_0, d0_1);
// Second tree ... mirror or shift
if (nranks % 2 == 0) {
// shift
int shiftrank = (rank-1+nranks) % nranks;
int u, d0, d1;
ncclGetBtree(nranks, shiftrank, &u, &d0, &d1);
*s1 = u == -1 ? -1 : (u+1) % nranks;
*d1_0 = d0 == -1 ? -1 : (d0+1) % nranks;
*d1_1 = d1 == -1 ? -1 : (d1+1) % nranks;
} else {
// mirror
int u, d0, d1;
ncclGetBtree(nranks, nranks-1-rank, &u, &d0, &d1);
*s1 = u == -1 ? -1 : nranks-1-u;
*d1_0 = d0 == -1 ? -1 : nranks-1-d0;
*d1_1 = d1 == -1 ? -1 : nranks-1-d1;
}
return ncclSuccess;
}
|