diff options
author | Roman Grundkiewicz <rogrundk@microsoft.com> | 2021-04-26 14:51:43 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rogrundk@microsoft.com> | 2021-04-26 14:51:43 +0300 |
commit | 49e379bba5c77c1b80927b7f0db5603e171a1903 (patch) | |
tree | 7cd89540ba86333eff3d4f07a3ab0b8ef3db324f | |
parent | 3e51ff387232f1096e9560980f0115ac734224f5 (diff) |
Merged PR 18612: Early stopping on first, all, or any validation metrics
Adds `--early-stopping-on first|all|any` allowing to decide if early stopping should take into account only first, all, or any validation metrics.
Feature request: https://github.com/marian-nmt/marian-dev/issues/850
Regression tests: https://github.com/marian-nmt/marian-regression-tests/pull/79
-rw-r--r-- | CHANGELOG.md | 1 | ||||
m--------- | regression-tests | 0 | ||||
-rw-r--r-- | src/common/config_parser.cpp | 14 | ||||
-rw-r--r-- | src/common/config_validator.cpp | 9 | ||||
-rw-r--r-- | src/training/scheduler.h | 69 | ||||
-rw-r--r-- | src/training/training_state.h | 4 | ||||
-rw-r--r-- | src/training/validator.cpp | 14 | ||||
-rw-r--r-- | src/training/validator.h | 2 |
8 files changed, 73 insertions, 40 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 752847e1..7f41b8d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Early stopping based on first, all, or any validation metrics via `--early-stopping-on` - Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`. - Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special - Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off diff --git a/regression-tests b/regression-tests -Subproject 7d612ca5e4b27a76f92584dad76d240e34f216d +Subproject 1afd4eb1014ac451c6a3d6f9b5d34c322902e62 diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 6495db0e..f29b3630 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -244,7 +244,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Tie all embedding layers and output layer"); cli.add<bool>("--output-omit-bias", "Do not use a bias vector in decoder output layer"); - + // Transformer options cli.add<int>("--transformer-heads", "Number of heads in multi-head attention (transformer)", @@ -529,13 +529,13 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). " "After this many updates about 90% of the mass of the exponential average comes from these updates", 100); - cli.add<std::vector<std::string>>("--dynamic-gradient-scaling", + cli.add<std::vector<std::string>>("--dynamic-gradient-scaling", "Re-scale gradient to have average gradient norm if (log) gradient norm diverges from average by arg1 sigmas. " "If arg2 = \"log\" the statistics are recorded for the log of the gradient norm else use plain norm") ->implicit_val("2.f log"); - cli.add<bool>("--check-gradient-nan", + cli.add<bool>("--check-gradient-nan", "Skip parameter update in case of NaNs in gradient"); - cli.add<bool>("--normalize-gradient", + cli.add<bool>("--normalize-gradient", "Normalize gradient by multiplying with no. devices / total labels (not recommended and to be removed in the future)"); cli.add<std::vector<std::string>>("--train-embedder-rank", @@ -574,6 +574,10 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) { cli.add<size_t>("--early-stopping", "Stop if the first validation metric does not improve for arg consecutive validation steps", 10); + cli.add<std::string>("--early-stopping-on", + "Decide if early stopping should take into account first, all, or any validation metrics" + "Possible values: first, all, any", + "first"); // decoding options cli.add<size_t>("--beam-size,-b", @@ -586,7 +590,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) { "Maximum target length as source length times factor", 3); cli.add<float>("--word-penalty", - "Subtract (arg * translation length) from translation score "); + "Subtract (arg * translation length) from translation score"); cli.add<bool>("--allow-unk", "Allow unknown words to appear in output"); cli.add<bool>("--n-best", diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index b2400145..fea7578f 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -4,6 +4,8 @@ #include "common/utils.h" #include "common/filesystem.h" +#include <set> + namespace marian { bool ConfigValidator::has(const std::string& key) const { @@ -129,6 +131,11 @@ void ConfigValidator::validateOptionsTraining() const { && !get<std::vector<std::string>>("valid-sets").empty(), errorMsg); + // check if --early-stopping-on has proper value + std::set<std::string> supportedStops = {"first", "all", "any"}; + ABORT_IF(supportedStops.find(get<std::string>("early-stopping-on")) == supportedStops.end(), + "Supported options for --early-stopping-on are: first, all, any"); + // validations for learning rate decaying ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual"); @@ -145,7 +152,7 @@ void ConfigValidator::validateOptionsTraining() const { // validate ULR options ABORT_IF((has("ulr") && get<bool>("ulr") && (get<std::string>("ulr-query-vectors") == "" || get<std::string>("ulr-keys-vectors") == "")), - "ULR enablign requires query and keys vectors specified with --ulr-query-vectors and " + "ULR requires query and keys vectors specified with --ulr-query-vectors and " "--ulr-keys-vectors option"); // validate model quantization diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 9d2500f9..8d4fa30c 100644 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -28,7 +28,7 @@ private: // (regardless if it's the 1st or nth epoch and if it's a new or continued training), // which indicates the end of the training data stream from STDIN bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN; - + // @TODO: figure out how to compute this with regard to updates as well, although maybe harder since no final value // determine scheduled LR decay factor (--lr-decay-inv-sqrt option) float getScheduledLRDecayFactor(const TrainingState& state) const { @@ -133,7 +133,7 @@ public: Scheduler(Ptr<Options> options, Ptr<TrainingState> state, Ptr<IMPIWrapper> mpi = nullptr) : options_(options), state_(state), mpi_(mpi), gradientNormAvgWindow_(options_->get<size_t>("gradient-norm-average-window", 100)) { - + // parse logical-epoch parameters auto logicalEpochStr = options->get<std::vector<std::string>>("logical-epoch", {"1e", "0"}); ABORT_IF(logicalEpochStr.empty(), "Logical epoch information is missing?"); @@ -174,7 +174,7 @@ public: size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period // if unit is labels, then account for the fact that our increment itself is not constant -#if 1 // this seems to hurt convergence quite a bit compared to when updates is used +#if 1 // this seems to hurt convergence quite a bit compared to when updates is used if (mbWarmup.unit == SchedulingUnit::trgLabels) progressRatio = std::sqrt(progressRatio); #endif @@ -207,7 +207,7 @@ public: if(saveAndExitRequested()) // via SIGTERM return false; -#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches +#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches // stop if it reached the maximum number of epochs size_t stopAfterEpochs = options_->get<size_t>("after-epochs"); if(stopAfterEpochs > 0 && calculateLogicalEpoch() > stopAfterEpochs) @@ -231,10 +231,9 @@ public: } } - // stop if the first validator did not improve for a given number of checks + // stop if the first/all/any validators did not improve for a given number of checks size_t stopAfterStalled = options_->get<size_t>("early-stopping"); - if(stopAfterStalled > 0 && !validators_.empty() - && stalled() >= stopAfterStalled) + if(stopAfterStalled > 0 && stalled() >= stopAfterStalled) return false; // stop if data streaming from STDIN is stopped @@ -297,12 +296,11 @@ public: || (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now return; - bool firstValidator = true; + size_t stalledPrev = stalled(); for(auto validator : validators_) { if(!validator) continue; - size_t stalledPrev = validator->stalled(); float value = 0; if(!mpi_ || mpi_->isMainProcess()) { // We run validation only in the main process, but this is risky with MPI. @@ -330,34 +328,60 @@ public: if(mpi_) { // collect and broadcast validation result to all processes and bring validator up-to-date mpi_->bCast(&value, 1, IMPIWrapper::getDataType(&value)); - + // @TODO: add function to validator? mpi_->bCast(&validator->stalled(), 1, IMPIWrapper::getDataType(&validator->stalled())); mpi_->bCast(&validator->lastBest(), 1, IMPIWrapper::getDataType(&validator->lastBest())); } - if(firstValidator) - state_->validBest = value; - state_->validators[validator->type()]["last-best"] = validator->lastBest(); state_->validators[validator->type()]["stalled"] = validator->stalled(); - - // notify training observers if the first validator did not improve - if(firstValidator && validator->stalled() > stalledPrev) - state_->newStalled(validator->stalled()); - firstValidator = false; } + // notify training observers about stalled validation + size_t stalledNew = stalled(); + if(stalledNew > stalledPrev) + state_->newStalled(stalledNew); + state_->validated = true; } + // Returns the proper number of stalled validation w.r.t. early-stopping-on size_t stalled() { + std::string stopOn = options_->get<std::string>("early-stopping-on"); + if(stopOn == "any") + return stalledMax(); + if(stopOn == "all") + return stalledMin(); + return stalled1st(); + } + + // Returns the number of stalled validations for the first validator + size_t stalled1st() { if(!validators_.empty()) if(validators_[0]) return validators_[0]->stalled(); return 0; } + // Returns the largest number of stalled validations across validators or 0 if there are no validators + size_t stalledMax() { + size_t max = 0; + for(auto validator : validators_) + if(validator && validator->stalled() > max) + max = validator->stalled(); + return max; + } + + // Returns the lowest number of stalled validations across validators or 0 if there are no validators + size_t stalledMin() { + size_t min = std::numeric_limits<std::size_t>::max(); + for(auto validator : validators_) + if(validator && validator->stalled() < min) + min = validator->stalled(); + return min == std::numeric_limits<std::size_t>::max() ? 0 : min; + } + void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) { update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg(), /*gradientNorm=*/0.f); } @@ -397,8 +421,8 @@ public: if(gradientNorm) { size_t range = std::min(gradientNormAvgWindow_, state_->batches); - float alpha = 2.f / (float)(range + 1); - + float alpha = 2.f / (float)(range + 1); + float delta = gradientNorm - state_->gradientNormAvg; state_->gradientNormAvg = state_->gradientNormAvg + alpha * delta; state_->gradientNormVar = (1.0f - alpha) * (state_->gradientNormVar + alpha * delta * delta); @@ -440,7 +464,7 @@ public: formatLogicalEpoch(), state_->batches, utils::withCommas(state_->samplesEpoch), - formatLoss(lossType, dispLabelCounts, batchLabels, state_), + formatLoss(lossType, dispLabelCounts, batchLabels, state_), timer_.elapsed(), state_->wordsDisp / timer_.elapsed(), state_->gradientNormAvg); @@ -627,7 +651,8 @@ public: if(options_->get<bool>("lr-decay-repeat-warmup")) { LOG(info, "Restarting learning rate warmup"); - state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit); + state.warmupStart.n = state.getProgressIn( + SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit); } } } diff --git a/src/training/training_state.h b/src/training/training_state.h index 7d62f060..e0c1ba5d 100644 --- a/src/training/training_state.h +++ b/src/training/training_state.h @@ -43,8 +43,6 @@ public: size_t stalled{0}; // The largest number of stalled validations so far size_t maxStalled{0}; - // Last best validation score - float validBest{0.f}; std::string validator; // List of validators YAML::Node validators; @@ -217,7 +215,6 @@ public: stalled = config["stalled"].as<size_t>(); maxStalled = config["stalled-max"].as<size_t>(); - validBest = config["valid-best"].as<float>(); validator = config["validator"].as<std::string>(); validators = config["validators"]; reset = config["reset"].as<bool>(); @@ -259,7 +256,6 @@ public: config["stalled"] = stalled; config["stalled-max"] = maxStalled; - config["valid-best"] = validBest; config["validator"] = validator; config["validators"] = validators; config["reset"] = reset; diff --git a/src/training/validator.cpp b/src/training/validator.cpp index d824052f..ef1bac3d 100644 --- a/src/training/validator.cpp +++ b/src/training/validator.cpp @@ -447,7 +447,7 @@ SacreBleuValidator::SacreBleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Optio ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check. - if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF, + if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF, order_ = 6; // we compute stats over character ngrams up to length 6 // @TODO: remove, only used for saving? @@ -613,12 +613,12 @@ void SacreBleuValidator::updateStats(std::vector<float>& stats, LOG_VALID_ONCE(info, "First sentence's tokens as scored:"); LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false))); LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false))); - + if(useWordIds_) updateStats(stats, cand, ref); else updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false)); - + } // Re-implementation of BLEU metric from SacreBLEU @@ -627,7 +627,7 @@ float SacreBleuValidator::calcBLEU(const std::vector<float>& stats) { for(int i = 0; i < order_; ++i) { float commonNgrams = stats[statsPerOrder * i + 0]; float hypothesesNgrams = stats[statsPerOrder * i + 1]; - + if(commonNgrams == 0.f) return 0.f; logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams); @@ -653,7 +653,7 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) { float commonNgrams = stats[statsPerOrder * i + 0]; float hypothesesNgrams = stats[statsPerOrder * i + 1]; float referencesNgrams = stats[statsPerOrder * i + 2]; - + if(hypothesesNgrams > 0 && referencesNgrams > 0) { avgPrecision += commonNgrams / hypothesesNgrams; avgRecall += commonNgrams / referencesNgrams; @@ -666,10 +666,10 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) { avgPrecision /= effectiveOrder; avgRecall /= effectiveOrder; - + if(avgPrecision + avgRecall == 0.f) return 0.f; - + auto betaSquare = beta * beta; auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall); return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU diff --git a/src/training/validator.h b/src/training/validator.h index d6e64d69..16bfd245 100644 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -352,7 +352,7 @@ protected: private: const std::string metric_; // allowed values are: bleu, bleu-detok (same as bleu), bleu-segmented, chrf bool computeChrF_{ false }; // should we compute ChrF instead of BLEU (BLEU by default)? - + size_t order_{ 4 }; // 4-grams for BLEU by default static const size_t statsPerOrder = 3; // 0: common ngrams, 1: candidate ngrams, 2: reference ngrams bool useWordIds_{ false }; // compute BLEU score by matching numeric segment ids |