Welcome to mirror list, hosted at ThFree Co, Russian Federation.

communicator.cpp « training « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0ee6ebd6b5263e248a5dbf104e81228cd005e6e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#include "training/communicator.h"

#if defined(CUDA_FOUND) && defined(USE_NCCL)
#include "training/communicator.cu"
//#include "training/communicator_nccl.h" // @TODO: rename to this
#endif

#if MPI_FOUND
#include "mpi.h"
#endif

namespace marian {

Ptr<ICommunicator> createCommunicator(
  const std::vector<Ptr<ExpressionGraph>>& graphs,
  bool noNccl, Ptr<IMPIWrapper> mpi) {
  mpi;
#if defined(CUDA_FOUND) && defined(USE_NCCL)
  if(noNccl) {
    LOG(warn, "[comm] NCCL communicator overridden");
    return New<DefaultCommunicator>(graphs, mpi);
  }

  // if at least one of the devices is not a gpu, fall-back to default
  for(auto& graph : graphs) {
    if(graph->getBackend()->getDeviceId().type == DeviceType::cpu) {
      return New<DefaultCommunicator>(graphs, mpi);
    }
  }

  size_t d = graphs.size();
  if((d & (d - 1)) != 0) {
    LOG(warn,
        "[comm] Number of devices {} is not a power of 2 and communication "
        "might be slow with NCCL",
        d);
    LOG(warn, "[comm] You can switch off NCCL with --no-nccl option", d);
  }

  // the actual implementation is inside communicator.cu
  return New<NCCLCommunicator>(graphs, mpi); 
#else // no CUDA or no NCCL
  noNccl; // (unused)
  return New<DefaultCommunicator>(graphs, mpi);
#endif
}

#if MPI_FOUND
// wrapper for MPI calls
// Since MPI can only be initialized once, only one instance of this class can exist.
class MPIWrapper : public IMPIWrapper
{
  int my_rank_;         // MPI rank of this node
  int comm_world_size_; // Number of nodes in MPI world (cluster)

  void handleError(int mpiRetval, const char* exprString) const { // call this with the return value of all MPI calls to report errors
    if (mpiRetval != MPI_SUCCESS) {
      char errStr[MPI_MAX_ERROR_STRING + 1] = { 0 };
      int resultLen = 0;
      MPI_Error_string(mpiRetval, &errStr[0], &resultLen);
      errStr[resultLen] = 0; // (@TODO: needed?)
      ABORT("MPI call failed with code {} '{}' on node {}: {}", mpiRetval, errStr, my_rank_, exprString); // @TODO: also log host name, which is involved on Windows
    }
  }
#define HANDLE_MPI_ERROR(expr) (handleError(expr, #expr)) // call through a macro so we can also log the failed expression itself

public:
  MPIWrapper(bool multiThreaded) {
    int requiredThreadingMode = multiThreaded ? MPI_THREAD_MULTIPLE : MPI_THREAD_SINGLE;

    int argc = 1; char* argv[] = { const_cast<char*>("this.exe") }; char** argvp = argv; // dummy argc/argv since MPI_Init needs something here
    int providedThreadingMode;
    HANDLE_MPI_ERROR(MPI_Init_thread(&argc, &argvp, MPI_THREAD_MULTIPLE, &providedThreadingMode));
    MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // have errors reported as return codes

    ABORT_IF(
      providedThreadingMode < requiredThreadingMode,
      "Your version of MPI does not support multi-threaded communication.");

    MPI_Comm_size(MPI_COMM_WORLD, &comm_world_size_);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);

    // patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
    if (numMPIProcesses() > 1)
      switchtoMultinodeLogging(std::to_string(MPIWrapper::myMPIRank()));

    // log hostnames in order, and test
    for (size_t r = 0; r < numMPIProcesses(); r++) {
      MPIWrapper::barrier();
      if (r == MPIWrapper::myMPIRank())
        LOG(info, "[mpi] initialized {} processes", MPIWrapper::numMPIProcesses());
      MPIWrapper::barrier();
    }
  }

  virtual size_t myMPIRank()        const override { return (size_t)my_rank_; };
  virtual size_t numMPIProcesses() const override { return (size_t)comm_world_size_; };

  virtual void barrier(MPI_Comm comm = MPI_COMM_WORLD) const override {
    HANDLE_MPI_ERROR(MPI_Barrier(comm));
  }
  virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm = MPI_COMM_WORLD) const override {
    HANDLE_MPI_ERROR(MPI_Bcast(buf, (int)count, datatype, (int)rootRank, comm));
  }
  virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
    HANDLE_MPI_ERROR(MPI_Ssend(buf, (int)count, datatype, (int)destRank, tag, comm));
  }
  virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
    HANDLE_MPI_ERROR(MPI_Recv(buf, (int)count, datatype, (int)sourceRank, tag, comm, status));
  }
  virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
    HANDLE_MPI_ERROR(MPI_Allreduce(sendbuf, recvbuf, (int)count, datatype, op, comm));
  }
  virtual void finalize() override {
    HANDLE_MPI_ERROR(MPI_Finalize());
  }
};
#endif

// dummy MPI wrapper that implements only one process without actual operations
// This is used when not compiling under MPI.
// @TODO: Complete this.
class FakeMPIWrapper : public IMPIWrapper
{
public:
  FakeMPIWrapper(bool) {
    LOG(warn, "compiled without MPI support; using FakeMPIWrapper to allow debugging");
  }

  virtual size_t myMPIRank() const override { return 0; };
  virtual size_t numMPIProcesses() const override { return 1; };

#pragma warning(push)
#pragma warning(disable: 4100) // unreferenced formal parameter
  // most functions are no-ops when applied to a single process
  virtual void barrier(MPI_Comm comm) const override {
  }
  virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm) const override {
  }
  virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
  }
  virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
    // @TODO: fill in status
    ABORT_IF(status != MPI_STATUS_IGNORE, "fake recv not implemented when passing a status");
  }
  virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
    ABORT_IF(sendbuf != recvbuf, "fake allReduce only implemented for in-place operation"); // otherwise it's not a no-op, we must copy data
  }
#pragma warning(push)
  virtual void finalize() override { }
};

// create instance of the singleton MPI wrapper
static Ptr<IMPIWrapper> s_mpi;    // singleton instance of MPI wrapper
static size_t s_mpiUseCount;      // how many times has this wrapper been instantiated?
static bool s_mpiIsMultiThreaded; // multi-threading mode of this instance

Ptr<IMPIWrapper> initMPI(bool multiThreaded) {
  if (!s_mpi) {
    // @TODO: This will be extended in the future to create other types, e.g. NCCL and fake for debugging
#if MPI_FOUND
    s_mpi = New<MPIWrapper>(multiThreaded);
#else
    s_mpi = New<FakeMPIWrapper>(multiThreaded);
#endif
    s_mpiIsMultiThreaded = multiThreaded;
  }
  else {
    ABORT_IF(s_mpiIsMultiThreaded != multiThreaded, "attempted to reinitialize MPI with different multi-threading mode");
  }
  s_mpiUseCount++;
  return s_mpi;
}

void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
  ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
  mpi = nullptr; // destruct caller's handle
  ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
  if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation
    ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
    s_mpi->finalize(); // signal successful completion to MPI
    s_mpi = nullptr;   // release the singleton instance upon last finalization
  }
  s_mpiUseCount--;
}

}  // namespace marian