Welcome to mirror list, hosted at ThFree Co, Russian Federation.

primitives.h « device « collectives « src - github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8df152e7e118a262ee65687035d031c2affb405a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
/*************************************************************************
 * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#ifndef NCCL_PRIMITIVES_H_
#define NCCL_PRIMITIVES_H_

#include <type_traits>
#include "reduce_kernel.h" // for reduction funcs


/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
 *
 * In order to reduce the reptetion of template arguments, the operations
 * are bundled as static methods of the Primitives class.
 *
 * Each primitive operation copies/reduces a contiguous buffer and syncs
 * an optional set of flags against a sub-step counter. The sync value is
 * based on the step parameter. Sync flags must be of type WaitFlag or
 * PostFlag. The primitive routines wait for all WaitFlag args to attain
 * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
 * corresponding substep by previous step) before executing the transfer.
 * After each substep is transfered, all PostFlag arguments get updated to
 * the value SUBSTEPS*step+substep+1.
 */


class WaitFlag {
  volatile uint64_t * const flag;
  const int shift;
 public:
  __device__ __forceinline__
  WaitFlag(volatile uint64_t * const flag, const int shift) : flag(flag), shift(shift) { }
  __device__ __forceinline__
  void wait(uint64_t val) { while ((*flag + shift) < val) /*SPIN*/; }
};


class PostFlag {
  volatile uint64_t * const flag;
  const int shift;
  volatile int * const fifo;
  const int fifo_size;
 public:
  __device__ __forceinline__
  PostFlag(volatile uint64_t* const flag, const int shift, volatile int* const fifo, const int fifo_size) : flag(flag), shift(shift), fifo(fifo), fifo_size(fifo_size) { }
  __device__ __forceinline__
  void post(uint64_t val) { *flag = (val - shift); }
  __device__ __forceinline__
  void postSize(uint64_t step, int size) { if (fifo != NULL) fifo[step%fifo_size] = size; };
};


// Helper to check if any argument is of type T.
// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
template<typename T> __device__ __forceinline__
bool AnyAre() { return false; }

template<typename T, typename FIRST_T, typename... TAIL_Ts>
__device__ __forceinline__
bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
  return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
}


// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
void WaitOnFlags(uint64_t val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
  flag.wait(val);
  WaitOnFlags(val, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(uint64_t val, PostFlag, TAIL_Ts... tail) {
  WaitOnFlags(val, tail...);
}


// Post all PostFlags, ignore WaitFlags
__device__ __forceinline__
void PostToFlags(uint64_t val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) {
  PostToFlags(val, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(uint64_t val, PostFlag flag, TAIL_Ts... tail) {
  flag.post(val);
  PostToFlags(val, tail...);
}


// Post sizes for PostFlags, ignore WaitFlags
__device__ __forceinline__
void PostSizeToFlags(uint64_t step, int size) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void PostSizeToFlags(uint64_t step, int size, WaitFlag flag, TAIL_Ts... tail) {
  PostSizeToFlags(step, size, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void PostSizeToFlags(uint64_t step, int size, PostFlag flag, TAIL_Ts... tail) {
  flag.postSize(step, size);
  PostSizeToFlags(step, size, tail...);
}


// Create pointer arithmetic syntax that doesn't break for nullptr_t
template <typename Tptr> __device__ __forceinline__
Tptr ptradd(Tptr ptr, int i) {
  return ptr + i;
}

__device__ __forceinline__
nullptr_t ptradd(nullptr_t ptr, int i) {
  return nullptr;
}


// Implementation of primitive types
template <int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
class Primitives {
 private:
  template <typename SRC2_T, // either T* or nullptr_t
      typename DST2_T, // either T* or nullptr_t
      typename... SYNC_Ts> // either WaitFunc or PostFunc
  static __device__ __forceinline__ void
  GenericOp(const int tid, const int nthreads,
      const T*     src1,
      const SRC2_T src2,
      T*     dst1,
      DST2_T dst2,
      int len, int maxoffset, uint64_t step, SYNC_Ts... flags) {

    enum { noSrc2 = std::is_same<SRC2_T, nullptr_t>::value };
    enum { noDst2 = std::is_same<DST2_T, nullptr_t>::value };
    static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
        "src2 must be of type T* or nullptr_t");
    static_assert(noDst2 || std::is_same<DST2_T, T*>::value,
        "dst2 must be of type T* or nullptr_t");

    using OpType = typename std::conditional<noSrc2, FuncSum<T>, REDOP>::type;

    int sliceSize = len / SUBSTEPS;
    int sliceOffset = 0;

#pragma unroll 1
    for (int sub=0; sub<SUBSTEPS; ++sub) {
      int realSize = max(0, min(sliceSize, maxoffset-sliceOffset));
      if (tid < nthreads) {
        if (AnyAre<WaitFlag>(flags...)) {
          if (tid == 0) {
            WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
          }
          asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
        }
        ReduceOrCopy
        <
        UNROLL,
        OpType,
        T,
        !std::is_same<DST2_T, nullptr_t>::value, // HAS_DEST1
        !std::is_same<SRC2_T, nullptr_t>::value  // HAS_SRC1
        >
        (
            tid, nthreads,
            ptradd(dst1, sliceOffset),
            ptradd(dst2, sliceOffset),
            ptradd(src1, sliceOffset),
            ptradd(src2, sliceOffset),
            realSize
        );
        if (AnyAre<PostFlag>(flags...)) {
          __syncthreads();
        }
      } else {
        if (AnyAre<PostFlag>(flags...)) {
          __syncthreads();
          PostSizeToFlags(SUBSTEPS*step+sub, realSize*sizeof(T), flags...);
          __threadfence_system();
          PostToFlags(SUBSTEPS*step + sub + 1, flags...);
        }
      }
      sliceOffset += sliceSize;
    }
  }

 public:
  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  Copy(const int tid, const int nthreads, const T* src, T* dst,
      int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
    GenericOp(tid, nthreads, src, nullptr, dst, nullptr, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  DoubleCopy(const int tid, const int nthreads, const T* src, T* dst1, T* dst2,
      int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
    GenericOp(tid, nthreads, src, nullptr, dst1, dst2, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  Reduce(const int tid, const int nthreads, const T* src1, const T* src2, T* dst,
      int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
    GenericOp(tid, nthreads, src1, src2, dst, nullptr, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  ReduceCopy(const int tid, const int nthreads, const T* src1, const T* src2, T* dst1, T* dst2,
      int len, int maxOffset, uint64_t step, SYNC_Ts... flags) {
    GenericOp(tid, nthreads, src1, src2, dst1, dst2, len, maxOffset, step, flags...);
  }
};

#endif // end include guard