Welcome to mirror list, hosted at ThFree Co, Russian Federation.

communicator_nccl.h « training « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 55b485287606acd363ed415021cf09f41a53fcdb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
// @TODO: rename to communicator_nccl.h
// Note: This must only be included if defined(CUDA_FOUND) && defined(USE_NCCL)
#include "training/communicator.h"
#include "3rd_party/threadpool.h"
#include "tensors/gpu/cuda_helpers.h"

#include "common/timer.h"

// Generated by NCCL make files in build/nccl/include;
// include dir has been set in CMake files. NCCL add version number etc.
#include "nccl.h"
#include <cuda_runtime.h>

#if (NCCL_MAJOR<3 || NCCL_MINOR<2)
#define ncclGetVersion(pv) (*(pv) = (NCCL_MAJOR * 1000 + NCCL_MINOR * 100 + NCCL_PATCH))
#endif

#include <signal.h> // HACK
#include <sys/types.h>
#include <sys/syscall.h>
pid_t gettid(void){ return syscall(SYS_gettid); }

namespace marian {

class NCCLCommunicator : public ICommunicator {
private:
  std::vector<ncclComm_t> comms_;     // [device index]
  std::vector<cudaStream_t> streams_; // [device index]
  std::vector<int> devices_;          // [device index]
  Ptr<IMPIWrapper> mpi_; // (may be null)
  mutable ThreadPool threadPool_;
  mutable std::vector<std::future<void>> threadResults_; // [device index]

  void groupStart() const { NCCL_CHECK(ncclGroupStart()); } // helpers to make sure we check the error
  void groupEnd()   const { NCCL_CHECK(ncclGroupEnd());   }

  void synchronizeAll() const {
    for(int i = 0; i < graphs_.size(); ++i) {
      CUDA_CHECK(cudaSetDevice(devices_[i]));
      CUDA_CHECK(cudaStreamSynchronize(streams_[i]));
      // @TODO: why do we sync the CPU, and not the GPU?
      //  - cudaEventRecord() an event on the nccl stream
      //  - submit a cudaStreamWaitEvent() into our compute stream (=NULL stream)
      // cf. https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.cpp
    }
  }

  void synchronizeAllOnNullStream() const {
    for (int i = 0; i < graphs_.size(); ++i) {
      auto backend = graphs_[i]->params()->vals()->getBackend();
      backend->setDevice();
      backend->synchronize(); // note: synchronize() does not set the device by itself
    }
  }

  std::string mpiIdStr() const { // (for logging)
    return mpi_ ? mpi_->idStr() : "";
  }

  size_t myNcclRank(size_t localDeviceIndex) const { // map local device index to a global rank
    if (mpi_)
      return mpi_->myMPIRank() * devices_.size() + localDeviceIndex;
    else
      return localDeviceIndex;
  }

  size_t numNcclRanks() const { // total number of devices across all MPI processes
    if (mpi_)
      return mpi_->numMPIProcesses() * devices_.size();
    else
      return devices_.size();
  }

  size_t dataSize() const { // total number of floats that comprise the concatenated parameter and gradient vector
    return graphs_[0]->params()->vals()->size();
  }

  // determine the (max) shard size
  // All shards except the last one have this size.
  // Presently, even all shards must have identical size, due to a limitation in NCCL we have not yet worked around.
  size_t shardSize() const {
    size_t numShards = numNcclRanks();
    size_t size = (dataSize() + numShards - 1) / numShards;
#if 1 // for now, all shards must have the same size, since NCCL does not allow a sub-slice for the last shard
    ABORT_IF(size * numShards != dataSize(), "presently, all shards must have the same size");
#endif
    return size;
  }

  // determine the index range (begin, end) of a shard
  std::pair<size_t, size_t> ncclRankShardRange(size_t rank) const {
    size_t size = shardSize();
    size_t begin = rank * size;
    size_t end = begin + size;
    end = std::min(end, dataSize()); // clip last shard. Note: Presently this never happens, since shardSize() enforces uniform shard size.
    return{ begin, end };
  }

  // determine the index range (begin, end) of a shard
  std::pair<size_t, size_t> localShardRange(size_t localDeviceIndex) const {
    return ncclRankShardRange(myNcclRank(localDeviceIndex));
  }

  static std::string ncclVersionString() {
    int ncclVersion = 0;
    ncclGetVersion(&ncclVersion);
    return std::to_string(ncclVersion/1000) + "." + std::to_string((ncclVersion/100)%10) + "." + std::to_string(ncclVersion%100);
  }

  void mpiBarrier() const {
    if (mpi_)
      mpi_->barrier();
  }

  // helper class to temporarily block a UNIX signal
  class BlockSignal {
    typedef std::function<void(int, const sigset_t*, sigset_t*)> SigMaskFn;
    SigMaskFn sigMaskFn_; // function to set the mask, thread or proc
    sigset_t oldSigSet_;  // old set to restore the signal
  public:
    BlockSignal(int signal, const SigMaskFn& sigMaskFn) : sigMaskFn_(sigMaskFn) {
      sigset_t newSigSet;
      sigemptyset(&newSigSet);
      sigaddset(&newSigSet, signal); // block signal by setting it in the blocked-signal mask
      sigMaskFn_(SIG_BLOCK, &newSigSet, &oldSigSet_);
    }
    ~BlockSignal() {
      sigMaskFn_(SIG_BLOCK, &oldSigSet_, nullptr); // restore old signal mask
    }
  };

public:
  // a NCCLCommunicator is bound to a set of graphs, one per GPU device
  // If MPI is used, then each MPI process has an instance of this class for its specific
  // set of GPU devices, which are communicating with each other. The total number of GPUs
  // involved in the NCCL communication setup is (#MPI processes) x (#GPUs per process).
  NCCLCommunicator(const std::vector<Ptr<ExpressionGraph>>& graphs, Ptr<IMPIWrapper> mpi)
      : ICommunicator(graphs),
        comms_(graphs.size()),
        streams_(graphs.size()),
        devices_(graphs.size()),
        mpi_(mpi),
        threadPool_(graphs.size(), graphs.size()), threadResults_(graphs.size()) {
    mpiBarrier(); // barrier to group the multiple log messages from MPI processes
    LOG(info, "[comm] Using NCCL {} {}for GPU communication",
        ncclVersionString(),
        (mpi_ && mpi_->numMPIProcesses() > 1) ? "and MPI " : "");
    mpiBarrier(); // (synchronize the log messages)

    // set up our local devices
    for(int i = 0; i < graphs_.size(); ++i) {
      auto device = graphs_[i]->getBackend()->getDeviceId();

      ABORT_IF(device.type != DeviceType::gpu,
               "NCCL communicator can only be used with GPUs");

      devices_[i] = device.no;
      CUDA_CHECK(cudaSetDevice(devices_[i]));
      CUDA_CHECK(cudaStreamCreate(&streams_[i]));
    }

    // Note: due to a bug in NCCL 2.3.5, NCCL's allocation of shared memory intermittently fails with
    //          Failed, NCCL error 2 'unhandled system error' - ncclGroupEnd()
    //          include/shm.h:26 NCCL WARN Unable to allocate shared memory (4263936 bytes) : Interrupted system call
    // This is caused by SIGPROF signals being raised, causing EINTR, which NCCL does not handle.
    // Reported as Issue #137 on the NCCL Github, and supposedly fixed for 2.3.7 (to be verified).
    // To work around, we disable the SIGPROF signal during NCCL initialization.
#define SIG_BAD 27 // SIGPROF
    BlockSignal blockThread(SIG_BAD, pthread_sigmask); // Note: I don't know yet which of these two makes the difference.
    BlockSignal blockProc(SIG_BAD, sigprocmask);       // So for now just block both.

    // set up NCCL
    // Since we want to use MPI, we cannot use NCCL's handy convenience function. Instead, we must go the laborious route.
    // cf. https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/index.html#multidevprothrd

    // generate NCCL unique ID at one process and broadcast to all
    ncclUniqueId uniqueId = { 0 };
    if (!mpi_ || mpi->myMPIRank() == 0)
      NCCL_CHECK(ncclGetUniqueId(&uniqueId));

    if (mpi_) {
      static_assert(sizeof(uniqueId) == NCCL_UNIQUE_ID_BYTES, "wrong NCCL_UNIQUE_ID_BYTES??"); // (this value is used in NVidia examples)
      mpi_->bCast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0);
    }

    groupStart();
    for (int localDeviceIndex = 0; localDeviceIndex < devices_.size(); localDeviceIndex++) {
      CUDA_CHECK(cudaSetDevice(devices_[localDeviceIndex]));
      NCCL_CHECK(ncclCommInitRank(&comms_[localDeviceIndex], numNcclRanks(), uniqueId, myNcclRank(localDeviceIndex)));
    }
    groupEnd();

    mpiBarrier(); // (synchronize the log messages)
    LOG(info, "[comm] NCCLCommunicator constructed successfully.");
    mpiBarrier(); // (synchronize the log messages)
  }

  ~NCCLCommunicator() override {
    for(int i = 0; i < devices_.size(); ++i) {
      cudaSetDevice(devices_[i]);
      cudaStreamDestroy(streams_[i]);
      ncclCommDestroy(comms_[i]);
    }
  }

  void foreach(const ForeachFunc& func, bool parallel = true) const override {
    parallel &= graphs_.size() > 1;

    for(size_t i = 0; i < graphs_.size(); ++i) {
      size_t begin, end; std::tie
      (begin, end) = localShardRange(i);
      if (parallel)
        threadResults_[i] = threadPool_.enqueue(func, i, begin, end);
      else
        func(i, begin, end);
    }
    if (parallel)
      for(size_t i = 0; i < graphs_.size(); ++i)
        threadResults_[i].wait();
  }

  void scatterReduceAndResetGrads() const override {
    synchronizeAllOnNullStream();

    groupStart();
    for(int i = 0; i < graphs_.size(); ++i) {
      size_t begin, end; std::tie
      (begin, end) = localShardRange(i);

      auto grads = graphs_[i]->params()->grads();
      const auto* sendbuf = grads->data();
      auto*       recvbuf = grads->subtensor(begin, end-begin)->data();
      size_t      bufsize = shardSize();
      ABORT_IF(grads->subtensor(begin, end-begin)->size() != bufsize, "unexpected subtensor size??");

      NCCL_CHECK(ncclReduceScatter(sendbuf, recvbuf, bufsize, ncclFloat, ncclSum, comms_[i], streams_[i]));
    }
    groupEnd();
    synchronizeAll();

    // reset gradients
    // In the future, we can keep quantization residuals here straight in the grads themselves.
    auto resetGrads = [&](size_t i, size_t begin, size_t end) {
      auto grads = graphs_[i]->params()->grads();
      auto size = grads->size();
      // reset everything outside the shard that we reduce in
      if (begin > 0)
        grads->subtensor(0, begin)->set(0.f);
      if (end < size)
        grads->subtensor(end, size - end)->set(0.f);
    };
    foreach(resetGrads);
  }

  // This distributes all 64 model shards to all 64 GPUs.
  // @TODO: For unknown reasons, this takes longer than any other operation incl. scatterReduceAndResetGrads().
  //        But both should have the same number of data transfers of the same size.
  void allGatherParams() const override {
    synchronizeAllOnNullStream();

    groupStart();
    for(int i = 0; i < graphs_.size(); ++i) {
      size_t begin, end; std::tie
      (begin, end) = localShardRange(i);

      auto vals = graphs_[i]->params()->vals();
      const auto* sendbuf = vals->subtensor(begin, end-begin)->data();
      void*       recvbuf = vals->data();
      size_t      bufsize = shardSize();

      NCCL_CHECK(ncclAllGather(sendbuf, recvbuf, bufsize, ncclFloat, comms_[i], streams_[i]));
    }
    groupEnd();
    synchronizeAll();
  }

  // swap distributed paramShards with model params()
  // It is assumed that all model params() on all devices and MPI processes are identical.
  // This is used for the smoothed parameters.
  void swapParams(const std::vector<Tensor>& distributedParamShards) const override {
    // get everything onto the CPU
    auto distributedParams = gatherState([&](size_t localDeviceIndex) {
      std::vector<float> tmp;
      distributedParamShards[localDeviceIndex]->get(tmp);
      return tmp;
    });
    // Now all MPI processes hold an identical copy of a concatenation of all distributedParamShards[] across local and remote devices.
    std::vector<float> localParams;
    graphs_[0]->params()->vals()->get(localParams);
    // Now all MPI processes hold an identical copy of params() (remember, we assumed all devices hold the same params()).
    ABORT_IF(distributedParams.size() != localParams.size(), "distributed sharded and local params have different size??");

    // swap
    std::swap(distributedParams, localParams);

    // distribute it back
    scatterState(distributedParams, [&](size_t localDeviceIndex,
                                        std::vector<float>::const_iterator begin,
                                        std::vector<float>::const_iterator end){
      ABORT_IF(distributedParamShards[localDeviceIndex]->size() != end-begin, "swapParams size mismatch??"); // @TODO: move check to set()
      distributedParamShards[localDeviceIndex]->set(std::vector<float>(begin, end)); // @TODO: directly pass iterators to set()
    });
    for (auto& graph : graphs_) // broadcast to every local graph
      graph->params()->vals()->set(localParams);
  }

  // Distribute a single CPU-side vector to shards across multiple devices and MPI processes.
  // This is used when restoring optimizer state, which is sharded, and as part of swapParams().
  // It is assumed that all MPI processes get the same data() passed. Hence, no MPI transfers are needed here.
  void scatterState(const std::vector<float>& data, const OptimizerBase::ScatterStateSetFunc& setFn) const override {
    size_t dataSize = data.size();
    size_t numShards = numNcclRanks();
    size_t shardSize = (dataSize + numShards - 1) / numShards;
    for(size_t localDeviceIndex = 0; localDeviceIndex < graphs_.size(); localDeviceIndex++) {
      // We only slice out data that is kept in our MPI process. Remember that all MPI processes receive the same, complete data.
      auto ncclRank = myNcclRank(localDeviceIndex);
      size_t begin = ncclRank * shardSize;
      size_t end   = std::min(begin + shardSize, dataSize);
      setFn(localDeviceIndex, data.begin() + begin, data.begin() + end);
    }
  }

  // Collect shards across multiple devices and MPI processes in the NCCL configuration into a single CPU-side vector.
  // This is used when persisting optimizer state, which is sharded, and as part of swapParams().
  std::vector<float> gatherState(const OptimizerBase::GatherStateGetFunc& getFn) const override {
    std::vector<float> tmp; // (temp buffer used multiple times)
    // first, concatenate over all local devices
    std::vector<float> localData;
    for(size_t localDeviceIndex = 0; localDeviceIndex < graphs_.size(); localDeviceIndex++) {
      tmp = getFn(localDeviceIndex);
      localData.insert(localData.end(), tmp.begin(), tmp.end());
    }
    // second, concatenate across MPI processes
    // Note that all local devices occupy consecutive ncclRanks in order.
    std::vector<float> data;
    if (mpi_) {
      // push one rank's data at a time using broadcast
      for(size_t mpiRank = 0; mpiRank < mpi_->numMPIProcesses(); mpiRank++) {
        // broadcast mpiRank's localData to all
        if(mpiRank == mpi_->myMPIRank())
          tmp = localData;
        mpi_->bCast(tmp, /*rootRank=*/mpiRank);
        // now all ranks have the same slice: concatenate (we will end up with the same on all MPI processes)
        data.insert(data.end(), tmp.begin(), tmp.end());
      }
    }
    else { // no MPI: localData is the complete result already
      data = std::move(localData);
    }
    return data;
  }
};

}  // namespace marian