Welcome to mirror list, hosted at ThFree Co, Russian Federation.

graph_group_async.cpp « training « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f85f9cf8500488dfa5e97c712f3167a11e8ead33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#include "training/graph_group_async.h"
#include "data/corpus_base.h"
#include "functional/functional.h"
#include "tensors/tensor_operators.h"

namespace marian {

AsyncGraphGroup::AsyncGraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
    : GraphGroup(options, mpi),
      shardSync_(devices_.size()),
      optimizerDelay_((size_t)options_->get<double>("optimizer-delay")) {
  ABORT_IF(mpi->numMPIProcesses() != 1, "AsyncGraphGroup presently does not support multiple MPI processes");
  ABORT_IF((double)optimizerDelay_ != options_->get<double>("optimizer-delay"), "AsyncGraphGroup presently does not implement fractional values for --optimizer-delay");
  pool_.reset(new ThreadPool(devices_.size(), devices_.size()));
}

void AsyncGraphGroup::setScheduler(Ptr<Scheduler> scheduler) {
  scheduler_ = scheduler;
  // optimizer has to be registered last to see changes of learning rate
  scheduler_->registerTrainingObserver(scheduler_);
  for(auto opt : optimizerShards_)
    scheduler_->registerTrainingObserver(opt);
}

void AsyncGraphGroup::fetchParams(Tensor oldParams,
                                  const std::vector<Tensor>& params,
                                  int /*device_id*/) {
  // @TODO read guard on parameters
  int pos = 0;
  auto fetch = [&](int idx, int pos) {
    // individual mutex per-shard
    std::lock_guard<std::mutex> guard(shardSync_[idx]);
    oldParams->subtensor((int)pos, (int)params[idx]->size())->copyFrom(params[idx]);
  };

  std::vector<std::thread> threads;
  for(int idx = 0; idx < devices_.size(); idx++) {
    threads.emplace_back(std::thread(fetch, idx, pos));
    pos += shardSize_;
  }
  for(auto&& t : threads)
    t.join();
}

void AsyncGraphGroup::pushGradients(Tensor newGrads,
                                    int /*device_id*/,
                                    size_t mbSize) {
  std::vector<std::thread> threads;
  int pos = 0;
  for(int idx = 0; idx < devices_.size(); idx++) {
    auto push = [&](int idx, int pos) {
      // individual mutex per-shard
      std::lock_guard<std::mutex> guard(shardSync_[idx]);
      grads_[idx]->copyFrom(newGrads->subtensor(pos, (int)grads_[idx]->size()));
      optimizerShards_[idx]->update(params_[idx], grads_[idx], mbSize);
    };

    threads.emplace_back(std::thread(push, idx, pos));
    pos += shardSize_;
  }
  for(auto&& t : threads)
    t.join();
}

void AsyncGraphGroup::init(Ptr<data::Batch> batch) {
  // initialize the parameters
  {
    ThreadPool pool(graphs_.size(), graphs_.size());
    for(size_t i = 0; i < graphs_.size(); ++i) {
      auto init = [&](size_t i) {
        models_[i]->build(graphs_[i], batch);
        graphs_[i]->forward();
      };
      pool.enqueue(init, i);
    }
  }

  if(params_.empty()) {
    int totalSize = (int)graphs_[0]->params()->vals()->size();
    shardSize_ = (int)ceil(totalSize / (float)devices_.size());

    int pos = 0;
    // parameter sharding
    for(auto graph : graphs_) {
      int __size__ = std::min(shardSize_, totalSize);
      totalSize -= __size__;

      Tensor param;
      Ptr<TensorAllocator> allocator
          = New<TensorAllocator>(graph->getBackend());
      allocator->reserveExact(__size__ * sizeOf(graph->getDefaultElementType()));
      allocator->allocate(param, {1, __size__}, graph->getDefaultElementType());
      paramsAlloc_.push_back(allocator);

      param->copyFrom(graphs_[0]->params()->vals()->subtensor(pos, __size__));
      params_.push_back(param);

      pos += __size__;
    }
  }
  if(grads_.empty()) {
    int totalSize = (int)graphs_[0]->params()->vals()->size();

    for(auto graph : graphs_) {
      int __size__ = std::min(shardSize_, totalSize);
      totalSize -= __size__;
      Tensor grad;
      Ptr<TensorAllocator> allocator
          = New<TensorAllocator>(graph->getBackend());

      allocator->reserveExact(__size__ * sizeOf(graph->getDefaultElementType()));
      allocator->allocate(grad, {1, __size__}, graph->getDefaultElementType());
      grad->set(0.f);

      gradsAlloc_.push_back(allocator);
      grads_.push_back(grad);
    }
  }

  // Initialize optimizers with empty gradient
  for(int i = 0; i < params_.size(); ++i)
    optimizerShards_[i]->update(params_[i], grads_[i], batch->wordsTrg());
}

void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
  if(first_) {
    init(batch);
    first_ = false;
  }

  auto task = [this](Ptr<data::Batch> batch) {
    // assign thread id safely via atomic increment  
    static std::atomic<int> threadCount{0};
    thread_local int tid = -1;
    if(tid == -1)
      tid = threadCount++;

    thread_local size_t t = 0;
    thread_local size_t num_seen_words = 0;
    thread_local size_t num_seen_sentences = 0;
    thread_local StaticLoss loss;

    thread_local Tensor accGradients;
    thread_local Ptr<TensorAllocator> accAlloc;

    ABORT_IF(costScaling_ ,"Cost-scaling not implemented for AsyncSGD");

    auto graph = graphs_[tid];
    Ptr<RationalLoss> dynamicLoss = models_[tid]->build(graph, batch);
    if(costScalingFactor_ != 1.f) {
      // it's ok to go out of scope, this will still insert the new top node into the graph
      auto costNode = dynamicLoss->loss() * costScalingFactor_;
    }

    if(t % optimizerDelay_ == 0) {
      fetchParams(graph->params()->vals(), params_, tid);
    }

    graph->forward();
    loss += *dynamicLoss; // does not add scaledLoss but original loss
    graph->backward();

    Tensor gradients;
    if(optimizerDelay_ > 1) {
      if(t == 0) {
        accAlloc = New<TensorAllocator>(graph->getBackend());
        accAlloc->reserveExact(graph->params()->grads()->memory()->size());
        accAlloc->allocate(accGradients, graph->params()->grads()->shape(), graph->getDefaultElementType());
        accGradients->set(0);
      }

      using namespace functional;
      Element(_1 += _2, accGradients, graph->params()->grads());
      gradients = accGradients;

      // Keep track of how many words we've calculated the error from
      num_seen_words += batch->words();
      num_seen_sentences += batch->size();
    } else {
      gradients = graph->params()->grads();
    }

    t++;

    if(t % optimizerDelay_ == 0) {
      pushGradients(gradients, tid, num_seen_words);
      // Reset the counter of seen target words after gradient update
      if(optimizerDelay_ > 1)
        gradients->set(0);
    }

    if(t % optimizerDelay_ == 0 && scheduler_) {
      std::unique_lock<std::mutex> lock(schedulerMutex_);

      // Wait until the thread that wants to do validation is finished.
      pool_->wait_for_one(lock);

      if(optimizerDelay_ > 1) {
        std::vector<size_t> fakeLength = {1, 1};
        std::vector<Ptr<Vocab>> vocabs;
        auto fb = data::CorpusBatch::fakeBatch(fakeLength, vocabs, num_seen_sentences, NULL);
        fb->front()->setWords(num_seen_words);

        scheduler_->update(loss, fb);

        num_seen_words = 0;
        num_seen_sentences = 0;
      } else {
        scheduler_->update(loss, batch);
      }

      loss.reset();

      if(scheduler_->saving() || scheduler_->validating()) {
        // Wait with validation or saving until all other threads are done with
        // update.
        // We want to reuse the graphs for validation, so they need to be in
        // a safe state.
        pool_->wait_for_others(lock);

        LOG(info, "TODO: implement exponential smoothing!");

        if(scheduler_->validating())
          scheduler_->validate(graphs_);

        if(scheduler_->saving())
          save(); // since we have waited above we can just call the generic save

        // Validation or saving is done, tell other threads to continue work.
        pool_->notify_others();
      }
    }
  };

  pool_->enqueue(task, batch);
}

void AsyncGraphGroup::finalize() {
  pool_->join_all();  // call before destructing thread pool
  pool_.reset(nullptr);
  finalized_ = true;
}

}  // namespace marian