Welcome to mirror list, hosted at ThFree Co, Russian Federation.

scheduler.h « training « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 65bb88207e9d812164399a348b58dae2edb8b7a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
#pragma once

#include "common/options.h"
#include "training/training_state.h"
#include "training/validator.h"
#include "training/communicator.h"
#include "layers/loss.h"

namespace marian {

class Scheduler : public TrainingObserver {
private:
  Ptr<Options> options_;
  Ptr<TrainingState> state_;
  std::vector<Ptr<ValidatorBase>> validators_;

  bool first_{true};

  timer::Timer timer_, heartBeatTimer_;

  // determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
  float getScheduledLRDecayFactor(const TrainingState& state) const {
    auto args = options_->get<std::vector<std::string>>("lr-decay-inv-sqrt");
    ABORT_IF(args.empty() || args.size() > 2, "--lr-decay-inv-sqrt argument must be one or two numbers with units");
    auto decayGoogle = SchedulingParameter::parse(args[0]);
    size_t progress = state.getProgressIn(decayGoogle.unit);
    size_t start = decayGoogle.n;
    if (args.size() > 1) {
      auto decayStart = SchedulingParameter::parse(args[1]);
      ABORT_IF(decayStart && decayStart.unit != decayGoogle.unit, "both --lr-decay-inv-sqrt arguments must have the same unit");
      start = decayStart.n;
    }
    if (decayGoogle && progress > start) {
      progress = progress - start + decayGoogle.n; // shift so that we get 1 at progress==start
      return (float)(std::sqrt((double)decayGoogle.n / (double)progress));
    }
    else
      return 1.f;
  }

  // update current learning rate in state.eta
  // This considers
  //  - base LR (--learn-rate)
  //  - LR warm-up (--lr-warmup, --lr=warmup-start-rate)
  //  - scheduled LR decay (--lr-decay-inv-sqrt)
  //  - state-based LR decay (--lr-decay, --lr-decay-strategy)
  void updateLearningRate(TrainingState& state) const {
    float baselr = options_->get<float>("learn-rate");

    // warm-up factor
    float warmupFactor = 1.f;
    auto warmupParam = SchedulingParameter::parse(options_->get<std::string>("lr-warmup"));
    if(warmupParam) {
      ABORT_IF(state.warmupStart && state.warmupStart.unit != warmupParam.unit, "lr-warmup and warmup-start must have the same unit");
      auto bno = state.getProgressIn(warmupParam.unit) - state.warmupStart.n;
      warmupFactor = std::min(1.f, (float)bno / (float)warmupParam.n);
    }

    float lrStart = options_->get<float>("lr-warmup-start-rate");
    baselr = lrStart + (baselr - lrStart) * warmupFactor; // linear interpolation between lr-warmup-start-rate to learn-rate

    // schedule-based decay factor (--lr-decay-inv-sqrt)
    float scheduledDecayFactor = getScheduledLRDecayFactor(state);
    baselr = baselr * scheduledDecayFactor;

    // factor in state-based decay and set final LR as state.eta
    state.updateEta(baselr);
  }

  std::string formatLoss(std::string lossType,
                         bool dispLabelCounts,
                         size_t batchLabels,
                         Ptr<TrainingState> state) {
    std::stringstream ss;
    ss << "Cost ";
    ss << std::setprecision(8) << std::fixed;

    // @TODO: put a single loss formatting function into loss.h and reuse here to avoid code duplication
    // @TODO: use dispLabelCounts with any display type?
    // @TODO: bugbug cost-type ce-mean-words with multi-loss-type mean divides too much in display
    if(lossType == "ce-mean-words") {
      ss << state->costSum / state->costCount;
    } else if(lossType == "ce-sum" && dispLabelCounts) {
      ss << state->costSum / state->costCount
         << " * " << utils::withCommas(state->costCount);
      if(batchLabels > 0)
         ss << " @ " << utils::withCommas(batchLabels);
      ss << " after " << utils::withCommas(state->labelsTotal);
    } else if(lossType == "ce-sum" && !dispLabelCounts) {
      ss << state->costSum / state->updatesDisp; // average over batches
    } else if(lossType == "perplexity") {
      ss << std::exp(state->costSum / state->costCount);
    } else if(lossType == "cross-entropy" || lossType == "ce-mean") { // backwards-compat, @TODO: get rid of this?
      ss << state->costSum / state->samplesDisp;
    } else {
      ABORT("Unknown loss type {}", lossType);
    }

    return ss.str();
  }

public:
  // test if any parameters specify dynamic MB scaling
  bool isDynamicMBSizeScaling() const {
    auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
    auto mbTracking = options_->get<bool>("mini-batch-track-lr");
    return mbWarmup || mbTracking;
  }

  // determine dynamic MB scaling factor
  double getDynamicMBSizeMultiplier() const {
    double ratio = 1.0;

    auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
    if (mbWarmup) {
      // mini-batch-warmup
      LOG_ONCE(info, "[scheduler] Mini-batch size warmup {}", std::string(mbWarmup));
      // This ramps up MB size at start, relative to progress within warm-up period.
      size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
      auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
      // if unit is labels, then account for the fact that our increment itself is not constant
      if (mbWarmup.unit == SchedulingUnit::trgLabels)
        progressRatio = std::sqrt(progressRatio);
      if (progressRatio < 1)
        ratio *= progressRatio;
    }

    // dynamic MB-size tracking with learning rate
    // As LR goes down, MB gets ramped up by the same ratio, which has been found to be safe.
    auto mbTracking = options_->get<bool>("mini-batch-track-lr");
    if (mbTracking) {
      auto lrFactor = getScheduledLRDecayFactor(*state_) * state_->factor; // (don't include lr-warmup)
      if (lrFactor != 1)
        LOG_ONCE(info, "[scheduler] Dynamic mini-batch size adjustment enabled and kicking in");
      ratio /= lrFactor;
    }
    return ratio;
  }

  Scheduler(Ptr<Options> options, Ptr<TrainingState> state)
      : options_(options), state_(state) {
    ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??");
    updateLearningRate(*state);
  }

  bool keepGoing() {
    // stop if it reached the maximum number of epochs
    size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
    if(stopAfterEpochs > 0 && state_->epochs > stopAfterEpochs)
      return false;

    // stop if it reached the maximum number of batch updates
    size_t stopAfterBatches = options_->get<size_t>("after-batches");
    if(stopAfterBatches > 0 && state_->batches >= stopAfterBatches)
      return false;

    // stop if the first validator did not improve for a given number of checks
    size_t stopAfterStalled = options_->get<size_t>("early-stopping");
    if(stopAfterStalled > 0 && !validators_.empty()
       && stalled() >= stopAfterStalled)
      return false;

    return true;
  }

  void increaseEpoch() {
    LOG(info, "Seen {} samples", state_->samplesEpoch);
    state_->newEpoch();
    LOG(info, "Starting epoch {}", state_->epochs);
  }

  void started() { LOG(info, "Training started"); }
  void finished() { LOG(info, "Training finished"); }

  void addValidator(Ptr<ValidatorBase> validator) {
    validators_.push_back(validator);

    registerTrainingObserver(validators_.back());
    if(!state_->loaded) {
      state_->validators[validator->type()]["last-best"]
          = validator->initScore();
      state_->validators[validator->type()]["stalled"] = 0;
    }
    if(validators_.size() == 1)
      state_->validator = validator->type();
  }

  bool validating() {
    return (!validators_.empty()
            && state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq"))
            && keepGoing());
  }

  bool saving() {
    return state_->enteredNewPeriodOf(options_->get<std::string>("save-freq"));
  }

  void validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
                bool final = false) {
    // Do not validate if already validated (for instance, after the model is
    // loaded) or if validation is scheduled for another update
    if(state_->validated
       || (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq"))
           && !final))
      return;

    bool firstValidator = true;
    for(auto validator : validators_) {
      if(!validator)
        continue;

      size_t stalledPrev = validator->stalled();
      float value = validator->validate(graphs);
      if(validator->stalled() > 0) {
        LOG_VALID(info,
                  "Ep. {} : Up. {} : {} : {} : stalled {} times (last best: {})",
                  state_->epochs,
                  state_->batches,
                  validator->type(),
                  value,
                  validator->stalled(), validator->lastBest());
      } else {
        LOG_VALID(info,
                  "Ep. {} : Up. {} : {} : {} : new best",
                  state_->epochs,
                  state_->batches,
                  validator->type(),
                  value);

        if(firstValidator)
          state_->validBest = value;
      }

      state_->validators[validator->type()]["last-best"]
          = validator->lastBest();
      state_->validators[validator->type()]["stalled"] = validator->stalled();

      // notify training observers if the first validator did not improve
      if(firstValidator && validator->stalled() > stalledPrev)
        state_->newStalled(validator->stalled());
      firstValidator = false;
    }

    state_->validated = true;
  }

  size_t stalled() {
    if(!validators_.empty())
      if(validators_[0])
        return validators_[0]->stalled();
    return 0;
  }

  void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
    update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg());
  }

  // @TODO: go back to function which takes batch as an argument? The current arguments make it hard to choose
  // which subbatch should be used for speed display. For sequence-classifiers it's more interesting to see the
  // source-words consumed rather than the labels.
  void update(StaticLoss rationalLoss,
              size_t numReadBatches, // number of batches read by the reader (for seeking in case of restart)
              size_t batchSize,      // total number of sentences in batch
              size_t batchLabels,    // total number of target words in batch
              Ptr<IMPIWrapper> mpi = nullptr) {

    state_->rememberPreviousProgress(); // note: epoch increases happen at the wrong place, hence -freq parameters do not support epoch units
    state_->validated = false;

    // Since batchLabels is counted across all MPI processes, we also should temporarily
    // extrapolate cost across MPI processes, to have numbers in the right range.
    // When doing the actual log, we then aggregate across MPI processes to get the accurate number.
    if (mpi)
      rationalLoss.loss *= mpi->numMPIProcesses();

    state_->costSum      += rationalLoss.loss;   // aggregate sum cost since last display
    state_->costCount    += rationalLoss.count; // cost gets normalized w.r.t. this in display

    state_->updatesDisp  += 1;
    state_->samplesDisp  += batchSize;
    state_->wordsDisp    += batchLabels;  //@TODO: this is wrong        // words at given input processed since last display, for speed display

    state_->samplesEpoch += batchSize;           // sentences processed in this epoch
    state_->labelsTotal  += rationalLoss.count; // total labels processed

    state_->newUpdate(numReadBatches);

    // reconstruct sum cost, for displaying epoch-level averages instead of minibatch-level
    auto lossType = options_->get<std::string>("cost-type");
    auto dispLabelCounts = options_->get<bool>("disp-label-counts");  // if true then show as "cost per label * number of labels"

    if(state_->enteredNewPeriodOf(options_->get<std::string>("disp-freq")) ||
       state_->batches <= options_->get<size_t>("disp-first")) {
      // if MPI then aggregate precise cost across workers
      if (mpi) {
        state_->costSum /= mpi->numMPIProcesses(); // undo the extra scaling
        mpi->allReduce(&state_->costSum, &state_->costSum, 1, MPI_FLOAT, MPI_SUM);
      }

      if (mpi && mpi->myMPIRank() != 0) {
        // skip the report on alternate worker processes
      } else if(options_->get<bool>("lr-report")) {
        LOG(info,
            "Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s : L.r. {:.4e}",
            state_->epochs,
            state_->batches,
            utils::withCommas(state_->samplesEpoch),
            formatLoss(lossType, dispLabelCounts, batchLabels, state_),
            timer_.elapsed(),
            state_->wordsDisp / timer_.elapsed(),
            state_->eta);
      } else {
        LOG(info,
            "Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s",
            state_->epochs,
            state_->batches,
            utils::withCommas(state_->samplesEpoch),
            formatLoss(lossType, dispLabelCounts, 0, state_), // ignore batchLabels
            timer_.elapsed(),
            state_->wordsDisp / timer_.elapsed());
      }


      timer_.start();
      state_->costSum      = 0;
      state_->costCount    = 0;

      state_->updatesDisp  = 0;
      state_->samplesDisp  = 0;
      state_->wordsDisp    = 0;
    }

    // progress heartbeat for MS-internal Philly compute cluster
    // This environment variable exists when running on the cluster.
    using namespace std::chrono;
    if((!mpi || mpi->myMPIRank() == 0) && getenv("PHILLY_JOB_ID")
       && heartBeatTimer_.elapsed<std::chrono::minutes>() >= 10) {
      printf("PROGRESS: %.2f%%\nEVALERR: %.7f%%\n",
          (double)state_->epochs,
          state_->costSum / state_->costCount / (mpi ? mpi->numMPIProcesses() : 1));
      fflush(stdout);
      std::cout << "MBSIZE: " << batchLabels << " after " << state_->batches << " updates = " << state_->labelsTotal << " labels" << std::endl << std::flush;
      heartBeatTimer_.start();
    }
  }

  void load(const std::string& name) {
    std::string nameYaml = name + ".progress.yml";
    if(filesystem::exists(nameYaml))
      state_->load(nameYaml);

    if(options_->get<bool>("no-restore-corpus")) {
      state_->samplesEpoch = 0;
      state_->costSum      = 0;
      state_->costCount    = 0;

      state_->updatesDisp  = 0;
      state_->samplesDisp  = 0;
      state_->wordsDisp    = 0;
    }

    state_->newLoad();
  }

  void save(const std::string& name) {
    // Save config options
    YAML::Node yaml = options_->getYaml();
    std::ofstream fout(name + ".yml");
    fout << yaml;
    // Save training progress
    state_->save(name + ".progress.yml");
  }

  size_t numberOfBatches() { return state_->batches; }

  void registerTrainingObserver(Ptr<TrainingObserver> observer) {
    state_->registerObserver(observer);
  }

  void actAfterEpoch(TrainingState& state) override {
    float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?

    updateLearningRate(state);

    if(factor > 0.0) {
      bool decay = false;
      auto strategy = options_->get<std::string>("lr-decay-strategy");
      state.reset = false;

      if(strategy == "epoch" || strategy == "epoch+batches"
         || strategy == "epoch+stalled") {
        size_t startEpoch
            = options_->get<std::vector<size_t>>("lr-decay-start").front();
        if(startEpoch && state.epochs >= startEpoch)
          decay = true;
      }

      if(strategy == "epoch+batches") {
        size_t startBatches
            = options_->get<std::vector<size_t>>("lr-decay-start")[1];
        if(startBatches && state.batches >= startBatches)
          decay = true;
      }
      if(strategy == "epoch+stalled") {
        size_t startStalled
            = options_->get<std::vector<size_t>>("lr-decay-start")[1];
        if(startStalled && state.maxStalled >= startStalled)
          decay = true;
      }

      if(decay) {
        state.factor *= factor;
        updateLearningRate(state);
        LOG(info,
            "Decaying learning rate to {} in epoch {}",
            state.eta,
            state.epochs);

        state.reset = options_->get<bool>("lr-decay-reset-optimizer");
        if(state.reset)
          LOG(info, "Resetting optimizer statistics");

        if(options_->get<bool>("lr-decay-repeat-warmup")) {
          LOG(info, "Restarting learning rate warmup");
          state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
        }
      }
    }
  }

  void actAfterBatches(TrainingState& state) override {
    float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?
    state.reset = false;

    updateLearningRate(state);

    if(factor > 0.0) {
      if(options_->get<std::string>("lr-decay-strategy") == "batches") {
        size_t start = options_->get<std::vector<size_t>>("lr-decay-start").front();
        size_t freq  = options_->get<size_t>("lr-decay-freq"); // note: unlike e.g. disp-freq, this is always in batches

        if(start > 0 && freq > 0 && state.batches >= start
           && ((state.batches - start) % freq == 0)) {
          state.factor *= factor;
          updateLearningRate(state);
          LOG(info,
              "Decaying learning rate to {} after {} batches",
              state.eta,
              state.batches);

          state.reset = options_->get<bool>("lr-decay-reset-optimizer");
          if(state.reset)
            LOG(info, "Resetting optimizer statistics");

          if(options_->get<bool>("lr-decay-repeat-warmup")) {
            LOG(info, "Restarting learning rate warmup");
            state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
          }
        }
      }
    }

    if(first_ && options_->get<bool>("lr-warmup-at-reload")) {
      LOG(info, "Restarting learning rate warmup");
      state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
    }

    if(options_->get<bool>("lr-warmup-cycle")) {
      if(state_->enteredNewPeriodOf(options_->get<std::string>("lr-warmup")))
        state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
    }

    first_ = false;
  }

  void actAfterStalled(TrainingState& state) override {
    float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?
    state.reset = false;

    updateLearningRate(state);

    if(factor > 0.0) {
      if(options_->get<std::string>("lr-decay-strategy") == "stalled") {
        size_t startStalled
            = options_->get<std::vector<size_t>>("lr-decay-start").front();
        if(startStalled && state.stalled && state.stalled % startStalled == 0) {
          state.factor *= factor;
          updateLearningRate(state);
          LOG(info,
              "Decaying learning rate to {} after having stalled {} time(s)",
              state.eta,
              state.stalled);

          state.reset = options_->get<bool>("lr-decay-reset-optimizer");
          if(state.reset)
            LOG(info, "Resetting optimizer statistics");

          if(options_->get<bool>("lr-decay-repeat-warmup")) {
            LOG(info, "Restarting learning rate warmup");
            state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
          }
        }
      }
    }
  }
};
}  // namespace marian